Commit ac5b402c authored by acud's avatar acud Committed by GitHub

file: seekjoiner cleanup (#847)

parent 99af2c62
...@@ -17,7 +17,7 @@ import ( ...@@ -17,7 +17,7 @@ import (
cmdfile "github.com/ethersphere/bee/cmd/internal/file" cmdfile "github.com/ethersphere/bee/cmd/internal/file"
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/seekjoiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -59,8 +59,12 @@ func getEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -59,8 +59,12 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
writeCloser := cmdfile.NopWriteCloser(buf) writeCloser := cmdfile.NopWriteCloser(buf)
limitBuf := cmdfile.NewLimitWriteCloser(writeCloser, limitMetadataLength) limitBuf := cmdfile.NewLimitWriteCloser(writeCloser, limitMetadataLength)
j := seekjoiner.NewSimpleJoiner(store) j, _, err := joiner.New(cmd.Context(), store, addr)
_, err = file.JoinReadAll(cmd.Context(), j, addr, limitBuf) if err != nil {
return err
}
_, err = file.JoinReadAll(cmd.Context(), j, limitBuf)
if err != nil { if err != nil {
return err return err
} }
...@@ -70,8 +74,14 @@ func getEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -70,8 +74,14 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
return err return err
} }
j, _, err = joiner.New(cmd.Context(), store, e.Metadata())
if err != nil {
return err
}
buf = bytes.NewBuffer(nil) buf = bytes.NewBuffer(nil)
_, err = file.JoinReadAll(cmd.Context(), j, e.Metadata(), buf)
_, err = file.JoinReadAll(cmd.Context(), j, buf)
if err != nil { if err != nil {
return err return err
} }
...@@ -117,7 +127,13 @@ func getEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -117,7 +127,13 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
return err return err
} }
defer outFile.Close() defer outFile.Close()
_, err = file.JoinReadAll(cmd.Context(), j, e.Reference(), outFile)
j, _, err = joiner.New(cmd.Context(), store, e.Reference())
if err != nil {
return err
}
_, err = file.JoinReadAll(cmd.Context(), j, outFile)
return err return err
} }
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
cmdfile "github.com/ethersphere/bee/cmd/internal/file" cmdfile "github.com/ethersphere/bee/cmd/internal/file"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/seekjoiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -82,8 +82,11 @@ func Join(cmd *cobra.Command, args []string) (err error) { ...@@ -82,8 +82,11 @@ func Join(cmd *cobra.Command, args []string) (err error) {
} }
// create the join and get its data reader // create the join and get its data reader
j := seekjoiner.NewSimpleJoiner(store) j, _, err := joiner.New(cmd.Context(), store, addr)
_, err = file.JoinReadAll(cmd.Context(), j, addr, outFile) if err != nil {
return err
}
_, err = file.JoinReadAll(cmd.Context(), j, outFile)
return err return err
} }
......
...@@ -17,7 +17,7 @@ import ( ...@@ -17,7 +17,7 @@ import (
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/seekjoiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/manifest" "github.com/ethersphere/bee/pkg/manifest"
"github.com/ethersphere/bee/pkg/sctx" "github.com/ethersphere/bee/pkg/sctx"
...@@ -53,9 +53,16 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -53,9 +53,16 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
toDecrypt := len(address.Bytes()) == 64 toDecrypt := len(address.Bytes()) == 64
// read manifest entry // read manifest entry
j := seekjoiner.NewSimpleJoiner(s.Storer) j, _, err := joiner.New(ctx, s.Storer, address)
if err != nil {
logger.Debugf("bzz download: joiner manifest entry %s: %v", address, err)
logger.Errorf("bzz download: joiner %s", address)
jsonhttp.NotFound(w, nil)
return
}
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
_, err = file.JoinReadAll(ctx, j, address, buf) _, err = file.JoinReadAll(ctx, j, buf)
if err != nil { if err != nil {
logger.Debugf("bzz download: read entry %s: %v", address, err) logger.Debugf("bzz download: read entry %s: %v", address, err)
logger.Errorf("bzz download: read entry %s", address) logger.Errorf("bzz download: read entry %s", address)
...@@ -71,9 +78,18 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -71,9 +78,18 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
// read metadata
j, _, err = joiner.New(ctx, s.Storer, e.Metadata())
if err != nil {
logger.Debugf("bzz download: joiner metadata %s: %v", address, err)
logger.Errorf("bzz download: joiner %s", address)
jsonhttp.NotFound(w, nil)
return
}
// read metadata // read metadata
buf = bytes.NewBuffer(nil) buf = bytes.NewBuffer(nil)
_, err = file.JoinReadAll(ctx, j, e.Metadata(), buf) _, err = file.JoinReadAll(ctx, j, buf)
if err != nil { if err != nil {
logger.Debugf("bzz download: read metadata %s: %v", address, err) logger.Debugf("bzz download: read metadata %s: %v", address, err)
logger.Errorf("bzz download: read metadata %s", address) logger.Errorf("bzz download: read metadata %s", address)
...@@ -114,7 +130,7 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -114,7 +130,7 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
// index document exists // index document exists
logger.Debugf("bzz download: serving path: %s", pathWithIndex) logger.Debugf("bzz download: serving path: %s", pathWithIndex)
s.serveManifestEntry(w, r, j, address, indexDocumentManifestEntry.Reference()) s.serveManifestEntry(w, r, address, indexDocumentManifestEntry.Reference())
return return
} }
} }
...@@ -154,7 +170,7 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -154,7 +170,7 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
// index document exists // index document exists
logger.Debugf("bzz download: serving path: %s", pathWithIndex) logger.Debugf("bzz download: serving path: %s", pathWithIndex)
s.serveManifestEntry(w, r, j, address, indexDocumentManifestEntry.Reference()) s.serveManifestEntry(w, r, address, indexDocumentManifestEntry.Reference())
return return
} }
} }
...@@ -168,7 +184,7 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -168,7 +184,7 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
// error document exists // error document exists
logger.Debugf("bzz download: serving path: %s", errorDocumentPath) logger.Debugf("bzz download: serving path: %s", errorDocumentPath)
s.serveManifestEntry(w, r, j, address, errorDocumentManifestEntry.Reference()) s.serveManifestEntry(w, r, address, errorDocumentManifestEntry.Reference())
return return
} }
} }
...@@ -182,16 +198,26 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -182,16 +198,26 @@ func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
} }
// serve requested path // serve requested path
s.serveManifestEntry(w, r, j, address, me.Reference()) s.serveManifestEntry(w, r, address, me.Reference())
} }
func (s *server) serveManifestEntry(w http.ResponseWriter, r *http.Request, j file.JoinSeeker, address, manifestEntryAddress swarm.Address) { func (s *server) serveManifestEntry(w http.ResponseWriter, r *http.Request, address, manifestEntryAddress swarm.Address) {
logger := tracing.NewLoggerWithTraceID(r.Context(), s.Logger) var (
ctx := r.Context() logger = tracing.NewLoggerWithTraceID(r.Context(), s.Logger)
ctx = r.Context()
buf = bytes.NewBuffer(nil)
)
// read file entry // read file entry
buf := bytes.NewBuffer(nil) j, _, err := joiner.New(ctx, s.Storer, manifestEntryAddress)
_, err := file.JoinReadAll(ctx, j, manifestEntryAddress, buf) if err != nil {
logger.Debugf("bzz download: joiner read file entry %s: %v", address, err)
logger.Errorf("bzz download: joiner read file entry %s", address)
jsonhttp.NotFound(w, nil)
return
}
_, err = file.JoinReadAll(ctx, j, buf)
if err != nil { if err != nil {
logger.Debugf("bzz download: read file entry %s: %v", address, err) logger.Debugf("bzz download: read file entry %s: %v", address, err)
logger.Errorf("bzz download: read file entry %s", address) logger.Errorf("bzz download: read file entry %s", address)
...@@ -208,8 +234,16 @@ func (s *server) serveManifestEntry(w http.ResponseWriter, r *http.Request, j fi ...@@ -208,8 +234,16 @@ func (s *server) serveManifestEntry(w http.ResponseWriter, r *http.Request, j fi
} }
// read file metadata // read file metadata
j, _, err = joiner.New(ctx, s.Storer, fe.Metadata())
if err != nil {
logger.Debugf("bzz download: joiner read file entry %s: %v", address, err)
logger.Errorf("bzz download: joiner read file entry %s", address)
jsonhttp.NotFound(w, nil)
return
}
buf = bytes.NewBuffer(nil) buf = bytes.NewBuffer(nil)
_, err = file.JoinReadAll(ctx, j, fe.Metadata(), buf) _, err = file.JoinReadAll(ctx, j, buf)
if err != nil { if err != nil {
logger.Debugf("bzz download: read file metadata %s: %v", address, err) logger.Debugf("bzz download: read file metadata %s: %v", address, err)
logger.Errorf("bzz download: read file metadata %s", address) logger.Errorf("bzz download: read file metadata %s", address)
......
...@@ -18,7 +18,7 @@ import ( ...@@ -18,7 +18,7 @@ import (
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/seekjoiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
...@@ -261,10 +261,13 @@ func TestDirs(t *testing.T) { ...@@ -261,10 +261,13 @@ func TestDirs(t *testing.T) {
} }
// read manifest metadata // read manifest metadata
j := seekjoiner.NewSimpleJoiner(storer) j, _, err := joiner.New(context.Background(), storer, resp.Reference)
if err != nil {
t.Fatal(err)
}
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
_, err = file.JoinReadAll(context.Background(), j, resp.Reference, buf) _, err = file.JoinReadAll(context.Background(), j, buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -21,8 +21,8 @@ import ( ...@@ -21,8 +21,8 @@ import (
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/pipeline/builder" "github.com/ethersphere/bee/pkg/file/pipeline/builder"
"github.com/ethersphere/bee/pkg/file/seekjoiner"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/sctx" "github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -246,10 +246,17 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -246,10 +246,17 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(sctx.SetTargets(r.Context(), targets)) r = r.WithContext(sctx.SetTargets(r.Context(), targets))
} }
// read entry. // read entry
j := seekjoiner.NewSimpleJoiner(s.Storer) j, _, err := joiner.New(r.Context(), s.Storer, address)
if err != nil {
logger.Debugf("file download: joiner %s: %v", address, err)
logger.Errorf("file download: joiner %s", address)
jsonhttp.NotFound(w, nil)
return
}
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
_, err = file.JoinReadAll(r.Context(), j, address, buf) _, err = file.JoinReadAll(r.Context(), j, buf)
if err != nil { if err != nil {
logger.Debugf("file download: read entry %s: %v", address, err) logger.Debugf("file download: read entry %s: %v", address, err)
logger.Errorf("file download: read entry %s", address) logger.Errorf("file download: read entry %s", address)
...@@ -275,9 +282,17 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -275,9 +282,17 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
// Read metadata. // read metadata
j, _, err = joiner.New(r.Context(), s.Storer, e.Metadata())
if err != nil {
logger.Debugf("file download: joiner %s: %v", address, err)
logger.Errorf("file download: joiner %s", address)
jsonhttp.NotFound(w, nil)
return
}
buf = bytes.NewBuffer(nil) buf = bytes.NewBuffer(nil)
_, err = file.JoinReadAll(r.Context(), j, e.Metadata(), buf) _, err = file.JoinReadAll(r.Context(), j, buf)
if err != nil { if err != nil {
logger.Debugf("file download: read metadata %s: %v", nameOrHex, err) logger.Debugf("file download: read metadata %s: %v", nameOrHex, err)
logger.Errorf("file download: read metadata %s", nameOrHex) logger.Errorf("file download: read metadata %s", nameOrHex)
...@@ -309,8 +324,7 @@ func (s *server) downloadHandler(w http.ResponseWriter, r *http.Request, referen ...@@ -309,8 +324,7 @@ func (s *server) downloadHandler(w http.ResponseWriter, r *http.Request, referen
r = r.WithContext(sctx.SetTargets(r.Context(), targets)) r = r.WithContext(sctx.SetTargets(r.Context(), targets))
} }
rs := seekjoiner.NewSimpleJoiner(s.Storer) reader, l, err := joiner.New(r.Context(), s.Storer, reference)
reader, l, err := rs.Join(r.Context(), reference)
if err != nil { if err != nil {
if errors.Is(err, storage.ErrNotFound) { if errors.Is(err, storage.ErrNotFound) {
logger.Debugf("api download: not found %s: %v", reference, err) logger.Debugf("api download: not found %s: %v", reference, err)
......
...@@ -7,8 +7,6 @@ package file ...@@ -7,8 +7,6 @@ package file
import ( import (
"context" "context"
"errors"
"fmt"
"io" "io"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -19,10 +17,11 @@ type Reader interface { ...@@ -19,10 +17,11 @@ type Reader interface {
io.ReaderAt io.ReaderAt
} }
// JoinSeeker provides a Joiner that can seek. // Joiner provides the inverse functionality of the Splitter.
type JoinSeeker interface { type Joiner interface {
Join(ctx context.Context, address swarm.Address) (dataOut Reader, dataLength int64, err error) Reader
Size(ctx context.Context, address swarm.Address) (dataLength int64, err error) // Size returns the span of the hash trie represented by the joiner's root hash.
Size() int64
} }
// Splitter starts a new file splitting job. // Splitter starts a new file splitting job.
...@@ -33,68 +32,3 @@ type JoinSeeker interface { ...@@ -33,68 +32,3 @@ type JoinSeeker interface {
type Splitter interface { type Splitter interface {
Split(ctx context.Context, dataIn io.ReadCloser, dataLength int64, toEncrypt bool) (addr swarm.Address, err error) Split(ctx context.Context, dataIn io.ReadCloser, dataLength int64, toEncrypt bool) (addr swarm.Address, err error)
} }
// JoinReadAll reads all output from the provided SeekJoiner.
func JoinReadAll(ctx context.Context, j JoinSeeker, addr swarm.Address, outFile io.Writer) (int64, error) {
r, l, err := j.Join(ctx, addr)
if err != nil {
return 0, err
}
// join, rinse, repeat until done
data := make([]byte, swarm.ChunkSize)
var total int64
for i := int64(0); i < l; i += swarm.ChunkSize {
cr, err := r.Read(data)
if err != nil {
return total, err
}
total += int64(cr)
cw, err := outFile.Write(data[:cr])
if err != nil {
return total, err
}
if cw != cr {
return total, fmt.Errorf("short wrote %d of %d for chunk %d", cw, cr, i)
}
}
if total != l {
return total, fmt.Errorf("received only %d of %d total bytes", total, l)
}
return total, nil
}
// SplitWriteAll writes all input from provided reader to the provided splitter
func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64, toEncrypt bool) (swarm.Address, error) {
chunkPipe := NewChunkPipe()
errC := make(chan error)
go func() {
buf := make([]byte, swarm.ChunkSize)
c, err := io.CopyBuffer(chunkPipe, r, buf)
if err != nil {
errC <- err
}
if c != l {
errC <- errors.New("read count mismatch")
}
err = chunkPipe.Close()
if err != nil {
errC <- err
}
close(errC)
}()
addr, err := s.Split(ctx, chunkPipe, l, toEncrypt)
if err != nil {
return swarm.ZeroAddress, err
}
select {
case err := <-errC:
if err != nil {
return swarm.ZeroAddress, err
}
case <-ctx.Done():
return swarm.ZeroAddress, ctx.Err()
}
return addr, nil
}
...@@ -13,8 +13,8 @@ import ( ...@@ -13,8 +13,8 @@ import (
"testing" "testing"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/pipeline/builder" "github.com/ethersphere/bee/pkg/file/pipeline/builder"
"github.com/ethersphere/bee/pkg/file/seekjoiner"
test "github.com/ethersphere/bee/pkg/file/testing" test "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
...@@ -45,7 +45,6 @@ func testSplitThenJoin(t *testing.T) { ...@@ -45,7 +45,6 @@ func testSplitThenJoin(t *testing.T) {
dataIdx, _ = strconv.ParseInt(paramstring[1], 10, 0) dataIdx, _ = strconv.ParseInt(paramstring[1], 10, 0)
store = mock.NewStorer() store = mock.NewStorer()
p = builder.NewPipelineBuilder(context.Background(), store, storage.ModePutUpload, false) p = builder.NewPipelineBuilder(context.Background(), store, storage.ModePutUpload, false)
j = seekjoiner.NewSimpleJoiner(store)
data, _ = test.GetVector(t, int(dataIdx)) data, _ = test.GetVector(t, int(dataIdx))
) )
...@@ -59,7 +58,7 @@ func testSplitThenJoin(t *testing.T) { ...@@ -59,7 +58,7 @@ func testSplitThenJoin(t *testing.T) {
} }
// then join // then join
r, l, err := j.Join(ctx, resultAddress) r, l, err := joiner.New(ctx, store, resultAddress)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -6,9 +6,13 @@ package file ...@@ -6,9 +6,13 @@ package file
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt"
"io" "io"
"github.com/ethersphere/bee/pkg/swarm"
) )
// simpleReadCloser wraps a byte slice in a io.ReadCloser implementation. // simpleReadCloser wraps a byte slice in a io.ReadCloser implementation.
...@@ -40,3 +44,66 @@ func (s *simpleReadCloser) Close() error { ...@@ -40,3 +44,66 @@ func (s *simpleReadCloser) Close() error {
s.closed = true s.closed = true
return nil return nil
} }
// JoinReadAll reads all output from the provided Joiner.
func JoinReadAll(ctx context.Context, j Joiner, outFile io.Writer) (int64, error) {
l := j.Size()
// join, rinse, repeat until done
data := make([]byte, swarm.ChunkSize)
var total int64
for i := int64(0); i < l; i += swarm.ChunkSize {
cr, err := j.Read(data)
if err != nil {
return total, err
}
total += int64(cr)
cw, err := outFile.Write(data[:cr])
if err != nil {
return total, err
}
if cw != cr {
return total, fmt.Errorf("short wrote %d of %d for chunk %d", cw, cr, i)
}
}
if total != l {
return total, fmt.Errorf("received only %d of %d total bytes", total, l)
}
return total, nil
}
// SplitWriteAll writes all input from provided reader to the provided splitter
func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64, toEncrypt bool) (swarm.Address, error) {
chunkPipe := NewChunkPipe()
errC := make(chan error)
go func() {
buf := make([]byte, swarm.ChunkSize)
c, err := io.CopyBuffer(chunkPipe, r, buf)
if err != nil {
errC <- err
}
if c != l {
errC <- errors.New("read count mismatch")
}
err = chunkPipe.Close()
if err != nil {
errC <- err
}
close(errC)
}()
addr, err := s.Split(ctx, chunkPipe, l, toEncrypt)
if err != nil {
return swarm.ZeroAddress, err
}
select {
case err := <-errC:
if err != nil {
return swarm.ZeroAddress, err
}
case <-ctx.Done():
return swarm.ZeroAddress, ctx.Err()
}
return addr, nil
}
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package internal // Package joiner provides implementations of the file.Joiner interface
package joiner
import ( import (
"context" "context"
...@@ -11,12 +12,14 @@ import ( ...@@ -11,12 +12,14 @@ import (
"io" "io"
"sync/atomic" "sync/atomic"
"github.com/ethersphere/bee/pkg/encryption/store"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
type SimpleJoiner struct { type joiner struct {
addr swarm.Address addr swarm.Address
rootData []byte rootData []byte
span int64 span int64
...@@ -27,8 +30,9 @@ type SimpleJoiner struct { ...@@ -27,8 +30,9 @@ type SimpleJoiner struct {
getter storage.Getter getter storage.Getter
} }
// NewSimpleJoiner creates a new SimpleJoiner. // New creates a new Joiner. A Joiner provides Read, Seek and Size functionalities.
func NewSimpleJoiner(ctx context.Context, getter storage.Getter, address swarm.Address) (*SimpleJoiner, int64, error) { func New(ctx context.Context, getter storage.Getter, address swarm.Address) (file.Joiner, int64, error) {
getter = store.New(getter)
// retrieve the root chunk to read the total data length the be retrieved // retrieve the root chunk to read the total data length the be retrieved
rootChunk, err := getter.Get(ctx, storage.ModeGetRequest, address) rootChunk, err := getter.Get(ctx, storage.ModeGetRequest, address)
if err != nil { if err != nil {
...@@ -39,7 +43,7 @@ func NewSimpleJoiner(ctx context.Context, getter storage.Getter, address swarm.A ...@@ -39,7 +43,7 @@ func NewSimpleJoiner(ctx context.Context, getter storage.Getter, address swarm.A
span := int64(binary.LittleEndian.Uint64(chunkData[:swarm.SpanSize])) span := int64(binary.LittleEndian.Uint64(chunkData[:swarm.SpanSize]))
j := &SimpleJoiner{ j := &joiner{
addr: rootChunk.Address(), addr: rootChunk.Address(),
refLength: len(address.Bytes()), refLength: len(address.Bytes()),
ctx: ctx, ctx: ctx,
...@@ -53,7 +57,7 @@ func NewSimpleJoiner(ctx context.Context, getter storage.Getter, address swarm.A ...@@ -53,7 +57,7 @@ func NewSimpleJoiner(ctx context.Context, getter storage.Getter, address swarm.A
// Read is called by the consumer to retrieve the joined data. // Read is called by the consumer to retrieve the joined data.
// It must be called with a buffer equal to the maximum chunk size. // It must be called with a buffer equal to the maximum chunk size.
func (j *SimpleJoiner) Read(b []byte) (n int, err error) { func (j *joiner) Read(b []byte) (n int, err error) {
read, err := j.ReadAt(b, j.off) read, err := j.ReadAt(b, j.off)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return read, err return read, err
...@@ -63,7 +67,7 @@ func (j *SimpleJoiner) Read(b []byte) (n int, err error) { ...@@ -63,7 +67,7 @@ func (j *SimpleJoiner) Read(b []byte) (n int, err error) {
return read, err return read, err
} }
func (j *SimpleJoiner) ReadAt(b []byte, off int64) (read int, err error) { func (j *joiner) ReadAt(b []byte, off int64) (read int, err error) {
// since offset is int64 and swarm spans are uint64 it means we cannot seek beyond int64 max value // since offset is int64 and swarm spans are uint64 it means we cannot seek beyond int64 max value
if off >= j.span { if off >= j.span {
return 0, io.EOF return 0, io.EOF
...@@ -85,7 +89,7 @@ func (j *SimpleJoiner) ReadAt(b []byte, off int64) (read int, err error) { ...@@ -85,7 +89,7 @@ func (j *SimpleJoiner) ReadAt(b []byte, off int64) (read int, err error) {
return int(atomic.LoadInt64(&bytesRead)), nil return int(atomic.LoadInt64(&bytesRead)), nil
} }
func (j *SimpleJoiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffset, bytesToRead int64, bytesRead *int64, eg *errgroup.Group) { func (j *joiner) readAtOffset(b, data []byte, cur, subTrieSize, off, bufferOffset, bytesToRead int64, bytesRead *int64, eg *errgroup.Group) {
// we are at a leaf data chunk // we are at a leaf data chunk
if subTrieSize <= int64(len(data)) { if subTrieSize <= int64(len(data)) {
dataOffsetStart := off - cur dataOffsetStart := off - cur
...@@ -179,7 +183,7 @@ func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64 ...@@ -179,7 +183,7 @@ func subtrieSection(data []byte, startIdx, refLen int, subtrieSize int64) int64
var errWhence = errors.New("seek: invalid whence") var errWhence = errors.New("seek: invalid whence")
var errOffset = errors.New("seek: invalid offset") var errOffset = errors.New("seek: invalid offset")
func (j *SimpleJoiner) Seek(offset int64, whence int) (int64, error) { func (j *joiner) Seek(offset int64, whence int) (int64, error) {
switch whence { switch whence {
case 0: case 0:
offset += 0 offset += 0
...@@ -206,18 +210,8 @@ func (j *SimpleJoiner) Seek(offset int64, whence int) (int64, error) { ...@@ -206,18 +210,8 @@ func (j *SimpleJoiner) Seek(offset int64, whence int) (int64, error) {
} }
func (j *SimpleJoiner) Size() (int64, error) { func (j *joiner) Size() int64 {
if j.rootData == nil { return j.span
chunk, err := j.getter.Get(j.ctx, storage.ModeGetRequest, j.addr)
if err != nil {
return 0, err
}
j.rootData = chunk.Data()
}
s := chunkToSpan(j.rootData)
return int64(s), nil
} }
func chunkToSpan(data []byte) uint64 { func chunkToSpan(data []byte) uint64 {
......
...@@ -2,25 +2,223 @@ ...@@ -2,25 +2,223 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package internal_test package joiner_test
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
mrand "math/rand" mrand "math/rand"
"testing" "testing"
"time" "time"
"github.com/ethersphere/bee/pkg/file/seekjoiner/internal" "github.com/ethersphere/bee/pkg/encryption/store"
"github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/pipeline/builder"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
filetest "github.com/ethersphere/bee/pkg/file/testing" filetest "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"gitlab.com/nolash/go-mockbytes"
) )
func TestJoiner_ErrReferenceLength(t *testing.T) {
store := mock.NewStorer()
_, _, err := joiner.New(context.Background(), store, swarm.ZeroAddress)
if !errors.Is(err, storage.ErrReferenceLength) {
t.Fatalf("expected ErrReferenceLength %x but got %v", swarm.ZeroAddress, err)
}
}
// TestJoinerSingleChunk verifies that a newly created joiner returns the data stored
// in the store when the reference is one single chunk.
func TestJoinerSingleChunk(t *testing.T) {
store := mock.NewStorer()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// create the chunk to
mockAddrHex := fmt.Sprintf("%064s", "2a")
mockAddr := swarm.MustParseHexAddress(mockAddrHex)
mockData := []byte("foo")
mockDataLengthBytes := make([]byte, 8)
mockDataLengthBytes[0] = 0x03
mockChunk := swarm.NewChunk(mockAddr, append(mockDataLengthBytes, mockData...))
_, err := store.Put(ctx, storage.ModePutUpload, mockChunk)
if err != nil {
t.Fatal(err)
}
// read back data and compare
joinReader, l, err := joiner.New(ctx, store, mockAddr)
if err != nil {
t.Fatal(err)
}
if l != int64(len(mockData)) {
t.Fatalf("expected join data length %d, got %d", len(mockData), l)
}
joinData, err := ioutil.ReadAll(joinReader)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(joinData, mockData) {
t.Fatalf("retrieved data '%x' not like original data '%x'", joinData, mockData)
}
}
// TestJoinerDecryptingStore_NormalChunk verifies the the mock store that uses
// the decrypting store manages to retrieve a normal chunk which is not encrypted
func TestJoinerDecryptingStore_NormalChunk(t *testing.T) {
st := mock.NewStorer()
store := store.New(st)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// create the chunk to
mockAddrHex := fmt.Sprintf("%064s", "2a")
mockAddr := swarm.MustParseHexAddress(mockAddrHex)
mockData := []byte("foo")
mockDataLengthBytes := make([]byte, 8)
mockDataLengthBytes[0] = 0x03
mockChunk := swarm.NewChunk(mockAddr, append(mockDataLengthBytes, mockData...))
_, err := st.Put(ctx, storage.ModePutUpload, mockChunk)
if err != nil {
t.Fatal(err)
}
// read back data and compare
joinReader, l, err := joiner.New(ctx, store, mockAddr)
if err != nil {
t.Fatal(err)
}
if l != int64(len(mockData)) {
t.Fatalf("expected join data length %d, got %d", len(mockData), l)
}
joinData, err := ioutil.ReadAll(joinReader)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(joinData, mockData) {
t.Fatalf("retrieved data '%x' not like original data '%x'", joinData, mockData)
}
}
// TestJoinerWithReference verifies that a chunk reference is correctly resolved
// and the underlying data is returned.
func TestJoinerWithReference(t *testing.T) {
store := mock.NewStorer()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// create root chunk and two data chunks referenced in the root chunk
rootChunk := filetest.GenerateTestRandomFileChunk(swarm.ZeroAddress, swarm.ChunkSize*2, swarm.SectionSize*2)
_, err := store.Put(ctx, storage.ModePutUpload, rootChunk)
if err != nil {
t.Fatal(err)
}
firstAddress := swarm.NewAddress(rootChunk.Data()[8 : swarm.SectionSize+8])
firstChunk := filetest.GenerateTestRandomFileChunk(firstAddress, swarm.ChunkSize, swarm.ChunkSize)
_, err = store.Put(ctx, storage.ModePutUpload, firstChunk)
if err != nil {
t.Fatal(err)
}
secondAddress := swarm.NewAddress(rootChunk.Data()[swarm.SectionSize+8:])
secondChunk := filetest.GenerateTestRandomFileChunk(secondAddress, swarm.ChunkSize, swarm.ChunkSize)
_, err = store.Put(ctx, storage.ModePutUpload, secondChunk)
if err != nil {
t.Fatal(err)
}
// read back data and compare
joinReader, l, err := joiner.New(ctx, store, rootChunk.Address())
if err != nil {
t.Fatal(err)
}
if l != int64(swarm.ChunkSize*2) {
t.Fatalf("expected join data length %d, got %d", swarm.ChunkSize*2, l)
}
resultBuffer := make([]byte, swarm.ChunkSize)
n, err := joinReader.Read(resultBuffer)
if err != nil {
t.Fatal(err)
}
if n != len(resultBuffer) {
t.Fatalf("expected read count %d, got %d", len(resultBuffer), n)
}
if !bytes.Equal(resultBuffer, firstChunk.Data()[8:]) {
t.Fatalf("expected resultbuffer %v, got %v", resultBuffer, firstChunk.Data()[:len(resultBuffer)])
}
}
func TestEncryptDecrypt(t *testing.T) {
var tests = []struct {
chunkLength int
}{
{10},
{100},
{1000},
{4095},
{4096},
{4097},
{1000000},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Encrypt %d bytes", tt.chunkLength), func(t *testing.T) {
store := mock.NewStorer()
g := mockbytes.New(0, mockbytes.MockTypeStandard).WithModulus(255)
testData, err := g.SequentialBytes(tt.chunkLength)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
pipe := builder.NewPipelineBuilder(ctx, store, storage.ModePutUpload, true)
testDataReader := bytes.NewReader(testData)
resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader, int64(len(testData)))
if err != nil {
t.Fatal(err)
}
reader, l, err := joiner.New(context.Background(), store, resultAddress)
if err != nil {
t.Fatal(err)
}
if l != int64(len(testData)) {
t.Fatalf("expected join data length %d, got %d", len(testData), l)
}
totalGot := make([]byte, tt.chunkLength)
index := 0
resultBuffer := make([]byte, swarm.ChunkSize)
for index < tt.chunkLength {
n, err := reader.Read(resultBuffer)
if err != nil && err != io.EOF {
t.Fatal(err)
}
copy(totalGot[index:], resultBuffer[:n])
index += n
}
if !bytes.Equal(testData, totalGot) {
t.Fatal("input data and output data does not match")
}
})
}
}
func TestSeek(t *testing.T) { func TestSeek(t *testing.T) {
seed := time.Now().UnixNano() seed := time.Now().UnixNano()
...@@ -80,7 +278,7 @@ func TestSeek(t *testing.T) { ...@@ -80,7 +278,7 @@ func TestSeek(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
j, _, err := internal.NewSimpleJoiner(ctx, store, addr) j, _, err := joiner.New(ctx, store, addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -358,7 +556,7 @@ func TestPrefetch(t *testing.T) { ...@@ -358,7 +556,7 @@ func TestPrefetch(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
j, _, err := internal.NewSimpleJoiner(ctx, store, addr) j, _, err := joiner.New(ctx, store, addr)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -378,8 +576,7 @@ func TestPrefetch(t *testing.T) { ...@@ -378,8 +576,7 @@ func TestPrefetch(t *testing.T) {
} }
} }
// TestSimpleJoinerReadAt func TestJoinerReadAt(t *testing.T) {
func TestSimpleJoinerReadAt(t *testing.T) {
store := mock.NewStorer() store := mock.NewStorer()
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
...@@ -406,7 +603,7 @@ func TestSimpleJoinerReadAt(t *testing.T) { ...@@ -406,7 +603,7 @@ func TestSimpleJoinerReadAt(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
j, _, err := internal.NewSimpleJoiner(ctx, store, rootChunk.Address()) j, _, err := joiner.New(ctx, store, rootChunk.Address())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -422,9 +619,9 @@ func TestSimpleJoinerReadAt(t *testing.T) { ...@@ -422,9 +619,9 @@ func TestSimpleJoinerReadAt(t *testing.T) {
} }
} }
// TestSimpleJoinerOneLevel tests the retrieval of two data chunks immediately // TestJoinerOneLevel tests the retrieval of two data chunks immediately
// below the root chunk level. // below the root chunk level.
func TestSimpleJoinerOneLevel(t *testing.T) { func TestJoinerOneLevel(t *testing.T) {
store := mock.NewStorer() store := mock.NewStorer()
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
...@@ -451,7 +648,7 @@ func TestSimpleJoinerOneLevel(t *testing.T) { ...@@ -451,7 +648,7 @@ func TestSimpleJoinerOneLevel(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
j, _, err := internal.NewSimpleJoiner(ctx, store, rootChunk.Address()) j, _, err := joiner.New(ctx, store, rootChunk.Address())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -493,10 +690,10 @@ func TestSimpleJoinerOneLevel(t *testing.T) { ...@@ -493,10 +690,10 @@ func TestSimpleJoinerOneLevel(t *testing.T) {
} }
} }
// TestSimpleJoinerTwoLevelsAcrossChunk tests the retrieval of data chunks below // TestJoinerTwoLevelsAcrossChunk tests the retrieval of data chunks below
// first intermediate level across two intermediate chunks. // first intermediate level across two intermediate chunks.
// Last chunk has sub-chunk length. // Last chunk has sub-chunk length.
func TestSimpleJoinerTwoLevelsAcrossChunk(t *testing.T) { func TestJoinerTwoLevelsAcrossChunk(t *testing.T) {
store := mock.NewStorer() store := mock.NewStorer()
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
...@@ -543,7 +740,7 @@ func TestSimpleJoinerTwoLevelsAcrossChunk(t *testing.T) { ...@@ -543,7 +740,7 @@ func TestSimpleJoinerTwoLevelsAcrossChunk(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
j, _, err := internal.NewSimpleJoiner(ctx, store, rootChunk.Address()) j, _, err := joiner.New(ctx, store, rootChunk.Address())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package joiner provides implementations of the file.Joiner interface
package seekjoiner
import (
"context"
"encoding/binary"
"fmt"
"github.com/ethersphere/bee/pkg/encryption/store"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/seekjoiner/internal"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
)
// simpleJoiner wraps a non-optimized implementation of file.SeekJoiner.
type simpleJoiner struct {
getter storage.Getter
}
// NewSimpleJoiner creates a new simpleJoiner.
func NewSimpleJoiner(getter storage.Getter) file.JoinSeeker {
return &simpleJoiner{
getter: store.New(getter),
}
}
func (s *simpleJoiner) Size(ctx context.Context, address swarm.Address) (int64, error) {
// retrieve the root chunk to read the total data length the be retrieved
rootChunk, err := s.getter.Get(ctx, storage.ModeGetRequest, address)
if err != nil {
return 0, err
}
chunkData := rootChunk.Data()
if l := len(chunkData); l < 8 {
return 0, fmt.Errorf("invalid chunk content of %d bytes", l)
}
dataLength := binary.LittleEndian.Uint64(chunkData[:8])
return int64(dataLength), nil
}
// Join implements the file.Joiner interface.
//
// It uses a non-optimized internal component that only retrieves a data chunk
// after the previous has been read.
func (s *simpleJoiner) Join(ctx context.Context, address swarm.Address) (dataOut file.Reader, dataSize int64, err error) {
return internal.NewSimpleJoiner(ctx, s.getter, address)
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package seekjoiner_test
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"testing"
"github.com/ethersphere/bee/pkg/encryption/store"
"github.com/ethersphere/bee/pkg/file/pipeline/builder"
"github.com/ethersphere/bee/pkg/file/seekjoiner"
joiner "github.com/ethersphere/bee/pkg/file/seekjoiner"
filetest "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
"gitlab.com/nolash/go-mockbytes"
)
func TestJoiner_ErrReferenceLength(t *testing.T) {
store := mock.NewStorer()
joiner := joiner.NewSimpleJoiner(store)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var err error
_, _, err = joiner.Join(ctx, swarm.ZeroAddress)
if !errors.Is(err, storage.ErrReferenceLength) {
t.Fatalf("expected ErrReferenceLength %x but got %v", swarm.ZeroAddress, err)
}
}
// TestJoiner verifies that a newly created joiner returns the data stored
// in the store when the reference is one single chunk.
func TestJoinerSingleChunk(t *testing.T) {
store := mock.NewStorer()
joiner := joiner.NewSimpleJoiner(store)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// create the chunk to
mockAddrHex := fmt.Sprintf("%064s", "2a")
mockAddr := swarm.MustParseHexAddress(mockAddrHex)
mockData := []byte("foo")
mockDataLengthBytes := make([]byte, 8)
mockDataLengthBytes[0] = 0x03
mockChunk := swarm.NewChunk(mockAddr, append(mockDataLengthBytes, mockData...))
_, err := store.Put(ctx, storage.ModePutUpload, mockChunk)
if err != nil {
t.Fatal(err)
}
// read back data and compare
joinReader, l, err := joiner.Join(ctx, mockAddr)
if err != nil {
t.Fatal(err)
}
if l != int64(len(mockData)) {
t.Fatalf("expected join data length %d, got %d", len(mockData), l)
}
joinData, err := ioutil.ReadAll(joinReader)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(joinData, mockData) {
t.Fatalf("retrieved data '%x' not like original data '%x'", joinData, mockData)
}
}
// TestJoinerDecryptingStore_NormalChunk verifies the the mock store that uses
// the decrypting store manages to retrieve a normal chunk which is not encrypted
func TestJoinerDecryptingStore_NormalChunk(t *testing.T) {
st := mock.NewStorer()
store := store.New(st)
joiner := joiner.NewSimpleJoiner(store)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// create the chunk to
mockAddrHex := fmt.Sprintf("%064s", "2a")
mockAddr := swarm.MustParseHexAddress(mockAddrHex)
mockData := []byte("foo")
mockDataLengthBytes := make([]byte, 8)
mockDataLengthBytes[0] = 0x03
mockChunk := swarm.NewChunk(mockAddr, append(mockDataLengthBytes, mockData...))
_, err := st.Put(ctx, storage.ModePutUpload, mockChunk)
if err != nil {
t.Fatal(err)
}
// read back data and compare
joinReader, l, err := joiner.Join(ctx, mockAddr)
if err != nil {
t.Fatal(err)
}
if l != int64(len(mockData)) {
t.Fatalf("expected join data length %d, got %d", len(mockData), l)
}
joinData, err := ioutil.ReadAll(joinReader)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(joinData, mockData) {
t.Fatalf("retrieved data '%x' not like original data '%x'", joinData, mockData)
}
}
// TestJoinerWithReference verifies that a chunk reference is correctly resolved
// and the underlying data is returned.
func TestJoinerWithReference(t *testing.T) {
store := mock.NewStorer()
joiner := joiner.NewSimpleJoiner(store)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// create root chunk and two data chunks referenced in the root chunk
rootChunk := filetest.GenerateTestRandomFileChunk(swarm.ZeroAddress, swarm.ChunkSize*2, swarm.SectionSize*2)
_, err := store.Put(ctx, storage.ModePutUpload, rootChunk)
if err != nil {
t.Fatal(err)
}
firstAddress := swarm.NewAddress(rootChunk.Data()[8 : swarm.SectionSize+8])
firstChunk := filetest.GenerateTestRandomFileChunk(firstAddress, swarm.ChunkSize, swarm.ChunkSize)
_, err = store.Put(ctx, storage.ModePutUpload, firstChunk)
if err != nil {
t.Fatal(err)
}
secondAddress := swarm.NewAddress(rootChunk.Data()[swarm.SectionSize+8:])
secondChunk := filetest.GenerateTestRandomFileChunk(secondAddress, swarm.ChunkSize, swarm.ChunkSize)
_, err = store.Put(ctx, storage.ModePutUpload, secondChunk)
if err != nil {
t.Fatal(err)
}
// read back data and compare
joinReader, l, err := joiner.Join(ctx, rootChunk.Address())
if err != nil {
t.Fatal(err)
}
if l != int64(swarm.ChunkSize*2) {
t.Fatalf("expected join data length %d, got %d", swarm.ChunkSize*2, l)
}
resultBuffer := make([]byte, swarm.ChunkSize)
n, err := joinReader.Read(resultBuffer)
if err != nil {
t.Fatal(err)
}
if n != len(resultBuffer) {
t.Fatalf("expected read count %d, got %d", len(resultBuffer), n)
}
if !bytes.Equal(resultBuffer, firstChunk.Data()[8:]) {
t.Fatalf("expected resultbuffer %v, got %v", resultBuffer, firstChunk.Data()[:len(resultBuffer)])
}
}
func TestEncryptDecrypt(t *testing.T) {
var tests = []struct {
chunkLength int
}{
{10},
{100},
{1000},
{4095},
{4096},
{4097},
{1000000},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Encrypt %d bytes", tt.chunkLength), func(t *testing.T) {
store := mock.NewStorer()
joiner := seekjoiner.NewSimpleJoiner(store)
g := mockbytes.New(0, mockbytes.MockTypeStandard).WithModulus(255)
testData, err := g.SequentialBytes(tt.chunkLength)
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
pipe := builder.NewPipelineBuilder(ctx, store, storage.ModePutUpload, true)
testDataReader := bytes.NewReader(testData)
resultAddress, err := builder.FeedPipeline(ctx, pipe, testDataReader, int64(len(testData)))
if err != nil {
t.Fatal(err)
}
reader, l, err := joiner.Join(context.Background(), resultAddress)
if err != nil {
t.Fatal(err)
}
if l != int64(len(testData)) {
t.Fatalf("expected join data length %d, got %d", len(testData), l)
}
totalGot := make([]byte, tt.chunkLength)
index := 0
resultBuffer := make([]byte, swarm.ChunkSize)
for index < tt.chunkLength {
n, err := reader.Read(resultBuffer)
if err != nil && err != io.EOF {
t.Fatal(err)
}
copy(totalGot[index:], resultBuffer[:n])
index += n
}
if !bytes.Equal(testData, totalGot) {
t.Fatal("input data and output data does not match")
}
})
}
}
...@@ -11,8 +11,8 @@ import ( ...@@ -11,8 +11,8 @@ import (
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/pipeline/builder" "github.com/ethersphere/bee/pkg/file/pipeline/builder"
"github.com/ethersphere/bee/pkg/file/seekjoiner"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/manifest/mantaray" "github.com/ethersphere/manifest/mantaray"
...@@ -161,10 +161,13 @@ func newMantaraySaver( ...@@ -161,10 +161,13 @@ func newMantaraySaver(
func (ls *mantarayLoadSaver) Load(ref []byte) ([]byte, error) { func (ls *mantarayLoadSaver) Load(ref []byte) ([]byte, error) {
ctx := ls.ctx ctx := ls.ctx
j := seekjoiner.NewSimpleJoiner(ls.storer) j, _, err := joiner.New(ctx, ls.storer, swarm.NewAddress(ref))
if err != nil {
return nil, err
}
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
_, err := file.JoinReadAll(ctx, j, swarm.NewAddress(ref), buf) _, err = file.JoinReadAll(ctx, j, buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -11,8 +11,8 @@ import ( ...@@ -11,8 +11,8 @@ import (
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/pipeline/builder" "github.com/ethersphere/bee/pkg/file/pipeline/builder"
"github.com/ethersphere/bee/pkg/file/seekjoiner"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/manifest/simple" "github.com/ethersphere/manifest/simple"
...@@ -120,10 +120,13 @@ func (m *simpleManifest) Store(ctx context.Context, mode storage.ModePut) (swarm ...@@ -120,10 +120,13 @@ func (m *simpleManifest) Store(ctx context.Context, mode storage.ModePut) (swarm
} }
func (m *simpleManifest) load(ctx context.Context, reference swarm.Address) error { func (m *simpleManifest) load(ctx context.Context, reference swarm.Address) error {
j := seekjoiner.NewSimpleJoiner(m.storer) j, _, err := joiner.New(ctx, m.storer, reference)
if err != nil {
return fmt.Errorf("new joiner: %w", err)
}
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
_, err := file.JoinReadAll(ctx, j, reference, buf) _, err = file.JoinReadAll(ctx, j, buf)
if err != nil { if err != nil {
return fmt.Errorf("manifest load error: %w", err) return fmt.Errorf("manifest load error: %w", err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment