Commit f3cb3f2e authored by Esad Akar's avatar Esad Akar Committed by GitHub

api,file: remove data length arg from pipeline builder (#1671)

parent e0765db3
......@@ -72,7 +72,6 @@ var (
errNoResolver = errors.New("no resolver connected")
errInvalidRequest = errors.New("could not validate request")
errInvalidContentType = errors.New("invalid content-type")
errInvalidContentLength = errors.New("invalid content-length")
errDirectoryStore = errors.New("could not store directory")
errFileStore = errors.New("could not store file")
errInvalidPostageBatch = errors.New("invalid postage batch id")
......@@ -339,13 +338,13 @@ func (p *stamperPutter) Put(ctx context.Context, mode storage.ModePut, chs ...sw
return p.Storer.Put(ctx, mode, chs...)
}
type pipelineFunc func(context.Context, io.Reader, int64) (swarm.Address, error)
type pipelineFunc func(context.Context, io.Reader) (swarm.Address, error)
func requestPipelineFn(s storage.Putter, r *http.Request) pipelineFunc {
mode, encrypt := requestModePut(r), requestEncrypt(r)
return func(ctx context.Context, r io.Reader, l int64) (swarm.Address, error) {
return func(ctx context.Context, r io.Reader) (swarm.Address, error) {
pipe := builder.NewPipelineBuilder(ctx, s, mode, encrypt)
return builder.FeedPipeline(ctx, pipe, r, l)
return builder.FeedPipeline(ctx, pipe, r)
}
}
......
......@@ -66,7 +66,7 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) {
}
p := requestPipelineFn(putter, r)
address, err := p(ctx, r.Body, r.ContentLength)
address, err := p(ctx, r.Body)
if err != nil {
logger.Debugf("bytes upload: split write all: %v", err)
logger.Error("bytes upload: split write all")
......
......@@ -10,12 +10,9 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"mime"
"net/http"
"os"
"path"
"strconv"
"strings"
"time"
......@@ -81,9 +78,8 @@ type bzzUploadResponse struct {
func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request, storer storage.Storer) {
logger := tracing.NewLoggerWithTraceID(r.Context(), s.logger)
var (
reader io.Reader
fileName, contentLength string
fileSize uint64
reader io.Reader
fileName string
)
// Content-Type has already been validated by this time
......@@ -114,49 +110,12 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request, store
ctx := sctx.SetTag(r.Context(), tag)
fileName = r.URL.Query().Get("name")
contentLength = r.Header.Get("Content-Length")
reader = r.Body
if contentLength != "" {
fileSize, err = strconv.ParseUint(contentLength, 10, 64)
if err != nil {
logger.Debugf("bzz upload file: content length, file %q: %v", fileName, err)
logger.Errorf("bzz upload file: content length, file %q", fileName)
jsonhttp.BadRequest(w, errInvalidContentLength)
return
}
} else {
// copy the part to a tmp file to get its size
tmp, err := ioutil.TempFile("", "bee-multipart")
if err != nil {
logger.Debugf("bzz upload file: create temporary file: %v", err)
logger.Errorf("bzz upload file: create temporary file")
jsonhttp.InternalServerError(w, nil)
return
}
defer os.Remove(tmp.Name())
defer tmp.Close()
n, err := io.Copy(tmp, reader)
if err != nil {
logger.Debugf("bzz upload file: write temporary file: %v", err)
logger.Error("bzz upload file: write temporary file")
jsonhttp.InternalServerError(w, nil)
return
}
if _, err := tmp.Seek(0, io.SeekStart); err != nil {
logger.Debugf("bzz upload file: seek to beginning of temporary file: %v", err)
logger.Error("bzz upload file: seek to beginning of temporary file")
jsonhttp.InternalServerError(w, nil)
return
}
fileSize = uint64(n)
reader = tmp
}
p := requestPipelineFn(storer, r)
// first store the file and get its reference
fr, err := p(ctx, reader, int64(fileSize))
fr, err := p(ctx, reader)
if err != nil {
logger.Debugf("bzz upload file: file store, file %q: %v", fileName, err)
logger.Errorf("bzz upload file: file store, file %q", fileName)
......
......@@ -161,7 +161,7 @@ func storeDir(
}
}
fileReference, err := p(ctx, fileInfo.Reader, fileInfo.Size)
fileReference, err := p(ctx, fileInfo.Reader)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("store dir file: %w", err)
}
......
......@@ -52,7 +52,7 @@ func testSplitThenJoin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dataReader := file.NewSimpleReadCloser(data)
resultAddress, err := builder.FeedPipeline(ctx, p, dataReader, int64(len(data)))
resultAddress, err := builder.FeedPipeline(ctx, p, dataReader)
if err != nil {
t.Fatal(err)
}
......
......@@ -187,7 +187,7 @@ func TestEncryptDecrypt(t *testing.T) {
ctx := context.Background()
pipe := builder.NewPipelineBuilder(ctx, store, storage.ModePutUpload, true)
testDataReader := bytes.NewReader(testData)
resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader, int64(len(testData)))
resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader)
if err != nil {
t.Fatal(err)
}
......
......@@ -51,7 +51,7 @@ func (ls *loadSave) Load(ctx context.Context, ref []byte) ([]byte, error) {
func (ls *loadSave) Save(ctx context.Context, data []byte) ([]byte, error) {
pipe := builder.NewPipelineBuilder(ctx, ls.storer, ls.mode, ls.encrypted)
address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data)))
address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
if err != nil {
return swarm.ZeroAddress.Bytes(), err
}
......
......@@ -72,17 +72,12 @@ func newShortEncryptionPipelineFunc(ctx context.Context, s storage.Putter, mode
// FeedPipeline feeds the pipeline with the given reader until EOF is reached.
// It returns the cryptographic root hash of the content.
func FeedPipeline(ctx context.Context, pipeline pipeline.Interface, r io.Reader, dataLength int64) (addr swarm.Address, err error) {
var total int64
func FeedPipeline(ctx context.Context, pipeline pipeline.Interface, r io.Reader) (addr swarm.Address, err error) {
data := make([]byte, swarm.ChunkSize)
for {
c, err := r.Read(data)
total += int64(c)
if err != nil {
if err == io.EOF {
if total < dataLength {
return swarm.ZeroAddress, fmt.Errorf("pipline short write: read %d out of %d bytes", total, dataLength)
}
if c > 0 {
cc, err := pipeline.Write(data[:c])
if err != nil {
......
......@@ -32,7 +32,7 @@ func TestPinningService(t *testing.T) {
)
pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false)
ref, err := builder.FeedPipeline(ctx, pipe, strings.NewReader(content), int64(len(content)))
ref, err := builder.FeedPipeline(ctx, pipe, strings.NewReader(content))
if err != nil {
t.Fatal(err)
}
......
......@@ -156,7 +156,7 @@ func TestTraversalBytes(t *testing.T) {
defer cancel()
pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false)
address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data)))
address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
......@@ -246,7 +246,7 @@ func TestTraversalFiles(t *testing.T) {
defer cancel()
pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false)
fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data)))
fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
......@@ -415,7 +415,7 @@ func TestTraversalManifest(t *testing.T) {
data := generateSample(f.size)
pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false)
fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data)))
fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
if err != nil {
t.Fatal(err)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment