Commit 4538fdfb authored by acud's avatar acud Committed by GitHub

pipeline: add and integrate encryption pipeline (#613)

* add encryption pipeline (#613) + integration (#618)
parent d28619b8
...@@ -19,8 +19,9 @@ import ( ...@@ -19,8 +19,9 @@ import (
) )
const ( const (
SwarmPinHeader = "Swarm-Pin" SwarmPinHeader = "Swarm-Pin"
SwarmTagUidHeader = "Swarm-Tag-Uid" SwarmTagUidHeader = "Swarm-Tag-Uid"
SwarmEncryptHeader = "Swarm-Encrypt"
) )
type Service interface { type Service interface {
...@@ -86,3 +87,7 @@ func requestModePut(r *http.Request) storage.ModePut { ...@@ -86,3 +87,7 @@ func requestModePut(r *http.Request) storage.ModePut {
} }
return storage.ModePutUpload return storage.ModePutUpload
} }
func requestEncrypt(r *http.Request) bool {
return strings.ToLower(r.Header.Get(SwarmEncryptHeader)) == "true"
}
...@@ -32,7 +32,7 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -32,7 +32,7 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) {
// Add the tag to the context // Add the tag to the context
ctx := sctx.SetTag(r.Context(), tag) ctx := sctx.SetTag(r.Context(), tag)
pipe := pipeline.NewPipeline(ctx, s.Storer, requestModePut(r)) pipe := pipeline.NewPipelineBuilder(ctx, s.Storer, requestModePut(r), requestEncrypt(r))
address, err := pipeline.FeedPipeline(ctx, pipe, r.Body, r.ContentLength) address, err := pipeline.FeedPipeline(ctx, pipe, r.Body, r.ContentLength)
if err != nil { if err != nil {
s.Logger.Debugf("bytes upload: split write all: %v", err) s.Logger.Debugf("bytes upload: split write all: %v", err)
......
...@@ -39,7 +39,7 @@ func TestBzz(t *testing.T) { ...@@ -39,7 +39,7 @@ func TestBzz(t *testing.T) {
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
}) })
pipeWriteAll = func(r io.Reader, l int64) (swarm.Address, error) { pipeWriteAll = func(r io.Reader, l int64) (swarm.Address, error) {
pipe := pipeline.NewPipeline(ctx, storer, storage.ModePutUpload) pipe := pipeline.NewPipelineBuilder(ctx, storer, storage.ModePutUpload, false)
return pipeline.FeedPipeline(ctx, pipe, r, l) return pipeline.FeedPipeline(ctx, pipe, r, l)
} }
) )
......
...@@ -15,7 +15,6 @@ import ( ...@@ -15,7 +15,6 @@ import (
"mime" "mime"
"net/http" "net/http"
"path/filepath" "path/filepath"
"strings"
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline"
...@@ -32,11 +31,9 @@ const ( ...@@ -32,11 +31,9 @@ const (
contentTypeTar = "application/x-tar" contentTypeTar = "application/x-tar"
) )
type toEncryptContextKey struct{}
// dirUploadHandler uploads a directory supplied as a tar in an HTTP request // dirUploadHandler uploads a directory supplied as a tar in an HTTP request
func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx, err := validateRequest(r) err := validateRequest(r)
if err != nil { if err != nil {
s.Logger.Errorf("dir upload, validate request") s.Logger.Errorf("dir upload, validate request")
s.Logger.Debugf("dir upload, validate request err: %v", err) s.Logger.Debugf("dir upload, validate request err: %v", err)
...@@ -53,9 +50,9 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -53,9 +50,9 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) {
} }
// Add the tag to the context // Add the tag to the context
ctx = sctx.SetTag(ctx, tag) ctx := sctx.SetTag(r.Context(), tag)
reference, err := storeDir(ctx, r.Body, s.Storer, requestModePut(r), s.Logger) reference, err := storeDir(ctx, r.Body, s.Storer, requestModePut(r), s.Logger, requestEncrypt(r))
if err != nil { if err != nil {
s.Logger.Debugf("dir upload, store dir err: %v", err) s.Logger.Debugf("dir upload, store dir err: %v", err)
s.Logger.Errorf("dir upload, store dir") s.Logger.Errorf("dir upload, store dir")
...@@ -72,31 +69,26 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -72,31 +69,26 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) {
} }
// validateRequest validates an HTTP request for a directory to be uploaded // validateRequest validates an HTTP request for a directory to be uploaded
// it returns a context based on the given request func validateRequest(r *http.Request) error {
func validateRequest(r *http.Request) (context.Context, error) {
ctx := r.Context()
if r.Body == http.NoBody { if r.Body == http.NoBody {
return nil, errors.New("request has no body") return errors.New("request has no body")
} }
contentType := r.Header.Get(contentTypeHeader) contentType := r.Header.Get(contentTypeHeader)
mediaType, _, err := mime.ParseMediaType(contentType) mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
return nil, err return err
} }
if mediaType != contentTypeTar { if mediaType != contentTypeTar {
return nil, errors.New("content-type not set to tar") return errors.New("content-type not set to tar")
} }
toEncrypt := strings.ToLower(r.Header.Get(EncryptHeader)) == "true" return nil
return context.WithValue(ctx, toEncryptContextKey{}, toEncrypt), nil
} }
// storeDir stores all files recursively contained in the directory given as a tar // storeDir stores all files recursively contained in the directory given as a tar
// it returns the hash for the uploaded manifest corresponding to the uploaded dir // it returns the hash for the uploaded manifest corresponding to the uploaded dir
func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode storage.ModePut, logger logging.Logger) (swarm.Address, error) { func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode storage.ModePut, logger logging.Logger, encrypt bool) (swarm.Address, error) {
v := ctx.Value(toEncryptContextKey{})
toEncrypt, _ := v.(bool) // default is false
dirManifest, err := manifest.NewDefaultManifest(toEncrypt, s) dirManifest, err := manifest.NewDefaultManifest(encrypt, s)
if err != nil { if err != nil {
return swarm.ZeroAddress, err return swarm.ZeroAddress, err
} }
...@@ -134,7 +126,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode ...@@ -134,7 +126,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode
contentType: contentType, contentType: contentType,
reader: tarReader, reader: tarReader,
} }
fileReference, err := storeFile(ctx, fileInfo, s, mode) fileReference, err := storeFile(ctx, fileInfo, s, mode, encrypt)
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("store dir file: %w", err) return swarm.ZeroAddress, fmt.Errorf("store dir file: %w", err)
} }
...@@ -168,7 +160,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode ...@@ -168,7 +160,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode
return swarm.ZeroAddress, fmt.Errorf("metadata marshal: %w", err) return swarm.ZeroAddress, fmt.Errorf("metadata marshal: %w", err)
} }
pipe := pipeline.NewPipeline(ctx, s, mode) pipe := pipeline.NewPipelineBuilder(ctx, s, mode, encrypt)
mr, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(metadataBytes), int64(len(metadataBytes))) mr, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(metadataBytes), int64(len(metadataBytes)))
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split metadata: %w", err) return swarm.ZeroAddress, fmt.Errorf("split metadata: %w", err)
...@@ -181,7 +173,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode ...@@ -181,7 +173,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode
return swarm.ZeroAddress, fmt.Errorf("entry marshal: %w", err) return swarm.ZeroAddress, fmt.Errorf("entry marshal: %w", err)
} }
pipe = pipeline.NewPipeline(ctx, s, mode) pipe = pipeline.NewPipelineBuilder(ctx, s, mode, encrypt)
manifestFileReference, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes))) manifestFileReference, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)))
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split entry: %w", err) return swarm.ZeroAddress, fmt.Errorf("split entry: %w", err)
...@@ -192,9 +184,9 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode ...@@ -192,9 +184,9 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode
// storeFile uploads the given file and returns its reference // storeFile uploads the given file and returns its reference
// this function was extracted from `fileUploadHandler` and should eventually replace its current code // this function was extracted from `fileUploadHandler` and should eventually replace its current code
func storeFile(ctx context.Context, fileInfo *fileUploadInfo, s storage.Storer, mode storage.ModePut) (swarm.Address, error) { func storeFile(ctx context.Context, fileInfo *fileUploadInfo, s storage.Storer, mode storage.ModePut, encrypt bool) (swarm.Address, error) {
// first store the file and get its reference // first store the file and get its reference
pipe := pipeline.NewPipeline(ctx, s, mode) pipe := pipeline.NewPipelineBuilder(ctx, s, mode, encrypt)
fr, err := pipeline.FeedPipeline(ctx, pipe, fileInfo.reader, fileInfo.size) fr, err := pipeline.FeedPipeline(ctx, pipe, fileInfo.reader, fileInfo.size)
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split file: %w", err) return swarm.ZeroAddress, fmt.Errorf("split file: %w", err)
...@@ -213,7 +205,7 @@ func storeFile(ctx context.Context, fileInfo *fileUploadInfo, s storage.Storer, ...@@ -213,7 +205,7 @@ func storeFile(ctx context.Context, fileInfo *fileUploadInfo, s storage.Storer,
return swarm.ZeroAddress, fmt.Errorf("metadata marshal: %w", err) return swarm.ZeroAddress, fmt.Errorf("metadata marshal: %w", err)
} }
pipe = pipeline.NewPipeline(ctx, s, mode) pipe = pipeline.NewPipelineBuilder(ctx, s, mode, encrypt)
mr, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(metadataBytes), int64(len(metadataBytes))) mr, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(metadataBytes), int64(len(metadataBytes)))
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split metadata: %w", err) return swarm.ZeroAddress, fmt.Errorf("split metadata: %w", err)
...@@ -225,7 +217,7 @@ func storeFile(ctx context.Context, fileInfo *fileUploadInfo, s storage.Storer, ...@@ -225,7 +217,7 @@ func storeFile(ctx context.Context, fileInfo *fileUploadInfo, s storage.Storer,
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("entry marshal: %w", err) return swarm.ZeroAddress, fmt.Errorf("entry marshal: %w", err)
} }
pipe = pipeline.NewPipeline(ctx, s, mode) pipe = pipeline.NewPipelineBuilder(ctx, s, mode, encrypt)
reference, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes))) reference, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)))
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split entry: %w", err) return swarm.ZeroAddress, fmt.Errorf("split entry: %w", err)
......
...@@ -34,7 +34,6 @@ import ( ...@@ -34,7 +34,6 @@ import (
const ( const (
multiPartFormData = "multipart/form-data" multiPartFormData = "multipart/form-data"
EncryptHeader = "swarm-encrypt"
) )
// fileUploadResponse is returned when an HTTP request to upload a file is successful // fileUploadResponse is returned when an HTTP request to upload a file is successful
...@@ -153,7 +152,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -153,7 +152,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
} }
// first store the file and get its reference // first store the file and get its reference
pipe := pipeline.NewPipeline(ctx, s.Storer, mode) pipe := pipeline.NewPipelineBuilder(ctx, s.Storer, mode, requestEncrypt(r))
fr, err := pipeline.FeedPipeline(ctx, pipe, reader, int64(fileSize)) fr, err := pipeline.FeedPipeline(ctx, pipe, reader, int64(fileSize))
if err != nil { if err != nil {
s.Logger.Debugf("file upload: file store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: file store, file %q: %v", fileName, err)
...@@ -177,7 +176,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -177,7 +176,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
jsonhttp.InternalServerError(w, "metadata marshal error") jsonhttp.InternalServerError(w, "metadata marshal error")
return return
} }
pipe = pipeline.NewPipeline(ctx, s.Storer, mode) pipe = pipeline.NewPipelineBuilder(ctx, s.Storer, mode, requestEncrypt(r))
mr, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(metadataBytes), int64(len(metadataBytes))) mr, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(metadataBytes), int64(len(metadataBytes)))
if err != nil { if err != nil {
s.Logger.Debugf("file upload: metadata store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: metadata store, file %q: %v", fileName, err)
...@@ -195,7 +194,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -195,7 +194,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
jsonhttp.InternalServerError(w, "entry marshal error") jsonhttp.InternalServerError(w, "entry marshal error")
return return
} }
pipe = pipeline.NewPipeline(ctx, s.Storer, mode) pipe = pipeline.NewPipelineBuilder(ctx, s.Storer, mode, requestEncrypt(r))
reference, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes))) reference, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)))
if err != nil { if err != nil {
s.Logger.Debugf("file upload: entry store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: entry store, file %q: %v", fileName, err)
......
...@@ -65,7 +65,7 @@ func TestFiles(t *testing.T) { ...@@ -65,7 +65,7 @@ func TestFiles(t *testing.T) {
var resp api.FileUploadResponse var resp api.FileUploadResponse
jsonhttptest.Request(t, client, http.MethodPost, fileUploadResource+"?name="+fileName, http.StatusOK, jsonhttptest.Request(t, client, http.MethodPost, fileUploadResource+"?name="+fileName, http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(simpleData)), jsonhttptest.WithRequestBody(bytes.NewReader(simpleData)),
jsonhttptest.WithRequestHeader(api.EncryptHeader, "True"), jsonhttptest.WithRequestHeader(api.SwarmEncryptHeader, "True"),
jsonhttptest.WithRequestHeader("Content-Type", "image/jpeg; charset=utf-8"), jsonhttptest.WithRequestHeader("Content-Type", "image/jpeg; charset=utf-8"),
jsonhttptest.WithUnmarshalJSONResponse(&resp), jsonhttptest.WithUnmarshalJSONResponse(&resp),
) )
......
...@@ -12,9 +12,8 @@ import ( ...@@ -12,9 +12,8 @@ import (
"io/ioutil" "io/ioutil"
"testing" "testing"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/pipeline"
filetest "github.com/ethersphere/bee/pkg/file/testing" filetest "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
...@@ -146,9 +145,10 @@ func TestEncryptionAndDecryption(t *testing.T) { ...@@ -146,9 +145,10 @@ func TestEncryptionAndDecryption(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
s := splitter.NewSimpleSplitter(store, storage.ModePutUpload) ctx := context.Background()
testDataReader := file.NewSimpleReadCloser(testData) pipe := pipeline.NewPipelineBuilder(ctx, store, storage.ModePutUpload, true)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), true) testDataReader := bytes.NewReader(testData)
resultAddress, err := pipeline.FeedPipeline(ctx, pipe, testDataReader, int64(len(testData)))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
package pipeline
import (
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/crypto/sha3"
)
type encryptionWriter struct {
next chainWriter
}
func newEncryptionWriter(next chainWriter) chainWriter {
return &encryptionWriter{
next: next,
}
}
// Write assumes that the span is prepended to the actual data before the write !
func (e *encryptionWriter) chainWrite(p *pipeWriteArgs) error {
key, encryptedSpan, encryptedData, err := encrypt(p.data)
if err != nil {
return err
}
c := make([]byte, len(encryptedSpan)+len(encryptedData))
copy(c[:8], encryptedSpan)
copy(c[8:], encryptedData)
p.data = c // replace the verbatim data with the encrypted data
p.key = key
return e.next.chainWrite(p)
}
func (e *encryptionWriter) sum() ([]byte, error) {
return e.next.sum()
}
func encrypt(chunkData []byte) (encryption.Key, []byte, []byte, error) {
key := encryption.GenerateRandomKey(encryption.KeyLength)
encryptedSpan, err := newSpanEncryption(key).Encrypt(chunkData[:8])
if err != nil {
return nil, nil, nil, err
}
encryptedData, err := newDataEncryption(key).Encrypt(chunkData[8:])
if err != nil {
return nil, nil, nil, err
}
return key, encryptedSpan, encryptedData, nil
}
func newSpanEncryption(key encryption.Key) *encryption.Encryption {
refSize := int64(swarm.HashSize + encryption.KeyLength)
return encryption.New(key, 0, uint32(swarm.ChunkSize/refSize), sha3.NewLegacyKeccak256)
}
func newDataEncryption(key encryption.Key) *encryption.Encryption {
return encryption.New(key, int(swarm.ChunkSize), 0, sha3.NewLegacyKeccak256)
}
...@@ -38,16 +38,24 @@ func newHashTrieWriter(chunkSize, branching, refLen int, pipelineFn pipelineFunc ...@@ -38,16 +38,24 @@ func newHashTrieWriter(chunkSize, branching, refLen int, pipelineFn pipelineFunc
// accepts writes of hashes from the previous writer in the chain, by definition these writes // accepts writes of hashes from the previous writer in the chain, by definition these writes
// are on level 1 // are on level 1
func (h *hashTrieWriter) chainWrite(p *pipeWriteArgs) error { func (h *hashTrieWriter) chainWrite(p *pipeWriteArgs) error {
return h.writeToLevel(1, p.span, p.ref) oneRef := h.refSize + swarm.SpanSize
l := len(p.span) + len(p.ref) + len(p.key)
if l%oneRef != 0 {
return errInconsistentRefs
}
return h.writeToLevel(1, p.span, p.ref, p.key)
} }
func (h *hashTrieWriter) writeToLevel(level int, span, ref []byte) error { func (h *hashTrieWriter) writeToLevel(level int, span, ref, key []byte) error {
copy(h.buffer[h.cursors[level]:h.cursors[level]+len(span)], span) //copy the span slongside copy(h.buffer[h.cursors[level]:h.cursors[level]+len(span)], span) //copy the span slongside
h.cursors[level] += len(span) h.cursors[level] += len(span)
copy(h.buffer[h.cursors[level]:h.cursors[level]+len(ref)], ref) copy(h.buffer[h.cursors[level]:h.cursors[level]+len(ref)], ref)
h.cursors[level] += len(ref) h.cursors[level] += len(ref)
copy(h.buffer[h.cursors[level]:h.cursors[level]+len(key)], key)
h.cursors[level] += len(key)
howLong := (h.refSize + swarm.SpanSize) * h.branching howLong := (h.refSize + swarm.SpanSize) * h.branching
if h.levelSize(level) == howLong { if h.levelSize(level) == howLong {
return h.wrapFullLevel(level) return h.wrapFullLevel(level)
} }
...@@ -88,7 +96,7 @@ func (h *hashTrieWriter) wrapFullLevel(level int) error { ...@@ -88,7 +96,7 @@ func (h *hashTrieWriter) wrapFullLevel(level int) error {
if err != nil { if err != nil {
return err return err
} }
err = h.writeToLevel(level+1, args.span, args.ref) err = h.writeToLevel(level+1, args.span, args.ref, args.key)
if err != nil { if err != nil {
return err return err
} }
...@@ -154,8 +162,8 @@ func (h *hashTrieWriter) hoistLevels(target int) ([]byte, error) { ...@@ -154,8 +162,8 @@ func (h *hashTrieWriter) hoistLevels(target int) ([]byte, error) {
span: spb, span: spb,
} }
err := writer.chainWrite(&args) err := writer.chainWrite(&args)
ref := append(args.ref, args.key...)
return args.ref, err return ref, err
} }
func (h *hashTrieWriter) levelSize(level int) int { func (h *hashTrieWriter) levelSize(level int) int {
......
...@@ -9,20 +9,30 @@ import ( ...@@ -9,20 +9,30 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
type pipeWriteArgs struct { type pipeWriteArgs struct {
ref []byte ref []byte // reference, generated by bmt
span []byte key []byte // encryption key
data []byte //data includes the span too span []byte // always unecrypted span uint64
data []byte // data includes the span too, but it may be encrypted when the pipeline is encrypted
} }
// NewPipeline creates a standard pipeline that only hashes content with BMT to create // NewPipelineBuilder returns the appropriate pipeline according to the specified parameters
func NewPipelineBuilder(ctx context.Context, s storage.Storer, mode storage.ModePut, encrypt bool) Interface {
if encrypt {
return newEncryptionPipeline(ctx, s, mode)
}
return newPipeline(ctx, s, mode)
}
// newPipeline creates a standard pipeline that only hashes content with BMT to create
// a merkle-tree of hashes that represent the given arbitrary size byte stream. Partial // a merkle-tree of hashes that represent the given arbitrary size byte stream. Partial
// writes are supported. The pipeline flow is: Data -> Feeder -> BMT -> Storage -> HashTrie. // writes are supported. The pipeline flow is: Data -> Feeder -> BMT -> Storage -> HashTrie.
func NewPipeline(ctx context.Context, s storage.Storer, mode storage.ModePut) Interface { func newPipeline(ctx context.Context, s storage.Storer, mode storage.ModePut) Interface {
tw := newHashTrieWriter(swarm.ChunkSize, swarm.Branches, swarm.HashSize, newShortPipelineFunc(ctx, s, mode)) tw := newHashTrieWriter(swarm.ChunkSize, swarm.Branches, swarm.HashSize, newShortPipelineFunc(ctx, s, mode))
lsw := newStoreWriter(ctx, s, mode, tw) lsw := newStoreWriter(ctx, s, mode, tw)
b := newBmtWriter(128, lsw) b := newBmtWriter(128, lsw)
...@@ -40,6 +50,29 @@ func newShortPipelineFunc(ctx context.Context, s storage.Storer, mode storage.Mo ...@@ -40,6 +50,29 @@ func newShortPipelineFunc(ctx context.Context, s storage.Storer, mode storage.Mo
} }
} }
// newEncryptionPipeline creates an encryption pipeline that encrypts using CTR, hashes content with BMT to create
// a merkle-tree of hashes that represent the given arbitrary size byte stream. Partial
// writes are supported. The pipeline flow is: Data -> Feeder -> Encryption -> BMT -> Storage -> HashTrie.
// Note that the encryption writer will mutate the data to contain the encrypted span, but the span field
// with the unencrypted span is preserved.
func newEncryptionPipeline(ctx context.Context, s storage.Storer, mode storage.ModePut) Interface {
tw := newHashTrieWriter(swarm.ChunkSize, 64, swarm.HashSize+encryption.KeyLength, newShortEncryptionPipelineFunc(ctx, s, mode))
lsw := newStoreWriter(ctx, s, mode, tw)
b := newBmtWriter(128, lsw)
enc := newEncryptionWriter(b)
return newChunkFeederWriter(swarm.ChunkSize, enc)
}
// newShortEncryptionPipelineFunc returns a constructor function for an ephemeral hashing pipeline
// needed by the hashTrieWriter.
func newShortEncryptionPipelineFunc(ctx context.Context, s storage.Storer, mode storage.ModePut) func() chainWriter {
return func() chainWriter {
lsw := newStoreWriter(ctx, s, mode, nil)
b := newBmtWriter(128, lsw)
return newEncryptionWriter(b)
}
}
// FeedPipeline feeds the pipeline with the given reader until EOF is reached. // FeedPipeline feeds the pipeline with the given reader until EOF is reached.
// It returns the cryptographic root hash of the content. // It returns the cryptographic root hash of the content.
func FeedPipeline(ctx context.Context, pipeline Interface, r io.Reader, dataLength int64) (addr swarm.Address, err error) { func FeedPipeline(ctx context.Context, pipeline Interface, r io.Reader, dataLength int64) (addr swarm.Address, err error) {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package pipeline package pipeline_test
import ( import (
"bytes" "bytes"
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/ethersphere/bee/pkg/file/pipeline"
test "github.com/ethersphere/bee/pkg/file/testing" test "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
...@@ -19,7 +20,7 @@ import ( ...@@ -19,7 +20,7 @@ import (
func TestPartialWrites(t *testing.T) { func TestPartialWrites(t *testing.T) {
m := mock.NewStorer() m := mock.NewStorer()
p := NewPipeline(context.Background(), m, storage.ModePutUpload) p := pipeline.NewPipelineBuilder(context.Background(), m, storage.ModePutUpload, false)
_, _ = p.Write([]byte("hello ")) _, _ = p.Write([]byte("hello "))
_, _ = p.Write([]byte("world")) _, _ = p.Write([]byte("world"))
...@@ -35,7 +36,7 @@ func TestPartialWrites(t *testing.T) { ...@@ -35,7 +36,7 @@ func TestPartialWrites(t *testing.T) {
func TestHelloWorld(t *testing.T) { func TestHelloWorld(t *testing.T) {
m := mock.NewStorer() m := mock.NewStorer()
p := NewPipeline(context.Background(), m, storage.ModePutUpload) p := pipeline.NewPipelineBuilder(context.Background(), m, storage.ModePutUpload, false)
data := []byte("hello world") data := []byte("hello world")
_, err := p.Write(data) _, err := p.Write(data)
...@@ -58,7 +59,7 @@ func TestAllVectors(t *testing.T) { ...@@ -58,7 +59,7 @@ func TestAllVectors(t *testing.T) {
data, expect := test.GetVector(t, i) data, expect := test.GetVector(t, i)
t.Run(fmt.Sprintf("data length %d, vector %d", len(data), i), func(t *testing.T) { t.Run(fmt.Sprintf("data length %d, vector %d", len(data), i), func(t *testing.T) {
m := mock.NewStorer() m := mock.NewStorer()
p := NewPipeline(context.Background(), m, storage.ModePutUpload) p := pipeline.NewPipelineBuilder(context.Background(), m, storage.ModePutUpload, false)
_, err := p.Write(data) _, err := p.Write(data)
if err != nil { if err != nil {
......
...@@ -106,9 +106,8 @@ func (m *simpleManifest) Store(ctx context.Context, mode storage.ModePut) (swarm ...@@ -106,9 +106,8 @@ func (m *simpleManifest) Store(ctx context.Context, mode storage.ModePut) (swarm
return swarm.ZeroAddress, fmt.Errorf("manifest marshal error: %w", err) return swarm.ZeroAddress, fmt.Errorf("manifest marshal error: %w", err)
} }
pipe := pipeline.NewPipeline(ctx, m.storer, mode) pipe := pipeline.NewPipelineBuilder(ctx, m.storer, mode, m.encrypted)
address, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data))) address, err := pipeline.FeedPipeline(ctx, pipe, bytes.NewReader(data), int64(len(data)))
_ = m.encrypted // need this field for encryption but this is to avoid linter complaints
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("manifest save error: %w", err) return swarm.ZeroAddress, fmt.Errorf("manifest save error: %w", err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment