Commit f3cb3f2e authored by Esad Akar's avatar Esad Akar Committed by GitHub

api,file: remove data length arg from pipeline builder (#1671)

parent e0765db3
...@@ -72,7 +72,6 @@ var ( ...@@ -72,7 +72,6 @@ var (
errNoResolver = errors.New("no resolver connected") errNoResolver = errors.New("no resolver connected")
errInvalidRequest = errors.New("could not validate request") errInvalidRequest = errors.New("could not validate request")
errInvalidContentType = errors.New("invalid content-type") errInvalidContentType = errors.New("invalid content-type")
errInvalidContentLength = errors.New("invalid content-length")
errDirectoryStore = errors.New("could not store directory") errDirectoryStore = errors.New("could not store directory")
errFileStore = errors.New("could not store file") errFileStore = errors.New("could not store file")
errInvalidPostageBatch = errors.New("invalid postage batch id") errInvalidPostageBatch = errors.New("invalid postage batch id")
...@@ -339,13 +338,13 @@ func (p *stamperPutter) Put(ctx context.Context, mode storage.ModePut, chs ...sw ...@@ -339,13 +338,13 @@ func (p *stamperPutter) Put(ctx context.Context, mode storage.ModePut, chs ...sw
return p.Storer.Put(ctx, mode, chs...) return p.Storer.Put(ctx, mode, chs...)
} }
type pipelineFunc func(context.Context, io.Reader, int64) (swarm.Address, error) type pipelineFunc func(context.Context, io.Reader) (swarm.Address, error)
func requestPipelineFn(s storage.Putter, r *http.Request) pipelineFunc { func requestPipelineFn(s storage.Putter, r *http.Request) pipelineFunc {
mode, encrypt := requestModePut(r), requestEncrypt(r) mode, encrypt := requestModePut(r), requestEncrypt(r)
return func(ctx context.Context, r io.Reader, l int64) (swarm.Address, error) { return func(ctx context.Context, r io.Reader) (swarm.Address, error) {
pipe := builder.NewPipelineBuilder(ctx, s, mode, encrypt) pipe := builder.NewPipelineBuilder(ctx, s, mode, encrypt)
return builder.FeedPipeline(ctx, pipe, r, l) return builder.FeedPipeline(ctx, pipe, r)
} }
} }
......
...@@ -66,7 +66,7 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -66,7 +66,7 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) {
} }
p := requestPipelineFn(putter, r) p := requestPipelineFn(putter, r)
address, err := p(ctx, r.Body, r.ContentLength) address, err := p(ctx, r.Body)
if err != nil { if err != nil {
logger.Debugf("bytes upload: split write all: %v", err) logger.Debugf("bytes upload: split write all: %v", err)
logger.Error("bytes upload: split write all") logger.Error("bytes upload: split write all")
......
...@@ -10,12 +10,9 @@ import ( ...@@ -10,12 +10,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"mime" "mime"
"net/http" "net/http"
"os"
"path" "path"
"strconv"
"strings" "strings"
"time" "time"
...@@ -82,8 +79,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request, store ...@@ -82,8 +79,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request, store
logger := tracing.NewLoggerWithTraceID(r.Context(), s.logger) logger := tracing.NewLoggerWithTraceID(r.Context(), s.logger)
var ( var (
reader io.Reader reader io.Reader
fileName, contentLength string fileName string
fileSize uint64
) )
// Content-Type has already been validated by this time // Content-Type has already been validated by this time
...@@ -114,49 +110,12 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request, store ...@@ -114,49 +110,12 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request, store
ctx := sctx.SetTag(r.Context(), tag) ctx := sctx.SetTag(r.Context(), tag)
fileName = r.URL.Query().Get("name") fileName = r.URL.Query().Get("name")
contentLength = r.Header.Get("Content-Length")
reader = r.Body reader = r.Body
if contentLength != "" {
fileSize, err = strconv.ParseUint(contentLength, 10, 64)
if err != nil {
logger.Debugf("bzz upload file: content length, file %q: %v", fileName, err)
logger.Errorf("bzz upload file: content length, file %q", fileName)
jsonhttp.BadRequest(w, errInvalidContentLength)
return
}
} else {
// copy the part to a tmp file to get its size
tmp, err := ioutil.TempFile("", "bee-multipart")
if err != nil {
logger.Debugf("bzz upload file: create temporary file: %v", err)
logger.Errorf("bzz upload file: create temporary file")
jsonhttp.InternalServerError(w, nil)
return
}
defer os.Remove(tmp.Name())
defer tmp.Close()
n, err := io.Copy(tmp, reader)
if err != nil {
logger.Debugf("bzz upload file: write temporary file: %v", err)
logger.Error("bzz upload file: write temporary file")
jsonhttp.InternalServerError(w, nil)
return
}
if _, err := tmp.Seek(0, io.SeekStart); err != nil {
logger.Debugf("bzz upload file: seek to beginning of temporary file: %v", err)
logger.Error("bzz upload file: seek to beginning of temporary file")
jsonhttp.InternalServerError(w, nil)
return
}
fileSize = uint64(n)
reader = tmp
}
p := requestPipelineFn(storer, r) p := requestPipelineFn(storer, r)
// first store the file and get its reference // first store the file and get its reference
fr, err := p(ctx, reader, int64(fileSize)) fr, err := p(ctx, reader)
if err != nil { if err != nil {
logger.Debugf("bzz upload file: file store, file %q: %v", fileName, err) logger.Debugf("bzz upload file: file store, file %q: %v", fileName, err)
logger.Errorf("bzz upload file: file store, file %q", fileName) logger.Errorf("bzz upload file: file store, file %q", fileName)
......
...@@ -161,7 +161,7 @@ func storeDir( ...@@ -161,7 +161,7 @@ func storeDir(
} }
} }
fileReference, err := p(ctx, fileInfo.Reader, fileInfo.Size) fileReference, err := p(ctx, fileInfo.Reader)
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("store dir file: %w", err) return swarm.ZeroAddress, fmt.Errorf("store dir file: %w", err)
} }
......
...@@ -52,7 +52,7 @@ func testSplitThenJoin(t *testing.T) { ...@@ -52,7 +52,7 @@ func testSplitThenJoin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
dataReader := file.NewSimpleReadCloser(data) dataReader := file.NewSimpleReadCloser(data)
resultAddress, err := builder.FeedPipeline(ctx, p, dataReader, int64(len(data))) resultAddress, err := builder.FeedPipeline(ctx, p, dataReader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -187,7 +187,7 @@ func TestEncryptDecrypt(t *testing.T) { ...@@ -187,7 +187,7 @@ func TestEncryptDecrypt(t *testing.T) {
ctx := context.Background() ctx := context.Background()
pipe := builder.NewPipelineBuilder(ctx, store, storage.ModePutUpload, true) pipe := builder.NewPipelineBuilder(ctx, store, storage.ModePutUpload, true)
testDataReader := bytes.NewReader(testData) testDataReader := bytes.NewReader(testData)
resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader, int64(len(testData))) resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -51,7 +51,7 @@ func (ls *loadSave) Load(ctx context.Context, ref []byte) ([]byte, error) { ...@@ -51,7 +51,7 @@ func (ls *loadSave) Load(ctx context.Context, ref []byte) ([]byte, error) {
func (ls *loadSave) Save(ctx context.Context, data []byte) ([]byte, error) { func (ls *loadSave) Save(ctx context.Context, data []byte) ([]byte, error) {
pipe := builder.NewPipelineBuilder(ctx, ls.storer, ls.mode, ls.encrypted) pipe := builder.NewPipelineBuilder(ctx, ls.storer, ls.mode, ls.encrypted)
address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data))) address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
if err != nil { if err != nil {
return swarm.ZeroAddress.Bytes(), err return swarm.ZeroAddress.Bytes(), err
} }
......
...@@ -72,17 +72,12 @@ func newShortEncryptionPipelineFunc(ctx context.Context, s storage.Putter, mode ...@@ -72,17 +72,12 @@ func newShortEncryptionPipelineFunc(ctx context.Context, s storage.Putter, mode
// FeedPipeline feeds the pipeline with the given reader until EOF is reached. // FeedPipeline feeds the pipeline with the given reader until EOF is reached.
// It returns the cryptographic root hash of the content. // It returns the cryptographic root hash of the content.
func FeedPipeline(ctx context.Context, pipeline pipeline.Interface, r io.Reader, dataLength int64) (addr swarm.Address, err error) { func FeedPipeline(ctx context.Context, pipeline pipeline.Interface, r io.Reader) (addr swarm.Address, err error) {
var total int64
data := make([]byte, swarm.ChunkSize) data := make([]byte, swarm.ChunkSize)
for { for {
c, err := r.Read(data) c, err := r.Read(data)
total += int64(c)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
if total < dataLength {
return swarm.ZeroAddress, fmt.Errorf("pipline short write: read %d out of %d bytes", total, dataLength)
}
if c > 0 { if c > 0 {
cc, err := pipeline.Write(data[:c]) cc, err := pipeline.Write(data[:c])
if err != nil { if err != nil {
......
...@@ -32,7 +32,7 @@ func TestPinningService(t *testing.T) { ...@@ -32,7 +32,7 @@ func TestPinningService(t *testing.T) {
) )
pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false) pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false)
ref, err := builder.FeedPipeline(ctx, pipe, strings.NewReader(content), int64(len(content))) ref, err := builder.FeedPipeline(ctx, pipe, strings.NewReader(content))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -156,7 +156,7 @@ func TestTraversalBytes(t *testing.T) { ...@@ -156,7 +156,7 @@ func TestTraversalBytes(t *testing.T) {
defer cancel() defer cancel()
pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false) pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false)
address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data))) address, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -246,7 +246,7 @@ func TestTraversalFiles(t *testing.T) { ...@@ -246,7 +246,7 @@ func TestTraversalFiles(t *testing.T) {
defer cancel() defer cancel()
pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false) pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false)
fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data))) fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -415,7 +415,7 @@ func TestTraversalManifest(t *testing.T) { ...@@ -415,7 +415,7 @@ func TestTraversalManifest(t *testing.T) {
data := generateSample(f.size) data := generateSample(f.size)
pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false) pipe := builder.NewPipelineBuilder(ctx, storerMock, storage.ModePutUpload, false)
fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data))) fr, err := builder.FeedPipeline(ctx, pipe, bytes.NewReader(data))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment