Commit d39c5638 authored by Zahoor Mohamed's avatar Zahoor Mohamed Committed by GitHub

Porting Swarm encryption in Bee (#320)

Encryption Support for Bee
parent 0bc0be56
......@@ -59,7 +59,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
writeCloser := cmdfile.NopWriteCloser(buf)
limitBuf := cmdfile.NewLimitWriteCloser(writeCloser, limitMetadataLength)
j := joiner.NewSimpleJoiner(store)
_, err = file.JoinReadAll(j, addr, limitBuf)
_, err = file.JoinReadAll(j, addr, limitBuf, false)
if err != nil {
return err
}
......@@ -70,7 +70,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
}
buf = bytes.NewBuffer(nil)
_, err = file.JoinReadAll(j, e.Metadata(), buf)
_, err = file.JoinReadAll(j, e.Metadata(), buf, false)
if err != nil {
return err
}
......@@ -116,7 +116,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
return err
}
defer outFile.Close()
_, err = file.JoinReadAll(j, e.Reference(), outFile)
_, err = file.JoinReadAll(j, e.Reference(), outFile, false)
return err
}
......@@ -167,7 +167,7 @@ func putEntry(cmd *cobra.Command, args []string) (err error) {
metadataBuf := bytes.NewBuffer(metadataBytes)
metadataReader := io.LimitReader(metadataBuf, int64(len(metadataBytes)))
metadataReadCloser := ioutil.NopCloser(metadataReader)
metadataAddr, err := s.Split(ctx, metadataReadCloser, int64(len(metadataBytes)))
metadataAddr, err := s.Split(ctx, metadataReadCloser, int64(len(metadataBytes)), false)
if err != nil {
return err
}
......@@ -182,7 +182,7 @@ func putEntry(cmd *cobra.Command, args []string) (err error) {
fileEntryBuf := bytes.NewBuffer(fileEntryBytes)
fileEntryReader := io.LimitReader(fileEntryBuf, int64(len(fileEntryBytes)))
fileEntryReadCloser := ioutil.NopCloser(fileEntryReader)
fileEntryAddr, err := s.Split(ctx, fileEntryReadCloser, int64(len(fileEntryBytes)))
fileEntryAddr, err := s.Split(ctx, fileEntryReadCloser, int64(len(fileEntryBytes)), false)
if err != nil {
return err
}
......
......@@ -83,7 +83,7 @@ func Join(cmd *cobra.Command, args []string) (err error) {
// create the join and get its data reader
j := joiner.NewSimpleJoiner(store)
_, err = file.JoinReadAll(j, addr, outFile)
_, err = file.JoinReadAll(j, addr, outFile, false)
return err
}
......
......@@ -96,7 +96,7 @@ func Split(cmd *cobra.Command, args []string) (err error) {
s := splitter.NewSimpleSplitter(stores)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
addr, err := s.Split(ctx, infile, inputLength)
addr, err := s.Split(ctx, infile, inputLength, false)
if err != nil {
return err
}
......
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo=
cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
......@@ -919,7 +918,6 @@ golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190227160552-c95aed5357e7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
......@@ -1145,4 +1143,4 @@ rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck=
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0=
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0=
\ No newline at end of file
......@@ -10,7 +10,9 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter"
......@@ -27,8 +29,10 @@ type bytesPostResponse struct {
// bytesUploadHandler handles upload of raw binary data of arbitrary length.
func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
toEncrypt := strings.ToLower(r.Header.Get(EncryptHeader)) == "true"
sp := splitter.NewSimpleSplitter(s.Storer)
address, err := file.SplitWriteAll(ctx, sp, r.Body, r.ContentLength)
address, err := file.SplitWriteAll(ctx, sp, r.Body, r.ContentLength, toEncrypt)
if err != nil {
s.Logger.Debugf("bytes upload: %v", err)
jsonhttp.InternalServerError(w, nil)
......@@ -52,8 +56,8 @@ func (s *server) bytesGetHandler(w http.ResponseWriter, r *http.Request) {
return
}
toDecrypt := len(address.Bytes()) == (swarm.HashSize + encryption.KeyLength)
j := joiner.NewSimpleJoiner(s.Storer)
dataSize, err := j.Size(ctx, address)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
......@@ -69,7 +73,7 @@ func (s *server) bytesGetHandler(w http.ResponseWriter, r *http.Request) {
}
outBuffer := bytes.NewBuffer(nil)
c, err := file.JoinReadAll(j, address, outBuffer)
c, err := file.JoinReadAll(j, address, outBuffer, toDecrypt)
if err != nil && c == 0 {
s.Logger.Debugf("bytes download: data join %s: %v", address, err)
s.Logger.Errorf("bytes download: data join %s", address)
......
......@@ -17,8 +17,10 @@ import (
"net/http"
"os"
"strconv"
"strings"
"github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter"
......@@ -28,7 +30,10 @@ import (
"github.com/gorilla/mux"
)
const multipartFormDataMediaType = "multipart/form-data"
const (
multiPartFormData = "multipart/form-data"
EncryptHeader = "swarm-encrypt"
)
type fileUploadResponse struct {
Reference swarm.Address `json:"reference"`
......@@ -38,6 +43,7 @@ type fileUploadResponse struct {
// - multipart http message
// - other content types as complete file body
func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
toEncrypt := strings.ToLower(r.Header.Get(EncryptHeader)) == "true"
contentType := r.Header.Get("Content-Type")
mediaType, params, err := mime.ParseMediaType(contentType)
if err != nil {
......@@ -52,7 +58,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
var fileName, contentLength string
var fileSize uint64
if mediaType == multipartFormDataMediaType {
if mediaType == multiPartFormData {
mr := multipart.NewReader(r.Body, params["boundary"])
// read only the first part, as only one file upload is supported
......@@ -133,7 +139,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
// first store the file and get its reference
sp := splitter.NewSimpleSplitter(s.Storer)
fr, err := file.SplitWriteAll(ctx, sp, reader, int64(fileSize))
fr, err := file.SplitWriteAll(ctx, sp, reader, int64(fileSize), toEncrypt)
if err != nil {
s.Logger.Debugf("file upload: file store, file %q: %v", fileName, err)
s.Logger.Errorf("file upload: file store, file %q", fileName)
......@@ -157,7 +163,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
return
}
sp = splitter.NewSimpleSplitter(s.Storer)
mr, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(metadataBytes), int64(len(metadataBytes)))
mr, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(metadataBytes), int64(len(metadataBytes)), toEncrypt)
if err != nil {
s.Logger.Debugf("file upload: metadata store, file %q: %v", fileName, err)
s.Logger.Errorf("file upload: metadata store, file %q", fileName)
......@@ -174,9 +180,8 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
jsonhttp.InternalServerError(w, "entry marshal error")
return
}
sp = splitter.NewSimpleSplitter(s.Storer)
reference, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)))
reference, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)), toEncrypt)
if err != nil {
s.Logger.Debugf("file upload: entry store, file %q: %v", fileName, err)
s.Logger.Errorf("file upload: entry store, file %q", fileName)
......@@ -200,10 +205,12 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
return
}
toDecrypt := len(address.Bytes()) == (swarm.HashSize + encryption.KeyLength)
// read entry.
j := joiner.NewSimpleJoiner(s.Storer)
buf := bytes.NewBuffer(nil)
_, err = file.JoinReadAll(j, address, buf)
_, err = file.JoinReadAll(j, address, buf, toDecrypt)
if err != nil {
s.Logger.Debugf("file download: read entry %s: %v", addr, err)
s.Logger.Errorf("file download: read entry %s", addr)
......@@ -231,7 +238,7 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
// Read metadata.
buf = bytes.NewBuffer(nil)
_, err = file.JoinReadAll(j, e.Metadata(), buf)
_, err = file.JoinReadAll(j, e.Metadata(), buf, toDecrypt)
if err != nil {
s.Logger.Debugf("file download: read metadata %s: %v", addr, err)
s.Logger.Errorf("file download: read metadata %s", addr)
......@@ -276,7 +283,7 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
}()
go func() {
_, err := file.JoinReadAll(j, e.Reference(), pw)
_, err := file.JoinReadAll(j, e.Reference(), pw, toDecrypt)
if err := pw.CloseWithError(err); err != nil {
s.Logger.Debugf("file download: data join close %s: %v", addr, err)
s.Logger.Errorf("file download: data join close %s", addr)
......
......@@ -49,6 +49,31 @@ func TestFiles(t *testing.T) {
})
})
t.Run("encrypt-decrypt", func(t *testing.T) {
fileName := "my-pictures.jpeg"
rootHash := "f2e761160deda91c1fbfab065a5abf530b0766b3e102b51fbd626ba37c3bc581"
headers := make(http.Header)
headers.Add("EncryptHeader", "True")
headers.Add("Content-Type", "image/jpeg; charset=utf-8")
_ = jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, client, http.MethodPost, fileUploadResource+"?name="+fileName, bytes.NewReader(simpleData), http.StatusOK, api.FileUploadResponse{
Reference: swarm.MustParseHexAddress(rootHash),
}, headers)
rcvdHeader := jsonhttptest.ResponseDirectCheckBinaryResponse(t, client, http.MethodGet, fileDownloadResource(rootHash), nil, http.StatusOK, simpleData, nil)
cd := rcvdHeader.Get("Content-Disposition")
_, params, err := mime.ParseMediaType(cd)
if err != nil {
t.Fatal(err)
}
if params["filename"] != fileName {
t.Fatal("Invalid file name detected")
}
if rcvdHeader.Get("Content-Type") != "image/jpeg; charset=utf-8" {
t.Fatal("Invalid content type detected")
}
})
t.Run("check-content-type-detection", func(t *testing.T) {
fileName := "my-pictures.jpeg"
rootHash := "f2e761160deda91c1fbfab065a5abf530b0766b3e102b51fbd626ba37c3bc581"
......
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package encryption
import (
"crypto/rand"
"encoding/binary"
"fmt"
"hash"
"sync"
)
const KeyLength = 32
type Key []byte
type Encryptor interface {
Encrypt(data []byte) ([]byte, error)
Decrypt(data []byte) ([]byte, error)
Reset()
}
type Encryption struct {
key Key // the encryption key (hashSize bytes long)
keyLen int // length of the key = length of blockcipher block
padding int // encryption will pad the data upto this if > 0
index int // counter index
initCtr uint32 // initial counter used for counter mode blockcipher
hashFunc func() hash.Hash // hasher constructor function
}
// New constructs a new encryptor/decryptor
func New(key Key, padding int, initCtr uint32, hashFunc func() hash.Hash) *Encryption {
return &Encryption{
key: key,
keyLen: len(key),
padding: padding,
initCtr: initCtr,
hashFunc: hashFunc,
}
}
// Encrypt encrypts the data and does padding if specified
func (e *Encryption) Encrypt(data []byte) ([]byte, error) {
length := len(data)
outLength := length
isFixedPadding := e.padding > 0
if isFixedPadding {
if length > e.padding {
return nil, fmt.Errorf("data length longer than padding, data length %v padding %v", length, e.padding)
}
outLength = e.padding
}
out := make([]byte, outLength)
err := e.transform(data, out)
if err != nil {
return nil, err
}
return out, nil
}
// Decrypt decrypts the data, if padding was used caller must know original length and truncate
func (e *Encryption) Decrypt(data []byte) ([]byte, error) {
length := len(data)
if e.padding > 0 && length != e.padding {
return nil, fmt.Errorf("data length different than padding, data length %v padding %v", length, e.padding)
}
out := make([]byte, length)
err := e.transform(data, out)
if err != nil {
return nil, err
}
return out, nil
}
// Reset resets the counter. It is only safe to call after an encryption operation is completed
// After Reset is called, the Encryption object can be re-used for other data
func (e *Encryption) Reset() {
e.index = 0
}
// split up input into keylength segments and encrypt sequentially
func (e *Encryption) transform(in, out []byte) error {
inLength := len(in)
wg := sync.WaitGroup{}
wg.Add((inLength-1)/e.keyLen + 1)
for i := 0; i < inLength; i += e.keyLen {
errs := make(chan error, 1)
l := min(e.keyLen, inLength-i)
go func(i int, x, y []byte) {
defer wg.Done()
err := e.Transcrypt(i, x, y)
errs <- err
}(e.index, in[i:i+l], out[i:i+l])
e.index++
err := <-errs
if err != nil {
close((errs))
return err
}
}
// pad the rest if out is longer
pad(out[inLength:])
wg.Wait()
return nil
}
// used for segmentwise transformation
// if in is shorter than out, padding is used
func (e *Encryption) Transcrypt(i int, in, out []byte) error {
// first hash key with counter (initial counter + i)
hasher := e.hashFunc()
_, err := hasher.Write(e.key)
if err != nil {
return err
}
ctrBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(ctrBytes, uint32(i)+e.initCtr)
_, err = hasher.Write(ctrBytes)
if err != nil {
return err
}
ctrHash := hasher.Sum(nil)
hasher.Reset()
// second round of hashing for selective disclosure
_, err = hasher.Write(ctrHash)
if err != nil {
return err
}
segmentKey := hasher.Sum(nil)
hasher.Reset()
// XOR bytes uptil length of in (out must be at least as long)
inLength := len(in)
for j := 0; j < inLength; j++ {
out[j] = in[j] ^ segmentKey[j]
}
// insert padding if out is longer
pad(out[inLength:])
return nil
}
func pad(b []byte) {
l := len(b)
for total := 0; total < l; {
read, _ := rand.Read(b[total:])
total += read
}
}
// GenerateRandomKey generates a random key of length l
func GenerateRandomKey(l int) Key {
key := make([]byte, l)
var total int
for total < l {
read, _ := rand.Read(key[total:])
total += read
}
return key
}
func min(x, y int) int {
if x < y {
return x
}
return y
}
This diff is collapsed.
......@@ -24,7 +24,7 @@ var (
// returning the length of the data which will be returned.
// The called can then read the data on the io.Reader that was provided.
type Joiner interface {
Join(ctx context.Context, address swarm.Address) (dataOut io.ReadCloser, dataLength int64, err error)
Join(ctx context.Context, address swarm.Address, toDecrypt bool) (dataOut io.ReadCloser, dataLength int64, err error)
Size(ctx context.Context, address swarm.Address) (dataLength int64, err error)
}
......@@ -34,12 +34,12 @@ type Joiner interface {
// If the dataLength parameter is 0, data is read until io.EOF is encountered.
// When EOF is received and splitting is done, the resulting Swarm Address is returned.
type Splitter interface {
Split(ctx context.Context, dataIn io.ReadCloser, dataLength int64) (addr swarm.Address, err error)
Split(ctx context.Context, dataIn io.ReadCloser, dataLength int64, toEncrypt bool) (addr swarm.Address, err error)
}
// JoinReadAll reads all output from the provided joiner.
func JoinReadAll(j Joiner, addr swarm.Address, outFile io.Writer) (int64, error) {
r, l, err := j.Join(context.Background(), addr)
func JoinReadAll(j Joiner, addr swarm.Address, outFile io.Writer, toDecrypt bool) (int64, error) {
r, l, err := j.Join(context.Background(), addr, toDecrypt)
if err != nil {
return 0, err
}
......@@ -67,7 +67,7 @@ func JoinReadAll(j Joiner, addr swarm.Address, outFile io.Writer) (int64, error)
}
// SplitWriteAll writes all input from provided reader to the provided splitter
func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64) (swarm.Address, error) {
func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64, toEncrypt bool) (swarm.Address, error) {
chunkPipe := NewChunkPipe()
errC := make(chan error)
go func() {
......@@ -86,7 +86,7 @@ func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64) (swarm
close(errC)
}()
addr, err := s.Split(ctx, chunkPipe, l)
addr, err := s.Split(ctx, chunkPipe, l, toEncrypt)
if err != nil {
return swarm.ZeroAddress, err
}
......
......@@ -53,13 +53,13 @@ func testSplitThenJoin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
dataReader := file.NewSimpleReadCloser(data)
resultAddress, err := s.Split(ctx, dataReader, int64(len(data)))
resultAddress, err := s.Split(ctx, dataReader, int64(len(data)), false)
if err != nil {
t.Fatal(err)
}
// then join
r, l, err := j.Join(ctx, resultAddress)
r, l, err := j.Join(ctx, resultAddress, false)
if err != nil {
t.Fatal(err)
}
......@@ -93,7 +93,7 @@ func TestJoinReadAll(t *testing.T) {
var dataLength int64 = swarm.ChunkSize + 2
j := newMockJoiner(dataLength)
buf := bytes.NewBuffer(nil)
c, err := file.JoinReadAll(j, swarm.ZeroAddress, buf)
c, err := file.JoinReadAll(j, swarm.ZeroAddress, buf, false)
if err != nil {
t.Fatal(err)
}
......@@ -113,7 +113,7 @@ type mockJoiner struct {
}
// Join implements file.Joiner.
func (j *mockJoiner) Join(ctx context.Context, address swarm.Address) (dataOut io.ReadCloser, dataLength int64, err error) {
func (j *mockJoiner) Join(ctx context.Context, address swarm.Address, toDecrypt bool) (dataOut io.ReadCloser, dataLength int64, err error) {
data := make([]byte, j.l)
buf := bytes.NewBuffer(data)
readCloser := ioutil.NopCloser(buf)
......
......@@ -13,10 +13,12 @@ import (
"io/ioutil"
"sync"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/crypto/sha3"
)
// SimpleJoinerJob encapsulates a single joiner operation, providing the consumer
......@@ -46,10 +48,11 @@ type SimpleJoinerJob struct {
closeDoneOnce sync.Once // make sure done channel is closed only once
err error // read by the main thread to capture error state of the job
logger logging.Logger
toDecrypt bool // to decrypt the chunks or not
}
// NewSimpleJoinerJob creates a new simpleJoinerJob.
func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk swarm.Chunk) *SimpleJoinerJob {
func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk swarm.Chunk, toDecrypt bool) *SimpleJoinerJob {
spanLength := binary.LittleEndian.Uint64(rootChunk.Data()[:8])
levelCount := file.Levels(int64(spanLength), swarm.SectionSize, swarm.Branches)
......@@ -60,6 +63,7 @@ func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk sw
dataC: make(chan []byte),
doneC: make(chan struct{}),
logger: logging.New(ioutil.Discard, 0),
toDecrypt: toDecrypt,
}
// startLevelIndex is the root chunk level
......@@ -87,7 +91,6 @@ func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk sw
// start processes all chunk references of the root chunk that already has been retrieved.
func (j *SimpleJoinerJob) start(level int) error {
// consume the reference at the current cursor position of the chunk level data
// and start recursive retrieval down to the underlying data chunks
for j.cursors[level] < len(j.data[level]) {
......@@ -104,8 +107,15 @@ func (j *SimpleJoinerJob) start(level int) error {
func (j *SimpleJoinerJob) nextReference(level int) error {
data := j.data[level]
cursor := j.cursors[level]
var encryptionKey encryption.Key
chunkAddress := swarm.NewAddress(data[cursor : cursor+swarm.SectionSize])
err := j.nextChunk(level-1, chunkAddress)
if j.toDecrypt {
encryptionKey = make([]byte, encryption.KeyLength)
copy(encryptionKey, data[cursor+swarm.SectionSize:cursor+swarm.SectionSize+encryption.KeyLength])
}
err := j.nextChunk(level-1, chunkAddress, encryptionKey)
if err != nil {
if err == io.EOF {
return err
......@@ -124,6 +134,9 @@ func (j *SimpleJoinerJob) nextReference(level int) error {
// move the cursor to the next reference
j.cursors[level] += swarm.SectionSize
if j.toDecrypt {
j.cursors[level] += encryption.KeyLength
}
return nil
}
......@@ -132,22 +145,33 @@ func (j *SimpleJoinerJob) nextReference(level int) error {
// the current chunk is an intermediate chunk.
// When a data chunk is found it is passed on the dataC channel to be consumed by the
// io.Reader consumer.
func (j *SimpleJoinerJob) nextChunk(level int, address swarm.Address) error {
func (j *SimpleJoinerJob) nextChunk(level int, address swarm.Address, key encryption.Key) error {
// attempt to retrieve the chunk
ch, err := j.getter.Get(j.ctx, storage.ModeGetRequest, address)
if err != nil {
return err
}
var chunkData []byte
if j.toDecrypt {
decryptedData, err := DecryptChunkData(ch.Data(), key)
if err != nil {
return fmt.Errorf("error decrypting chunk %v: %v", address, err)
}
chunkData = decryptedData[8:]
} else {
chunkData = ch.Data()[8:]
}
j.cursors[level] = 0
j.data[level] = ch.Data()[8:]
j.data[level] = chunkData
// any level higher than 0 means the chunk contains references
// which must be recursively processed
if level > 0 {
for j.cursors[level] < len(j.data[level]) {
if len(j.data[level]) == j.cursors[level] {
j.data[level] = ch.Data()[8:]
j.data[level] = chunkData
j.cursors[level] = 0
}
err = j.nextReference(level)
......@@ -159,7 +183,7 @@ func (j *SimpleJoinerJob) nextChunk(level int, address swarm.Address) error {
// read data and pass to reader only if session is still active
// * context cancelled when client has disappeared, timeout etc
// * doneC receive when gracefully terminated through Close
data := ch.Data()[8:]
data := chunkData
err = j.sendChunkToReader(data)
}
return err
......@@ -213,3 +237,50 @@ func (j *SimpleJoinerJob) closeDone() {
close(j.doneC)
})
}
func DecryptChunkData(chunkData []byte, encryptionKey encryption.Key) ([]byte, error) {
if len(chunkData) < 8 {
return nil, fmt.Errorf("invalid ChunkData, min length 8 got %v", len(chunkData))
}
decryptedSpan, decryptedData, err := decrypt(chunkData, encryptionKey)
if err != nil {
return nil, err
}
// removing extra bytes which were just added for padding
length := binary.LittleEndian.Uint64(decryptedSpan)
refSize := int64(swarm.HashSize + encryption.KeyLength)
for length > swarm.ChunkSize {
length = length + (swarm.ChunkSize - 1)
length = length / swarm.ChunkSize
length *= uint64(refSize)
}
c := make([]byte, length+8)
copy(c[:8], decryptedSpan)
copy(c[8:], decryptedData[:length])
return c, nil
}
func decrypt(chunkData []byte, key encryption.Key) ([]byte, []byte, error) {
encryptedSpan, err := newSpanEncryption(key).Encrypt(chunkData[:8])
if err != nil {
return nil, nil, err
}
encryptedData, err := newDataEncryption(key).Encrypt(chunkData[8:])
if err != nil {
return nil, nil, err
}
return encryptedSpan, encryptedData, nil
}
func newSpanEncryption(key encryption.Key) *encryption.Encryption {
refSize := int64(swarm.HashSize + encryption.KeyLength)
return encryption.New(key, 0, uint32(swarm.ChunkSize/refSize), sha3.NewLegacyKeccak256)
}
func newDataEncryption(key encryption.Key) *encryption.Encryption {
return encryption.New(key, int(swarm.ChunkSize), 0, sha3.NewLegacyKeccak256)
}
......@@ -48,7 +48,7 @@ func TestSimpleJoinerJobBlocksize(t *testing.T) {
}
// this buffer is too small
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk)
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk, false)
b := make([]byte, swarm.SectionSize)
_, err = j.Read(b)
if err == nil {
......@@ -99,7 +99,7 @@ func TestSimpleJoinerJobOneLevel(t *testing.T) {
t.Fatal(err)
}
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk)
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk, false)
// verify first chunk content
outBuffer := make([]byte, 4096)
......@@ -188,7 +188,7 @@ func TestSimpleJoinerJobTwoLevelsAcrossChunk(t *testing.T) {
t.Fatal(err)
}
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk)
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk, false)
// read back all the chunks and verify
b := make([]byte, swarm.ChunkSize)
......
......@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner/internal"
"github.com/ethersphere/bee/pkg/storage"
......@@ -49,21 +50,45 @@ func (s *simpleJoiner) Size(ctx context.Context, address swarm.Address) (dataSiz
//
// It uses a non-optimized internal component that only retrieves a data chunk
// after the previous has been read.
func (s *simpleJoiner) Join(ctx context.Context, address swarm.Address) (dataOut io.ReadCloser, dataSize int64, err error) {
func (s *simpleJoiner) Join(ctx context.Context, address swarm.Address, toDecrypt bool) (dataOut io.ReadCloser, dataSize int64, err error) {
var addr []byte
var key encryption.Key
if toDecrypt {
addr = address.Bytes()[:swarm.HashSize]
key = address.Bytes()[swarm.HashSize : swarm.HashSize+encryption.KeyLength]
} else {
addr = address.Bytes()
}
// retrieve the root chunk to read the total data length the be retrieved
rootChunk, err := s.getter.Get(ctx, storage.ModeGetRequest, address)
rootChunk, err := s.getter.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(addr))
if err != nil {
return nil, 0, err
}
var chunkData []byte
if toDecrypt {
originalData, err := internal.DecryptChunkData(rootChunk.Data(), key)
if err != nil {
return nil, 0, err
}
chunkData = originalData
} else {
chunkData = rootChunk.Data()
}
// if this is a single chunk, short circuit to returning just that chunk
spanLength := binary.LittleEndian.Uint64(rootChunk.Data())
spanLength := binary.LittleEndian.Uint64(chunkData[:8])
chunkToSend := rootChunk
if spanLength <= swarm.ChunkSize {
data := rootChunk.Data()[8:]
data := chunkData[8:]
return file.NewSimpleReadCloser(data), int64(spanLength), nil
}
r := internal.NewSimpleJoinerJob(ctx, s.getter, rootChunk)
if toDecrypt {
chunkToSend = swarm.NewChunk(swarm.NewAddress(addr), chunkData)
}
r := internal.NewSimpleJoinerJob(ctx, s.getter, chunkToSend, toDecrypt)
return r, int64(spanLength), nil
}
......@@ -8,14 +8,18 @@ import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"testing"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter"
filetest "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
"gitlab.com/nolash/go-mockbytes"
)
// TestJoiner verifies that a newly created joiner returns the data stored
......@@ -29,7 +33,7 @@ func TestJoinerSingleChunk(t *testing.T) {
defer cancel()
var err error
_, _, err = joiner.Join(ctx, swarm.ZeroAddress)
_, _, err = joiner.Join(ctx, swarm.ZeroAddress, false)
if err != storage.ErrNotFound {
t.Fatalf("expected ErrNotFound for %x", swarm.ZeroAddress)
}
......@@ -47,7 +51,7 @@ func TestJoinerSingleChunk(t *testing.T) {
}
// read back data and compare
joinReader, l, err := joiner.Join(ctx, mockAddr)
joinReader, l, err := joiner.Join(ctx, mockAddr, false)
if err != nil {
t.Fatal(err)
}
......@@ -94,7 +98,7 @@ func TestJoinerWithReference(t *testing.T) {
}
// read back data and compare
joinReader, l, err := joiner.Join(ctx, rootChunk.Address())
joinReader, l, err := joiner.Join(ctx, rootChunk.Address(), false)
if err != nil {
t.Fatal(err)
}
......@@ -114,3 +118,63 @@ func TestJoinerWithReference(t *testing.T) {
t.Fatalf("expected resultbuffer %v, got %v", resultBuffer, firstChunk.Data()[:len(resultBuffer)])
}
}
func TestEncryptionAndDecryption(t *testing.T) {
var tests = []struct {
chunkLength int
}{
{10},
{100},
{1000},
{4095},
{4096},
{4097},
{15000},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Encrypt %d bytes", tt.chunkLength), func(t *testing.T) {
store := mock.NewStorer()
joinner := joiner.NewSimpleJoiner(store)
g := mockbytes.New(0, mockbytes.MockTypeStandard).WithModulus(255)
testData, err := g.SequentialBytes(tt.chunkLength)
if err != nil {
t.Fatal(err)
}
s := splitter.NewSimpleSplitter(store)
testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), true)
if err != nil {
t.Fatal(err)
}
reader, l, err := joinner.Join(context.Background(), resultAddress, true)
if err != nil {
t.Fatal(err)
}
if l != int64(len(testData)) {
t.Fatalf("expected join data length %d, got %d", len(testData), l)
}
totalGot := make([]byte, tt.chunkLength)
index := 0
resultBuffer := make([]byte, swarm.ChunkSize)
for index < tt.chunkLength {
n, err := reader.Read(resultBuffer)
if err != nil && err != io.EOF {
t.Fatal(err)
}
copy(totalGot[index:], resultBuffer[:n])
index += n
}
if !bytes.Equal(testData, totalGot) {
t.Fatal("input data and output data does not match")
}
})
}
}
......@@ -11,6 +11,7 @@ import (
"fmt"
"hash"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
......@@ -46,12 +47,19 @@ type SimpleSplitterJob struct {
cursors []int // section write position, indexed per level
hasher bmt.Hash // underlying hasher used for hashing the tree
buffer []byte // keeps data and hashes, indexed by cursors
toEncrypt bool // to encryrpt the chunks or not
refSize int64
}
// NewSimpleSplitterJob creates a new SimpleSplitterJob.
//
// The spanLength is the length of the data that will be written.
func NewSimpleSplitterJob(ctx context.Context, putter storage.Putter, spanLength int64) *SimpleSplitterJob {
func NewSimpleSplitterJob(ctx context.Context, putter storage.Putter, spanLength int64, toEncrypt bool) *SimpleSplitterJob {
hashSize := swarm.HashSize
refSize := int64(hashSize)
if toEncrypt {
refSize += encryption.KeyLength
}
p := bmtlegacy.NewTreePool(hashFunc, swarm.Branches, bmtlegacy.PoolSize)
return &SimpleSplitterJob{
ctx: ctx,
......@@ -61,6 +69,8 @@ func NewSimpleSplitterJob(ctx context.Context, putter storage.Putter, spanLength
cursors: make([]int, levelBufferLimit),
hasher: bmtlegacy.New(p),
buffer: make([]byte, file.ChunkWithLengthSize*levelBufferLimit*2), // double size as temp workaround for weak calculation of needed buffer space
toEncrypt: toEncrypt,
refSize: refSize,
}
}
......@@ -126,34 +136,48 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
s.sumCounts[lvl]++
spanSize := file.Spans[lvl] * swarm.ChunkSize
span := (s.length-1)%spanSize + 1
sizeToSum := s.cursors[lvl] - s.cursors[lvl+1]
// perform hashing
//perform hashing
s.hasher.Reset()
err := s.hasher.SetSpan(span)
if err != nil {
return nil, err
}
_, err = s.hasher.Write(s.buffer[s.cursors[lvl+1] : s.cursors[lvl+1]+sizeToSum])
var ref encryption.Key
var chunkData []byte
data := s.buffer[s.cursors[lvl+1] : s.cursors[lvl+1]+sizeToSum]
_, err = s.hasher.Write(data)
if err != nil {
return nil, err
}
ref := s.hasher.Sum(nil)
// assemble chunk and put in store
addr := swarm.NewAddress(ref)
ref = s.hasher.Sum(nil)
head := make([]byte, 8)
binary.LittleEndian.PutUint64(head, uint64(span))
tail := s.buffer[s.cursors[lvl+1]:s.cursors[lvl]]
chunkData := append(head, tail...)
ch := swarm.NewChunk(addr, chunkData)
chunkData = append(head, tail...)
// assemble chunk and put in store
addr := swarm.NewAddress(ref)
c := chunkData
var encryptionKey encryption.Key
if s.toEncrypt {
c, encryptionKey, err = s.encryptChunkData(chunkData)
if err != nil {
return nil, err
}
}
ch := swarm.NewChunk(addr, c)
_, err = s.putter.Put(s.ctx, storage.ModePutUpload, ch)
if err != nil {
return nil, err
}
return ref, nil
return append(ch.Address().Bytes(), encryptionKey...), nil
}
// digest returns the calculated digest after a Sum call.
......@@ -164,7 +188,11 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
// The method does not check that the final hash actually has been written, so
// timing is the responsibility of the caller.
func (s *SimpleSplitterJob) digest() []byte {
return s.buffer[:swarm.SectionSize]
if s.toEncrypt {
return s.buffer[:swarm.SectionSize*2]
} else {
return s.buffer[:swarm.SectionSize]
}
}
// hashUnfinished hasher the remaining unhashed chunks at the end of each level if
......@@ -229,3 +257,39 @@ func (s *SimpleSplitterJob) moveDanglingChunk() error {
}
return nil
}
func (s *SimpleSplitterJob) encryptChunkData(chunkData []byte) ([]byte, encryption.Key, error) {
if len(chunkData) < 8 {
return nil, nil, fmt.Errorf("invalid data, min length 8 got %v", len(chunkData))
}
key, encryptedSpan, encryptedData, err := s.encrypt(chunkData)
if err != nil {
return nil, nil, err
}
c := make([]byte, len(encryptedSpan)+len(encryptedData))
copy(c[:8], encryptedSpan)
copy(c[8:], encryptedData)
return c, key, nil
}
func (s *SimpleSplitterJob) encrypt(chunkData []byte) (encryption.Key, []byte, []byte, error) {
key := encryption.GenerateRandomKey(encryption.KeyLength)
encryptedSpan, err := s.newSpanEncryption(key).Encrypt(chunkData[:8])
if err != nil {
return nil, nil, nil, err
}
encryptedData, err := s.newDataEncryption(key).Encrypt(chunkData[8:])
if err != nil {
return nil, nil, nil, err
}
return key, encryptedSpan, encryptedData, nil
}
func (s *SimpleSplitterJob) newSpanEncryption(key encryption.Key) *encryption.Encryption {
return encryption.New(key, 0, uint32(swarm.ChunkSize/s.refSize), sha3.NewLegacyKeccak256)
}
func (s *SimpleSplitterJob) newDataEncryption(key encryption.Key) *encryption.Encryption {
return encryption.New(key, int(swarm.ChunkSize), 0, sha3.NewLegacyKeccak256)
}
......@@ -31,7 +31,7 @@ func TestSplitterJobPartialSingleChunk(t *testing.T) {
defer cancel()
data := []byte("foo")
j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data)))
j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data)), false)
c, err := j.Write(data)
if err != nil {
......@@ -74,7 +74,7 @@ func testSplitterJobVector(t *testing.T) {
data, expect := test.GetVector(t, int(dataIdx))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data)))
j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data)), false)
for i := 0; i < len(data); i += swarm.ChunkSize {
l := swarm.ChunkSize
......
......@@ -34,9 +34,8 @@ func NewSimpleSplitter(putter storage.Putter) file.Splitter {
// multiple levels of hashing when building the file hash tree.
//
// It returns the Swarmhash of the data.
func (s *simpleSplitter) Split(ctx context.Context, r io.ReadCloser, dataLength int64) (addr swarm.Address, err error) {
j := internal.NewSimpleSplitterJob(ctx, s.putter, dataLength)
func (s *simpleSplitter) Split(ctx context.Context, r io.ReadCloser, dataLength int64, toEncrypt bool) (addr swarm.Address, err error) {
j := internal.NewSimpleSplitterJob(ctx, s.putter, dataLength, toEncrypt)
var total int64
data := make([]byte, swarm.ChunkSize)
var eof bool
......@@ -49,6 +48,7 @@ func (s *simpleSplitter) Split(ctx context.Context, r io.ReadCloser, dataLength
return swarm.ZeroAddress, fmt.Errorf("splitter only received %d bytes of data, expected %d bytes", total+int64(c), dataLength)
}
eof = true
continue
} else {
return swarm.ZeroAddress, err
}
......
......@@ -26,7 +26,7 @@ func TestSplitIncomplete(t *testing.T) {
s := splitter.NewSimpleSplitter(store)
testDataReader := file.NewSimpleReadCloser(testData)
_, err := s.Split(context.Background(), testDataReader, 41)
_, err := s.Split(context.Background(), testDataReader, 41, false)
if err == nil {
t.Fatalf("expected error on EOF before full length write")
}
......@@ -45,7 +45,7 @@ func TestSplitSingleChunk(t *testing.T) {
s := splitter.NewSimpleSplitter(store)
testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)))
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), false)
if err != nil {
t.Fatal(err)
}
......@@ -68,7 +68,7 @@ func TestSplitSingleChunk(t *testing.T) {
func TestSplitThreeLevels(t *testing.T) {
// edge case selected from internal/job_test.go
g := mockbytes.New(0, mockbytes.MockTypeStandard).WithModulus(255)
testData, err := g.SequentialBytes(swarm.ChunkSize * swarm.Branches)
testData, err := g.SequentialBytes(swarm.ChunkSize * 128)
if err != nil {
t.Fatal(err)
}
......@@ -77,7 +77,7 @@ func TestSplitThreeLevels(t *testing.T) {
s := splitter.NewSimpleSplitter(store)
testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)))
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), false)
if err != nil {
t.Fatal(err)
}
......@@ -136,7 +136,7 @@ func TestUnalignedSplit(t *testing.T) {
doneC := make(chan swarm.Address)
errC := make(chan error)
go func() {
addr, err := sp.Split(ctx, chunkPipe, dataLen)
addr, err := sp.Split(ctx, chunkPipe, dataLen, false)
if err != nil {
errC <- err
} else {
......@@ -180,5 +180,4 @@ func TestUnalignedSplit(t *testing.T) {
case <-timer.C:
t.Fatal("timeout")
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment