Commit d39c5638 authored by Zahoor Mohamed's avatar Zahoor Mohamed Committed by GitHub

Porting Swarm encryption in Bee (#320)

Encryption Support for Bee
parent 0bc0be56
...@@ -59,7 +59,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -59,7 +59,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
writeCloser := cmdfile.NopWriteCloser(buf) writeCloser := cmdfile.NopWriteCloser(buf)
limitBuf := cmdfile.NewLimitWriteCloser(writeCloser, limitMetadataLength) limitBuf := cmdfile.NewLimitWriteCloser(writeCloser, limitMetadataLength)
j := joiner.NewSimpleJoiner(store) j := joiner.NewSimpleJoiner(store)
_, err = file.JoinReadAll(j, addr, limitBuf) _, err = file.JoinReadAll(j, addr, limitBuf, false)
if err != nil { if err != nil {
return err return err
} }
...@@ -70,7 +70,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -70,7 +70,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
} }
buf = bytes.NewBuffer(nil) buf = bytes.NewBuffer(nil)
_, err = file.JoinReadAll(j, e.Metadata(), buf) _, err = file.JoinReadAll(j, e.Metadata(), buf, false)
if err != nil { if err != nil {
return err return err
} }
...@@ -116,7 +116,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -116,7 +116,7 @@ func getEntry(cmd *cobra.Command, args []string) (err error) {
return err return err
} }
defer outFile.Close() defer outFile.Close()
_, err = file.JoinReadAll(j, e.Reference(), outFile) _, err = file.JoinReadAll(j, e.Reference(), outFile, false)
return err return err
} }
...@@ -167,7 +167,7 @@ func putEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -167,7 +167,7 @@ func putEntry(cmd *cobra.Command, args []string) (err error) {
metadataBuf := bytes.NewBuffer(metadataBytes) metadataBuf := bytes.NewBuffer(metadataBytes)
metadataReader := io.LimitReader(metadataBuf, int64(len(metadataBytes))) metadataReader := io.LimitReader(metadataBuf, int64(len(metadataBytes)))
metadataReadCloser := ioutil.NopCloser(metadataReader) metadataReadCloser := ioutil.NopCloser(metadataReader)
metadataAddr, err := s.Split(ctx, metadataReadCloser, int64(len(metadataBytes))) metadataAddr, err := s.Split(ctx, metadataReadCloser, int64(len(metadataBytes)), false)
if err != nil { if err != nil {
return err return err
} }
...@@ -182,7 +182,7 @@ func putEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -182,7 +182,7 @@ func putEntry(cmd *cobra.Command, args []string) (err error) {
fileEntryBuf := bytes.NewBuffer(fileEntryBytes) fileEntryBuf := bytes.NewBuffer(fileEntryBytes)
fileEntryReader := io.LimitReader(fileEntryBuf, int64(len(fileEntryBytes))) fileEntryReader := io.LimitReader(fileEntryBuf, int64(len(fileEntryBytes)))
fileEntryReadCloser := ioutil.NopCloser(fileEntryReader) fileEntryReadCloser := ioutil.NopCloser(fileEntryReader)
fileEntryAddr, err := s.Split(ctx, fileEntryReadCloser, int64(len(fileEntryBytes))) fileEntryAddr, err := s.Split(ctx, fileEntryReadCloser, int64(len(fileEntryBytes)), false)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -83,7 +83,7 @@ func Join(cmd *cobra.Command, args []string) (err error) { ...@@ -83,7 +83,7 @@ func Join(cmd *cobra.Command, args []string) (err error) {
// create the join and get its data reader // create the join and get its data reader
j := joiner.NewSimpleJoiner(store) j := joiner.NewSimpleJoiner(store)
_, err = file.JoinReadAll(j, addr, outFile) _, err = file.JoinReadAll(j, addr, outFile, false)
return err return err
} }
......
...@@ -96,7 +96,7 @@ func Split(cmd *cobra.Command, args []string) (err error) { ...@@ -96,7 +96,7 @@ func Split(cmd *cobra.Command, args []string) (err error) {
s := splitter.NewSimpleSplitter(stores) s := splitter.NewSimpleSplitter(stores)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
addr, err := s.Split(ctx, infile, inputLength) addr, err := s.Split(ctx, infile, inputLength, false)
if err != nil { if err != nil {
return err return err
} }
......
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo=
cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
...@@ -919,7 +918,6 @@ golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73r ...@@ -919,7 +918,6 @@ golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73r
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181220203305-927f97764cc3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190227160552-c95aed5357e7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190227160552-c95aed5357e7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190228165749-92fc7df08ae7/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
...@@ -1145,4 +1143,4 @@ rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8 ...@@ -1145,4 +1143,4 @@ rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck=
sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0=
\ No newline at end of file
...@@ -10,7 +10,9 @@ import ( ...@@ -10,7 +10,9 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
...@@ -27,8 +29,10 @@ type bytesPostResponse struct { ...@@ -27,8 +29,10 @@ type bytesPostResponse struct {
// bytesUploadHandler handles upload of raw binary data of arbitrary length. // bytesUploadHandler handles upload of raw binary data of arbitrary length.
func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
toEncrypt := strings.ToLower(r.Header.Get(EncryptHeader)) == "true"
sp := splitter.NewSimpleSplitter(s.Storer) sp := splitter.NewSimpleSplitter(s.Storer)
address, err := file.SplitWriteAll(ctx, sp, r.Body, r.ContentLength) address, err := file.SplitWriteAll(ctx, sp, r.Body, r.ContentLength, toEncrypt)
if err != nil { if err != nil {
s.Logger.Debugf("bytes upload: %v", err) s.Logger.Debugf("bytes upload: %v", err)
jsonhttp.InternalServerError(w, nil) jsonhttp.InternalServerError(w, nil)
...@@ -52,8 +56,8 @@ func (s *server) bytesGetHandler(w http.ResponseWriter, r *http.Request) { ...@@ -52,8 +56,8 @@ func (s *server) bytesGetHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
toDecrypt := len(address.Bytes()) == (swarm.HashSize + encryption.KeyLength)
j := joiner.NewSimpleJoiner(s.Storer) j := joiner.NewSimpleJoiner(s.Storer)
dataSize, err := j.Size(ctx, address) dataSize, err := j.Size(ctx, address)
if err != nil { if err != nil {
if errors.Is(err, storage.ErrNotFound) { if errors.Is(err, storage.ErrNotFound) {
...@@ -69,7 +73,7 @@ func (s *server) bytesGetHandler(w http.ResponseWriter, r *http.Request) { ...@@ -69,7 +73,7 @@ func (s *server) bytesGetHandler(w http.ResponseWriter, r *http.Request) {
} }
outBuffer := bytes.NewBuffer(nil) outBuffer := bytes.NewBuffer(nil)
c, err := file.JoinReadAll(j, address, outBuffer) c, err := file.JoinReadAll(j, address, outBuffer, toDecrypt)
if err != nil && c == 0 { if err != nil && c == 0 {
s.Logger.Debugf("bytes download: data join %s: %v", address, err) s.Logger.Debugf("bytes download: data join %s: %v", address, err)
s.Logger.Errorf("bytes download: data join %s", address) s.Logger.Errorf("bytes download: data join %s", address)
......
...@@ -17,8 +17,10 @@ import ( ...@@ -17,8 +17,10 @@ import (
"net/http" "net/http"
"os" "os"
"strconv" "strconv"
"strings"
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
...@@ -28,7 +30,10 @@ import ( ...@@ -28,7 +30,10 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
const multipartFormDataMediaType = "multipart/form-data" const (
multiPartFormData = "multipart/form-data"
EncryptHeader = "swarm-encrypt"
)
type fileUploadResponse struct { type fileUploadResponse struct {
Reference swarm.Address `json:"reference"` Reference swarm.Address `json:"reference"`
...@@ -38,6 +43,7 @@ type fileUploadResponse struct { ...@@ -38,6 +43,7 @@ type fileUploadResponse struct {
// - multipart http message // - multipart http message
// - other content types as complete file body // - other content types as complete file body
func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
toEncrypt := strings.ToLower(r.Header.Get(EncryptHeader)) == "true"
contentType := r.Header.Get("Content-Type") contentType := r.Header.Get("Content-Type")
mediaType, params, err := mime.ParseMediaType(contentType) mediaType, params, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
...@@ -52,7 +58,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -52,7 +58,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
var fileName, contentLength string var fileName, contentLength string
var fileSize uint64 var fileSize uint64
if mediaType == multipartFormDataMediaType { if mediaType == multiPartFormData {
mr := multipart.NewReader(r.Body, params["boundary"]) mr := multipart.NewReader(r.Body, params["boundary"])
// read only the first part, as only one file upload is supported // read only the first part, as only one file upload is supported
...@@ -133,7 +139,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -133,7 +139,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
// first store the file and get its reference // first store the file and get its reference
sp := splitter.NewSimpleSplitter(s.Storer) sp := splitter.NewSimpleSplitter(s.Storer)
fr, err := file.SplitWriteAll(ctx, sp, reader, int64(fileSize)) fr, err := file.SplitWriteAll(ctx, sp, reader, int64(fileSize), toEncrypt)
if err != nil { if err != nil {
s.Logger.Debugf("file upload: file store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: file store, file %q: %v", fileName, err)
s.Logger.Errorf("file upload: file store, file %q", fileName) s.Logger.Errorf("file upload: file store, file %q", fileName)
...@@ -157,7 +163,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -157,7 +163,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
sp = splitter.NewSimpleSplitter(s.Storer) sp = splitter.NewSimpleSplitter(s.Storer)
mr, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(metadataBytes), int64(len(metadataBytes))) mr, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(metadataBytes), int64(len(metadataBytes)), toEncrypt)
if err != nil { if err != nil {
s.Logger.Debugf("file upload: metadata store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: metadata store, file %q: %v", fileName, err)
s.Logger.Errorf("file upload: metadata store, file %q", fileName) s.Logger.Errorf("file upload: metadata store, file %q", fileName)
...@@ -174,9 +180,8 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -174,9 +180,8 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
jsonhttp.InternalServerError(w, "entry marshal error") jsonhttp.InternalServerError(w, "entry marshal error")
return return
} }
sp = splitter.NewSimpleSplitter(s.Storer) sp = splitter.NewSimpleSplitter(s.Storer)
reference, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes))) reference, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)), toEncrypt)
if err != nil { if err != nil {
s.Logger.Debugf("file upload: entry store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: entry store, file %q: %v", fileName, err)
s.Logger.Errorf("file upload: entry store, file %q", fileName) s.Logger.Errorf("file upload: entry store, file %q", fileName)
...@@ -200,10 +205,12 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -200,10 +205,12 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
toDecrypt := len(address.Bytes()) == (swarm.HashSize + encryption.KeyLength)
// read entry. // read entry.
j := joiner.NewSimpleJoiner(s.Storer) j := joiner.NewSimpleJoiner(s.Storer)
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
_, err = file.JoinReadAll(j, address, buf) _, err = file.JoinReadAll(j, address, buf, toDecrypt)
if err != nil { if err != nil {
s.Logger.Debugf("file download: read entry %s: %v", addr, err) s.Logger.Debugf("file download: read entry %s: %v", addr, err)
s.Logger.Errorf("file download: read entry %s", addr) s.Logger.Errorf("file download: read entry %s", addr)
...@@ -231,7 +238,7 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -231,7 +238,7 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
// Read metadata. // Read metadata.
buf = bytes.NewBuffer(nil) buf = bytes.NewBuffer(nil)
_, err = file.JoinReadAll(j, e.Metadata(), buf) _, err = file.JoinReadAll(j, e.Metadata(), buf, toDecrypt)
if err != nil { if err != nil {
s.Logger.Debugf("file download: read metadata %s: %v", addr, err) s.Logger.Debugf("file download: read metadata %s: %v", addr, err)
s.Logger.Errorf("file download: read metadata %s", addr) s.Logger.Errorf("file download: read metadata %s", addr)
...@@ -276,7 +283,7 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -276,7 +283,7 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
}() }()
go func() { go func() {
_, err := file.JoinReadAll(j, e.Reference(), pw) _, err := file.JoinReadAll(j, e.Reference(), pw, toDecrypt)
if err := pw.CloseWithError(err); err != nil { if err := pw.CloseWithError(err); err != nil {
s.Logger.Debugf("file download: data join close %s: %v", addr, err) s.Logger.Debugf("file download: data join close %s: %v", addr, err)
s.Logger.Errorf("file download: data join close %s", addr) s.Logger.Errorf("file download: data join close %s", addr)
......
...@@ -49,6 +49,31 @@ func TestFiles(t *testing.T) { ...@@ -49,6 +49,31 @@ func TestFiles(t *testing.T) {
}) })
}) })
t.Run("encrypt-decrypt", func(t *testing.T) {
fileName := "my-pictures.jpeg"
rootHash := "f2e761160deda91c1fbfab065a5abf530b0766b3e102b51fbd626ba37c3bc581"
headers := make(http.Header)
headers.Add("EncryptHeader", "True")
headers.Add("Content-Type", "image/jpeg; charset=utf-8")
_ = jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, client, http.MethodPost, fileUploadResource+"?name="+fileName, bytes.NewReader(simpleData), http.StatusOK, api.FileUploadResponse{
Reference: swarm.MustParseHexAddress(rootHash),
}, headers)
rcvdHeader := jsonhttptest.ResponseDirectCheckBinaryResponse(t, client, http.MethodGet, fileDownloadResource(rootHash), nil, http.StatusOK, simpleData, nil)
cd := rcvdHeader.Get("Content-Disposition")
_, params, err := mime.ParseMediaType(cd)
if err != nil {
t.Fatal(err)
}
if params["filename"] != fileName {
t.Fatal("Invalid file name detected")
}
if rcvdHeader.Get("Content-Type") != "image/jpeg; charset=utf-8" {
t.Fatal("Invalid content type detected")
}
})
t.Run("check-content-type-detection", func(t *testing.T) { t.Run("check-content-type-detection", func(t *testing.T) {
fileName := "my-pictures.jpeg" fileName := "my-pictures.jpeg"
rootHash := "f2e761160deda91c1fbfab065a5abf530b0766b3e102b51fbd626ba37c3bc581" rootHash := "f2e761160deda91c1fbfab065a5abf530b0766b3e102b51fbd626ba37c3bc581"
......
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package encryption
import (
"crypto/rand"
"encoding/binary"
"fmt"
"hash"
"sync"
)
const KeyLength = 32
type Key []byte
type Encryptor interface {
Encrypt(data []byte) ([]byte, error)
Decrypt(data []byte) ([]byte, error)
Reset()
}
type Encryption struct {
key Key // the encryption key (hashSize bytes long)
keyLen int // length of the key = length of blockcipher block
padding int // encryption will pad the data upto this if > 0
index int // counter index
initCtr uint32 // initial counter used for counter mode blockcipher
hashFunc func() hash.Hash // hasher constructor function
}
// New constructs a new encryptor/decryptor
func New(key Key, padding int, initCtr uint32, hashFunc func() hash.Hash) *Encryption {
return &Encryption{
key: key,
keyLen: len(key),
padding: padding,
initCtr: initCtr,
hashFunc: hashFunc,
}
}
// Encrypt encrypts the data and does padding if specified
func (e *Encryption) Encrypt(data []byte) ([]byte, error) {
length := len(data)
outLength := length
isFixedPadding := e.padding > 0
if isFixedPadding {
if length > e.padding {
return nil, fmt.Errorf("data length longer than padding, data length %v padding %v", length, e.padding)
}
outLength = e.padding
}
out := make([]byte, outLength)
err := e.transform(data, out)
if err != nil {
return nil, err
}
return out, nil
}
// Decrypt decrypts the data, if padding was used caller must know original length and truncate
func (e *Encryption) Decrypt(data []byte) ([]byte, error) {
length := len(data)
if e.padding > 0 && length != e.padding {
return nil, fmt.Errorf("data length different than padding, data length %v padding %v", length, e.padding)
}
out := make([]byte, length)
err := e.transform(data, out)
if err != nil {
return nil, err
}
return out, nil
}
// Reset resets the counter. It is only safe to call after an encryption operation is completed
// After Reset is called, the Encryption object can be re-used for other data
func (e *Encryption) Reset() {
e.index = 0
}
// split up input into keylength segments and encrypt sequentially
func (e *Encryption) transform(in, out []byte) error {
inLength := len(in)
wg := sync.WaitGroup{}
wg.Add((inLength-1)/e.keyLen + 1)
for i := 0; i < inLength; i += e.keyLen {
errs := make(chan error, 1)
l := min(e.keyLen, inLength-i)
go func(i int, x, y []byte) {
defer wg.Done()
err := e.Transcrypt(i, x, y)
errs <- err
}(e.index, in[i:i+l], out[i:i+l])
e.index++
err := <-errs
if err != nil {
close((errs))
return err
}
}
// pad the rest if out is longer
pad(out[inLength:])
wg.Wait()
return nil
}
// used for segmentwise transformation
// if in is shorter than out, padding is used
func (e *Encryption) Transcrypt(i int, in, out []byte) error {
// first hash key with counter (initial counter + i)
hasher := e.hashFunc()
_, err := hasher.Write(e.key)
if err != nil {
return err
}
ctrBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(ctrBytes, uint32(i)+e.initCtr)
_, err = hasher.Write(ctrBytes)
if err != nil {
return err
}
ctrHash := hasher.Sum(nil)
hasher.Reset()
// second round of hashing for selective disclosure
_, err = hasher.Write(ctrHash)
if err != nil {
return err
}
segmentKey := hasher.Sum(nil)
hasher.Reset()
// XOR bytes uptil length of in (out must be at least as long)
inLength := len(in)
for j := 0; j < inLength; j++ {
out[j] = in[j] ^ segmentKey[j]
}
// insert padding if out is longer
pad(out[inLength:])
return nil
}
func pad(b []byte) {
l := len(b)
for total := 0; total < l; {
read, _ := rand.Read(b[total:])
total += read
}
}
// GenerateRandomKey generates a random key of length l
func GenerateRandomKey(l int) Key {
key := make([]byte, l)
var total int
for total < l {
read, _ := rand.Read(key[total:])
total += read
}
return key
}
func min(x, y int) int {
if x < y {
return x
}
return y
}
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package encryption
import (
"bytes"
crand "crypto/rand"
"encoding/hex"
"math/rand"
"testing"
"github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/crypto/sha3"
)
var expectedTransformedHex = "352187af3a843decc63ceca6cb01ea39dbcf77caf0a8f705f5c30d557044ceec9392b94a79376f1e5c10cd0c0f2a98e5353bf22b3ea4fdac6677ee553dec192e3db64e179d0474e96088fb4abd2babd67de123fb398bdf84d818f7bda2c1ab60b3ea0e0569ae54aa969658eb4844e6960d2ff44d7c087ee3aaffa1c0ee5df7e50b615f7ad90190f022934ad5300c7d1809bfe71a11cc04cece5274eb97a5f20350630522c1dbb7cebaf4f97f84e03f5cfd88f2b48880b25d12f4d5e75c150f704ef6b46c72e07db2b705ac3644569dccd22fd8f964f6ef787fda63c46759af334e6f665f70eac775a7017acea49f3c7696151cb1b9434fa4ac27fb803921ffb5ec58dafa168098d7d5b97e384be3384cf5bc235c3d887fef89fe76c0065f9b8d6ad837b442340d9e797b46ef5709ea3358bc415df11e4830de986ef0f1c418ffdcc80e9a3cda9bea0ab5676c0d4240465c43ba527e3b4ea50b4f6255b510e5d25774a75449b0bd71e56c537ade4fcf0f4d63c99ae1dbb5a844971e2c19941b8facfcfc8ee3056e7cb3c7114c5357e845b52f7103cb6e00d2308c37b12baa5b769e1cc7b00fc06f2d16e70cc27a82cb9c1a4e40cb0d43907f73df2c9db44f1b51a6b0bc6d09f77ac3be14041fae3f9df2da42df43ae110904f9ecee278030185254d7c6e918a5512024d047f77a992088cb3190a6587aa54d0c7231c1cd2e455e0d4c07f74bece68e29cd8ba0190c0bcfb26d24634af5d91a81ef5d4dd3d614836ce942ddbf7bb1399317f4c03faa675f325f18324bf9433844bfe5c4cc04130c8d5c329562b7cd66e72f7355de8f5375a72202971613c32bd7f3fcdcd51080758cd1d0a46dbe8f0374381dbc359f5864250c63dde8131cbd7c98ae2b0147d6ea4bf65d1443d511b18e6d608bbb46ac036353b4c51df306a10a6f6939c38629a5c18aaf89cac04bd3ad5156e6b92011c88341cb08551bab0a89e6a46538f5af33b86121dba17e3a434c273f385cd2e8cb90bdd32747d8425d929ccbd9b0815c73325988855549a8489dfd047daf777aaa3099e54cf997175a5d9e1edfe363e3b68c70e02f6bf4fcde6a0f3f7d0e7e98bde1a72ae8b6cd27b32990680cc4a04fc467f41c5adcaddabfc71928a3f6872c360c1d765260690dd28b269864c8e380d9c92ef6b89b0094c8f9bb22608b4156381b19b920e9583c9616ce5693b4d2a6c689f02e6a91584a8e501e107403d2689dd0045269dd9946c0e969fb656a3b39d84a798831f5f9290f163eb2f97d3ae25071324e95e2256d9c1e56eb83c26397855323edc202d56ad05894333b7f0ed3c1e4734782eb8bd5477242fd80d7a89b12866f85cfae476322f032465d6b1253993033fccd4723530630ab97a1566460af9c90c9da843c229406e65f3fa578bd6bf04dee9b6153807ddadb8ceefc5c601a8ab26023c67b1ab1e8e0f29ce94c78c308005a781853e7a2e0e51738939a657c987b5e611f32f47b5ff461c52e63e0ea390515a8e1f5393dae54ea526934b5f310b76e3fa050e40718cb4c8a20e58946d6ee1879f08c52764422fe542b3240e75eccb7aa75b1f8a651e37a3bc56b0932cdae0e985948468db1f98eb4b77b82081ea25d8a762db00f7898864984bd80e2f3f35f236bf57291dec28f550769943bcfb6f884b7687589b673642ef7fe5d7d5a87d3eca5017f83ccb9a3310520474479464cb3f433440e7e2f1e28c0aef700a45848573409e7ab66e0cfd4fe5d2147ace81bc65fd8891f6245cd69246bbf5c27830e5ab882dd1d02aba34ff6ca9af88df00fd602892f02fedbdc65dedec203faf3f8ff4a97314e0ddb58b9ab756a61a562597f4088b445fcc3b28a708ca7b1485dcd791b779fbf2b3ef1ec5c6205f595fbe45a02105034147e5a146089c200a49dae33ae051a08ea5f974a21540aaeffa7f9d9e3d35478016fb27b871036eb27217a5b834b461f535752fb5f1c8dded3ae14ce3a2ef6639e2fe41939e3509e46e347a95d50b2080f1ba42c804b290ddc912c952d1cec3f2661369f738feacc0dbf1ea27429c644e45f9e26f30c341acd34c7519b2a1663e334621691e810767e9918c2c547b2e23cce915f97d26aac8d0d2fcd3edb7986ad4e2b8a852edebad534cb6c0e9f0797d3563e5409d7e068e48356c67ce519246cd9c560e881453df97cbba562018811e6cf8c327f399d1d1253ab47a19f4a0ccc7c6d86a9603e0551da310ea595d71305c4aad96819120a92cdbaf1f77ec8df9cc7c838c0d4de1e8692dd81da38268d1d71324bcffdafbe5122e4b81828e021e936d83ae8021eac592aa52cd296b5ce392c7173d622f8e07d18f59bb1b08ba15211af6703463b09b593af3c37735296816d9f2e7a369354a5374ea3955e14ca8ac56d5bfe4aef7a21bd825d6ae85530bee5d2aaaa4914981b3dfdb2e92ec2a27c83d74b59e84ff5c056f7d8945745f2efc3dcf28f288c6cd8383700fb2312f7001f24dd40015e436ae23e052fe9070ea9535b9c989898a9bda3d5382cf10e432fae6ccf0c825b3e6436edd3a9f8846e5606f8563931b5f29ba407c5236e5730225dda211a8504ec1817bc935e1fd9a532b648c502df302ed2063aed008fd5676131ac9e95998e9447b02bd29d77e38fcfd2959f2de929b31970335eb2a74348cc6918bc35b9bf749eab0fe304c946cd9e1ca284e6853c42646e60b6b39e0d3fb3c260abfc5c1b4ca3c3770f344118ca7c7f5c1ad1f123f8f369cd60afc3cdb3e9e81968c5c9fa7c8b014ffe0508dd4f0a2a976d5d1ca8fc9ad7a237d92cfe7b41413d934d6e142824b252699397e48e4bac4e91ebc10602720684bd0863773c548f9a2f9724245e47b129ecf65afd7252aac48c8a8d6fd3d888af592a01fb02dc71ed7538a700d3d16243e4621e0fcf9f8ed2b4e11c9fa9a95338bb1dac74a7d9bc4eb8cbf900b634a2a56469c00f5994e4f0934bdb947640e6d67e47d0b621aacd632bfd3c800bd7d93bd329f494a90e06ed51535831bd6e07ac1b4b11434ef3918fa9511813a002913f33f836454798b8d1787fea9a4c4743ba091ed192ed92f4d33e43a226bf9503e1a83a16dd340b3cbbf38af6db0d99201da8de529b4225f3d2fa2aad6621afc6c79ef3537720591edfc681ae6d00ede53ed724fc71b23b90d2e9b7158aaee98d626a4fe029107df2cb5f90147e07ebe423b1519d848af18af365c71bfd0665db46be493bbe99b79a188de0cf3594aef2299f0324075bdce9eb0b87bc29d62401ba4fd6ae48b1ba33261b5b845279becf38ee03e3dc5c45303321c5fac96fd02a3ad8c9e3b02127b320501333c9e6360440d1ad5e64a6239501502dde1a49c9abe33b66098458eee3d611bb06ffcd234a1b9aef4af5021cd61f0de6789f822ee116b5078aae8c129e8391d8987500d322b58edd1595dc570b57341f2df221b94a96ab7fbcf32a8ca9684196455694024623d7ed49f7d66e8dd453c0bae50e0d8b34377b22d0ece059e2c385dfc70b9089fcd27577c51f4d870b5738ee2b68c361a67809c105c7848b68860a829f29930857a9f9d40b14fd2384ac43bafdf43c0661103794c4bd07d1cfdd4681b6aeaefad53d4c1473359bcc5a83b09189352e5bb9a7498dd0effb89c35aad26954551f8b0621374b449bf515630bd3974dca982279733470fdd059aa9c3df403d8f22b38c4709c82d8f12b888e22990350490e16179caf406293cc9e65f116bafcbe96af132f679877061107a2f690a82a8cb46eea57a90abd23798c5937c6fe6b17be3f9bfa01ce117d2c268181b9095bf49f395fea07ca03838de0588c5e2db633e836d64488c1421e653ea52d810d096048c092d0da6e02fa6613890219f51a76148c8588c2487b171a28f17b7a299204874af0131725d793481333be5f08e86ca837a226850b0c1060891603bfecf9e55cddd22c0dbb28d495342d9cc3de8409f72e52a0115141cffe755c74f061c1a770428ccb0ae59536ee6fc074fbfc6cacb51a549d327527e20f8407477e60355863f1153f9ce95641198663c968874e7fdb29407bd771d94fdda8180cbb0358f5874738db705924b8cbe0cd5e1484aeb64542fe8f38667b7c34baf818c63b1e18440e9fba575254d063fd49f24ef26432f4eb323f3836972dca87473e3e9bb26dc3be236c3aae6bc8a6da567442309da0e8450e242fc9db836e2964f2c76a3b80a2c677979882dda7d7ebf62c93664018bcf4ec431fe6b403d49b3b36618b9c07c2d0d4569cb8d52223903debc72ec113955b206c34f1ae5300990ccfc0180f47d91afdb542b6312d12aeff7e19c645dc0b9fe6e3288e9539f6d5870f99882df187bfa6d24d179dfd1dac22212c8b5339f7171a3efc15b760fed8f68538bc5cbd845c2d1ab41f3a6c692820653eaef7930c02fbe6061d93805d73decdbb945572a7c44ed0241982a6e4d2d730898f82b3d9877cb7bca41cc6dcee67aa0c3d6db76f0b0a708ace0031113e48429de5d886c10e9200f68f32263a2fbf44a5992c2459fda7b8796ba796e3a0804fc25992ed2c9a5fe0580a6b809200ecde6caa0364b58be11564dcb9a616766dd7906db5636ee708b0204f38d309466d8d4a162965dd727e29f5a6c133e9b4ed5bafe803e479f9b2a7640c942c4a40b14ac7dc9828546052761a070f6404008f1ec3605836339c3da95a00b4fd81b2cabf88b51d2087d5b83e8c5b69bf96d8c72cbd278dad3bbb42b404b436f84ad688a22948adf60a81090f1e904291503c16e9f54b05fc76c881a5f95f0e732949e95d3f1bae2d3652a14fe0dda2d68879604657171856ef72637def2a96ac47d7b3fe86eb3198f5e0e626f06be86232305f2ae79ffcd2725e48208f9d8d63523f81915acc957563ab627cd6bc68c2a37d59fb0ed77a90aa9d085d6914a8ebada22a2c2d471b5163aeddd799d90fbb10ed6851ace2c4af504b7d572686700a59d6db46d5e42bb83f8e0c0ffe1dfa6582cc0b34c921ff6e85e83188d24906d5c08bb90069639e713051b3102b53e6f703e8210017878add5df68e6f2b108de279c5490e9eef5590185c4a1c744d4e00d244e1245a8805bd30407b1bc488db44870ccfd75a8af104df78efa2fb7ba31f048a263efdb3b63271fff4922bece9a71187108f65744a24f4947dc556b7440cb4fa45d296bb7f724588d1f245125b21ea063500029bd49650237f53899daf1312809552c81c5827341263cc807a29fe84746170cdfa1ff3838399a5645319bcaff674bb70efccdd88b3d3bb2f2d98111413585dc5d5bd5168f43b3f55e58972a5b2b9b3733febf02f931bd436648cb617c3794841aab961fe41277ab07812e1d3bc4ff6f4350a3e615bfba08c3b9480ef57904d3a16f7e916345202e3f93d11f7a7305170cb8c4eb9ac88ace8bbd1f377bdd5855d3162d6723d4435e84ce529b8f276a8927915ac759a0d04e5ca4a9d3da6291f0333b475df527e99fe38f7a4082662e8125936640c26dd1d17cf284ce6e2b17777a05aa0574f7793a6a062cc6f7263f7ab126b4528a17becfdec49ac0f7d8705aa1704af97fb861faa8a466161b2b5c08a5bacc79fe8500b913d65c8d3c52d1fd52d2ab2c9f52196e712455619c1cd3e0f391b274487944240e2ed8858dd0823c801094310024ae3fe4dd1cf5a2b6487b42cc5937bbafb193ee331d87e378258963d49b9da90899bbb4b88e79f78e866b0213f4719f67da7bcc2fce073c01e87c62ea3cdbcd589cfc41281f2f4c757c742d6d1e"
var hashFunc = sha3.NewLegacyKeccak256
var testKey Key
func init() {
testKey = swarm.MustParseHexAddress("8abf1502f557f15026716030fb6384792583daf39608a3cd02ff2f47e9bc6e49").Bytes()
}
func TestEncryptDataLongerThanPadding(t *testing.T) {
enc := New(testKey, 4095, uint32(0), hashFunc)
data := make([]byte, 4096)
expectedError := "data length longer than padding, data length 4096 padding 4095"
_, err := enc.Encrypt(data)
if err == nil || err.Error() != expectedError {
t.Fatalf("Expected error \"%v\" got \"%v\"", expectedError, err)
}
}
func TestEncryptDataZeroPadding(t *testing.T) {
enc := New(testKey, 0, uint32(0), hashFunc)
data := make([]byte, 2048)
encrypted, err := enc.Encrypt(data)
if err != nil {
t.Fatalf("Expected no error got %v", err)
}
if len(encrypted) != 2048 {
t.Fatalf("Encrypted data length expected \"%v\" got %v", 2048, len(encrypted))
}
}
func TestEncryptDataLengthEqualsPadding(t *testing.T) {
enc := New(testKey, 4096, uint32(0), hashFunc)
data := make([]byte, 4096)
encrypted, err := enc.Encrypt(data)
if err != nil {
t.Fatalf("Expected no error got %v", err)
}
encryptedHex := hex.EncodeToString(encrypted)
expectedTransformed, _ := hex.DecodeString(expectedTransformedHex)
if !bytes.Equal(encrypted, expectedTransformed) {
t.Fatalf("Expected %v got %v", expectedTransformedHex, encryptedHex)
}
}
func TestEncryptDataLengthSmallerThanPadding(t *testing.T) {
enc := New(testKey, 4096, uint32(0), hashFunc)
data := make([]byte, 4080)
encrypted, err := enc.Encrypt(data)
if err != nil {
t.Fatalf("Expected no error got %v", err)
}
if len(encrypted) != 4096 {
t.Fatalf("Encrypted data length expected %v got %v", 4096, len(encrypted))
}
}
func TestEncryptDataCounterNonZero(t *testing.T) {
// TODO
}
func TestDecryptDataLengthNotEqualsPadding(t *testing.T) {
enc := New(testKey, 4096, uint32(0), hashFunc)
data := make([]byte, 4097)
expectedError := "data length different than padding, data length 4097 padding 4096"
_, err := enc.Decrypt(data)
if err == nil || err.Error() != expectedError {
t.Fatalf("Expected error \"%v\" got \"%v\"", expectedError, err)
}
}
func TestEncryptDecryptIsIdentity(t *testing.T) {
testEncryptDecryptIsIdentity(t, 0, 2048, 2048, 32)
testEncryptDecryptIsIdentity(t, 0, 4096, 4096, 32)
testEncryptDecryptIsIdentity(t, 0, 4096, 1000, 32)
testEncryptDecryptIsIdentity(t, 10, 32, 32, 32)
}
func testEncryptDecryptIsIdentity(t *testing.T, initCtr uint32, padding, dataLength, keyLength int) {
key := GenerateRandomKey(keyLength)
enc := New(key, padding, initCtr, hashFunc)
data := RandomBytes(1, dataLength)
encrypted, err := enc.Encrypt(data)
if err != nil {
t.Fatalf("Expected no error got %v", err)
}
enc.Reset()
decrypted, err := enc.Decrypt(encrypted)
if err != nil {
t.Fatalf("Expected no error got %v", err)
}
if len(decrypted) != padding {
t.Fatalf("Expected decrypted data length %v got %v", padding, len(decrypted))
}
// we have to remove the extra bytes which were randomly added to fill until padding
if len(data) < padding {
decrypted = decrypted[:len(data)]
}
if !bytes.Equal(data, decrypted) {
t.Fatalf("Expected decrypted %v got %v", hex.EncodeToString(data), hex.EncodeToString(decrypted))
}
}
// TestEncryptSectioned tests that the cipherText is the same regardless of size of data input buffer
func TestEncryptSectioned(t *testing.T) {
data := make([]byte, 4096)
c, err := crand.Read(data)
if err != nil {
t.Fatal(err)
}
if c < 4096 {
t.Fatalf("short read %d", c)
}
key := make([]byte, KeyLength)
c, err = crand.Read(key)
if err != nil {
t.Fatal(err)
}
if c < KeyLength {
t.Fatalf("short read %d", c)
}
enc := New(key, 0, uint32(42), sha3.NewLegacyKeccak256)
whole, err := enc.Encrypt(data)
if err != nil {
t.Fatal(err)
}
enc.Reset()
for i := 0; i < 4096; i += KeyLength {
cipher, err := enc.Encrypt(data[i : i+KeyLength])
if err != nil {
t.Fatal(err)
}
wholeSection := whole[i : i+KeyLength]
if !bytes.Equal(cipher, wholeSection) {
t.Fatalf("index %d, expected %x, got %x", i/KeyLength, wholeSection, cipher)
}
}
}
// RandomBytes returns pseudo-random deterministic result
// because test fails must be reproducible
func RandomBytes(seed, length int) []byte {
b := make([]byte, length)
reader := rand.New(rand.NewSource(int64(seed)))
for n := 0; n < length; {
read, err := reader.Read(b[n:])
if err != nil {
panic(err)
}
n += read
}
return b
}
...@@ -24,7 +24,7 @@ var ( ...@@ -24,7 +24,7 @@ var (
// returning the length of the data which will be returned. // returning the length of the data which will be returned.
// The called can then read the data on the io.Reader that was provided. // The called can then read the data on the io.Reader that was provided.
type Joiner interface { type Joiner interface {
Join(ctx context.Context, address swarm.Address) (dataOut io.ReadCloser, dataLength int64, err error) Join(ctx context.Context, address swarm.Address, toDecrypt bool) (dataOut io.ReadCloser, dataLength int64, err error)
Size(ctx context.Context, address swarm.Address) (dataLength int64, err error) Size(ctx context.Context, address swarm.Address) (dataLength int64, err error)
} }
...@@ -34,12 +34,12 @@ type Joiner interface { ...@@ -34,12 +34,12 @@ type Joiner interface {
// If the dataLength parameter is 0, data is read until io.EOF is encountered. // If the dataLength parameter is 0, data is read until io.EOF is encountered.
// When EOF is received and splitting is done, the resulting Swarm Address is returned. // When EOF is received and splitting is done, the resulting Swarm Address is returned.
type Splitter interface { type Splitter interface {
Split(ctx context.Context, dataIn io.ReadCloser, dataLength int64) (addr swarm.Address, err error) Split(ctx context.Context, dataIn io.ReadCloser, dataLength int64, toEncrypt bool) (addr swarm.Address, err error)
} }
// JoinReadAll reads all output from the provided joiner. // JoinReadAll reads all output from the provided joiner.
func JoinReadAll(j Joiner, addr swarm.Address, outFile io.Writer) (int64, error) { func JoinReadAll(j Joiner, addr swarm.Address, outFile io.Writer, toDecrypt bool) (int64, error) {
r, l, err := j.Join(context.Background(), addr) r, l, err := j.Join(context.Background(), addr, toDecrypt)
if err != nil { if err != nil {
return 0, err return 0, err
} }
...@@ -67,7 +67,7 @@ func JoinReadAll(j Joiner, addr swarm.Address, outFile io.Writer) (int64, error) ...@@ -67,7 +67,7 @@ func JoinReadAll(j Joiner, addr swarm.Address, outFile io.Writer) (int64, error)
} }
// SplitWriteAll writes all input from provided reader to the provided splitter // SplitWriteAll writes all input from provided reader to the provided splitter
func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64) (swarm.Address, error) { func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64, toEncrypt bool) (swarm.Address, error) {
chunkPipe := NewChunkPipe() chunkPipe := NewChunkPipe()
errC := make(chan error) errC := make(chan error)
go func() { go func() {
...@@ -86,7 +86,7 @@ func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64) (swarm ...@@ -86,7 +86,7 @@ func SplitWriteAll(ctx context.Context, s Splitter, r io.Reader, l int64) (swarm
close(errC) close(errC)
}() }()
addr, err := s.Split(ctx, chunkPipe, l) addr, err := s.Split(ctx, chunkPipe, l, toEncrypt)
if err != nil { if err != nil {
return swarm.ZeroAddress, err return swarm.ZeroAddress, err
} }
......
...@@ -53,13 +53,13 @@ func testSplitThenJoin(t *testing.T) { ...@@ -53,13 +53,13 @@ func testSplitThenJoin(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
dataReader := file.NewSimpleReadCloser(data) dataReader := file.NewSimpleReadCloser(data)
resultAddress, err := s.Split(ctx, dataReader, int64(len(data))) resultAddress, err := s.Split(ctx, dataReader, int64(len(data)), false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// then join // then join
r, l, err := j.Join(ctx, resultAddress) r, l, err := j.Join(ctx, resultAddress, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -93,7 +93,7 @@ func TestJoinReadAll(t *testing.T) { ...@@ -93,7 +93,7 @@ func TestJoinReadAll(t *testing.T) {
var dataLength int64 = swarm.ChunkSize + 2 var dataLength int64 = swarm.ChunkSize + 2
j := newMockJoiner(dataLength) j := newMockJoiner(dataLength)
buf := bytes.NewBuffer(nil) buf := bytes.NewBuffer(nil)
c, err := file.JoinReadAll(j, swarm.ZeroAddress, buf) c, err := file.JoinReadAll(j, swarm.ZeroAddress, buf, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -113,7 +113,7 @@ type mockJoiner struct { ...@@ -113,7 +113,7 @@ type mockJoiner struct {
} }
// Join implements file.Joiner. // Join implements file.Joiner.
func (j *mockJoiner) Join(ctx context.Context, address swarm.Address) (dataOut io.ReadCloser, dataLength int64, err error) { func (j *mockJoiner) Join(ctx context.Context, address swarm.Address, toDecrypt bool) (dataOut io.ReadCloser, dataLength int64, err error) {
data := make([]byte, j.l) data := make([]byte, j.l)
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
readCloser := ioutil.NopCloser(buf) readCloser := ioutil.NopCloser(buf)
......
...@@ -13,10 +13,12 @@ import ( ...@@ -13,10 +13,12 @@ import (
"io/ioutil" "io/ioutil"
"sync" "sync"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/crypto/sha3"
) )
// SimpleJoinerJob encapsulates a single joiner operation, providing the consumer // SimpleJoinerJob encapsulates a single joiner operation, providing the consumer
...@@ -46,10 +48,11 @@ type SimpleJoinerJob struct { ...@@ -46,10 +48,11 @@ type SimpleJoinerJob struct {
closeDoneOnce sync.Once // make sure done channel is closed only once closeDoneOnce sync.Once // make sure done channel is closed only once
err error // read by the main thread to capture error state of the job err error // read by the main thread to capture error state of the job
logger logging.Logger logger logging.Logger
toDecrypt bool // to decrypt the chunks or not
} }
// NewSimpleJoinerJob creates a new simpleJoinerJob. // NewSimpleJoinerJob creates a new simpleJoinerJob.
func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk swarm.Chunk) *SimpleJoinerJob { func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk swarm.Chunk, toDecrypt bool) *SimpleJoinerJob {
spanLength := binary.LittleEndian.Uint64(rootChunk.Data()[:8]) spanLength := binary.LittleEndian.Uint64(rootChunk.Data()[:8])
levelCount := file.Levels(int64(spanLength), swarm.SectionSize, swarm.Branches) levelCount := file.Levels(int64(spanLength), swarm.SectionSize, swarm.Branches)
...@@ -60,6 +63,7 @@ func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk sw ...@@ -60,6 +63,7 @@ func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk sw
dataC: make(chan []byte), dataC: make(chan []byte),
doneC: make(chan struct{}), doneC: make(chan struct{}),
logger: logging.New(ioutil.Discard, 0), logger: logging.New(ioutil.Discard, 0),
toDecrypt: toDecrypt,
} }
// startLevelIndex is the root chunk level // startLevelIndex is the root chunk level
...@@ -87,7 +91,6 @@ func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk sw ...@@ -87,7 +91,6 @@ func NewSimpleJoinerJob(ctx context.Context, getter storage.Getter, rootChunk sw
// start processes all chunk references of the root chunk that already has been retrieved. // start processes all chunk references of the root chunk that already has been retrieved.
func (j *SimpleJoinerJob) start(level int) error { func (j *SimpleJoinerJob) start(level int) error {
// consume the reference at the current cursor position of the chunk level data // consume the reference at the current cursor position of the chunk level data
// and start recursive retrieval down to the underlying data chunks // and start recursive retrieval down to the underlying data chunks
for j.cursors[level] < len(j.data[level]) { for j.cursors[level] < len(j.data[level]) {
...@@ -104,8 +107,15 @@ func (j *SimpleJoinerJob) start(level int) error { ...@@ -104,8 +107,15 @@ func (j *SimpleJoinerJob) start(level int) error {
func (j *SimpleJoinerJob) nextReference(level int) error { func (j *SimpleJoinerJob) nextReference(level int) error {
data := j.data[level] data := j.data[level]
cursor := j.cursors[level] cursor := j.cursors[level]
var encryptionKey encryption.Key
chunkAddress := swarm.NewAddress(data[cursor : cursor+swarm.SectionSize]) chunkAddress := swarm.NewAddress(data[cursor : cursor+swarm.SectionSize])
err := j.nextChunk(level-1, chunkAddress) if j.toDecrypt {
encryptionKey = make([]byte, encryption.KeyLength)
copy(encryptionKey, data[cursor+swarm.SectionSize:cursor+swarm.SectionSize+encryption.KeyLength])
}
err := j.nextChunk(level-1, chunkAddress, encryptionKey)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
return err return err
...@@ -124,6 +134,9 @@ func (j *SimpleJoinerJob) nextReference(level int) error { ...@@ -124,6 +134,9 @@ func (j *SimpleJoinerJob) nextReference(level int) error {
// move the cursor to the next reference // move the cursor to the next reference
j.cursors[level] += swarm.SectionSize j.cursors[level] += swarm.SectionSize
if j.toDecrypt {
j.cursors[level] += encryption.KeyLength
}
return nil return nil
} }
...@@ -132,22 +145,33 @@ func (j *SimpleJoinerJob) nextReference(level int) error { ...@@ -132,22 +145,33 @@ func (j *SimpleJoinerJob) nextReference(level int) error {
// the current chunk is an intermediate chunk. // the current chunk is an intermediate chunk.
// When a data chunk is found it is passed on the dataC channel to be consumed by the // When a data chunk is found it is passed on the dataC channel to be consumed by the
// io.Reader consumer. // io.Reader consumer.
func (j *SimpleJoinerJob) nextChunk(level int, address swarm.Address) error { func (j *SimpleJoinerJob) nextChunk(level int, address swarm.Address, key encryption.Key) error {
// attempt to retrieve the chunk // attempt to retrieve the chunk
ch, err := j.getter.Get(j.ctx, storage.ModeGetRequest, address) ch, err := j.getter.Get(j.ctx, storage.ModeGetRequest, address)
if err != nil { if err != nil {
return err return err
} }
var chunkData []byte
if j.toDecrypt {
decryptedData, err := DecryptChunkData(ch.Data(), key)
if err != nil {
return fmt.Errorf("error decrypting chunk %v: %v", address, err)
}
chunkData = decryptedData[8:]
} else {
chunkData = ch.Data()[8:]
}
j.cursors[level] = 0 j.cursors[level] = 0
j.data[level] = ch.Data()[8:] j.data[level] = chunkData
// any level higher than 0 means the chunk contains references // any level higher than 0 means the chunk contains references
// which must be recursively processed // which must be recursively processed
if level > 0 { if level > 0 {
for j.cursors[level] < len(j.data[level]) { for j.cursors[level] < len(j.data[level]) {
if len(j.data[level]) == j.cursors[level] { if len(j.data[level]) == j.cursors[level] {
j.data[level] = ch.Data()[8:] j.data[level] = chunkData
j.cursors[level] = 0 j.cursors[level] = 0
} }
err = j.nextReference(level) err = j.nextReference(level)
...@@ -159,7 +183,7 @@ func (j *SimpleJoinerJob) nextChunk(level int, address swarm.Address) error { ...@@ -159,7 +183,7 @@ func (j *SimpleJoinerJob) nextChunk(level int, address swarm.Address) error {
// read data and pass to reader only if session is still active // read data and pass to reader only if session is still active
// * context cancelled when client has disappeared, timeout etc // * context cancelled when client has disappeared, timeout etc
// * doneC receive when gracefully terminated through Close // * doneC receive when gracefully terminated through Close
data := ch.Data()[8:] data := chunkData
err = j.sendChunkToReader(data) err = j.sendChunkToReader(data)
} }
return err return err
...@@ -213,3 +237,50 @@ func (j *SimpleJoinerJob) closeDone() { ...@@ -213,3 +237,50 @@ func (j *SimpleJoinerJob) closeDone() {
close(j.doneC) close(j.doneC)
}) })
} }
func DecryptChunkData(chunkData []byte, encryptionKey encryption.Key) ([]byte, error) {
if len(chunkData) < 8 {
return nil, fmt.Errorf("invalid ChunkData, min length 8 got %v", len(chunkData))
}
decryptedSpan, decryptedData, err := decrypt(chunkData, encryptionKey)
if err != nil {
return nil, err
}
// removing extra bytes which were just added for padding
length := binary.LittleEndian.Uint64(decryptedSpan)
refSize := int64(swarm.HashSize + encryption.KeyLength)
for length > swarm.ChunkSize {
length = length + (swarm.ChunkSize - 1)
length = length / swarm.ChunkSize
length *= uint64(refSize)
}
c := make([]byte, length+8)
copy(c[:8], decryptedSpan)
copy(c[8:], decryptedData[:length])
return c, nil
}
func decrypt(chunkData []byte, key encryption.Key) ([]byte, []byte, error) {
encryptedSpan, err := newSpanEncryption(key).Encrypt(chunkData[:8])
if err != nil {
return nil, nil, err
}
encryptedData, err := newDataEncryption(key).Encrypt(chunkData[8:])
if err != nil {
return nil, nil, err
}
return encryptedSpan, encryptedData, nil
}
func newSpanEncryption(key encryption.Key) *encryption.Encryption {
refSize := int64(swarm.HashSize + encryption.KeyLength)
return encryption.New(key, 0, uint32(swarm.ChunkSize/refSize), sha3.NewLegacyKeccak256)
}
func newDataEncryption(key encryption.Key) *encryption.Encryption {
return encryption.New(key, int(swarm.ChunkSize), 0, sha3.NewLegacyKeccak256)
}
...@@ -48,7 +48,7 @@ func TestSimpleJoinerJobBlocksize(t *testing.T) { ...@@ -48,7 +48,7 @@ func TestSimpleJoinerJobBlocksize(t *testing.T) {
} }
// this buffer is too small // this buffer is too small
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk) j := internal.NewSimpleJoinerJob(ctx, store, rootChunk, false)
b := make([]byte, swarm.SectionSize) b := make([]byte, swarm.SectionSize)
_, err = j.Read(b) _, err = j.Read(b)
if err == nil { if err == nil {
...@@ -99,7 +99,7 @@ func TestSimpleJoinerJobOneLevel(t *testing.T) { ...@@ -99,7 +99,7 @@ func TestSimpleJoinerJobOneLevel(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk) j := internal.NewSimpleJoinerJob(ctx, store, rootChunk, false)
// verify first chunk content // verify first chunk content
outBuffer := make([]byte, 4096) outBuffer := make([]byte, 4096)
...@@ -188,7 +188,7 @@ func TestSimpleJoinerJobTwoLevelsAcrossChunk(t *testing.T) { ...@@ -188,7 +188,7 @@ func TestSimpleJoinerJobTwoLevelsAcrossChunk(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
j := internal.NewSimpleJoinerJob(ctx, store, rootChunk) j := internal.NewSimpleJoinerJob(ctx, store, rootChunk, false)
// read back all the chunks and verify // read back all the chunks and verify
b := make([]byte, swarm.ChunkSize) b := make([]byte, swarm.ChunkSize)
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner/internal" "github.com/ethersphere/bee/pkg/file/joiner/internal"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -49,21 +50,45 @@ func (s *simpleJoiner) Size(ctx context.Context, address swarm.Address) (dataSiz ...@@ -49,21 +50,45 @@ func (s *simpleJoiner) Size(ctx context.Context, address swarm.Address) (dataSiz
// //
// It uses a non-optimized internal component that only retrieves a data chunk // It uses a non-optimized internal component that only retrieves a data chunk
// after the previous has been read. // after the previous has been read.
func (s *simpleJoiner) Join(ctx context.Context, address swarm.Address) (dataOut io.ReadCloser, dataSize int64, err error) { func (s *simpleJoiner) Join(ctx context.Context, address swarm.Address, toDecrypt bool) (dataOut io.ReadCloser, dataSize int64, err error) {
var addr []byte
var key encryption.Key
if toDecrypt {
addr = address.Bytes()[:swarm.HashSize]
key = address.Bytes()[swarm.HashSize : swarm.HashSize+encryption.KeyLength]
} else {
addr = address.Bytes()
}
// retrieve the root chunk to read the total data length the be retrieved // retrieve the root chunk to read the total data length the be retrieved
rootChunk, err := s.getter.Get(ctx, storage.ModeGetRequest, address) rootChunk, err := s.getter.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(addr))
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
var chunkData []byte
if toDecrypt {
originalData, err := internal.DecryptChunkData(rootChunk.Data(), key)
if err != nil {
return nil, 0, err
}
chunkData = originalData
} else {
chunkData = rootChunk.Data()
}
// if this is a single chunk, short circuit to returning just that chunk // if this is a single chunk, short circuit to returning just that chunk
spanLength := binary.LittleEndian.Uint64(rootChunk.Data()) spanLength := binary.LittleEndian.Uint64(chunkData[:8])
chunkToSend := rootChunk
if spanLength <= swarm.ChunkSize { if spanLength <= swarm.ChunkSize {
data := rootChunk.Data()[8:] data := chunkData[8:]
return file.NewSimpleReadCloser(data), int64(spanLength), nil return file.NewSimpleReadCloser(data), int64(spanLength), nil
} }
r := internal.NewSimpleJoinerJob(ctx, s.getter, rootChunk) if toDecrypt {
chunkToSend = swarm.NewChunk(swarm.NewAddress(addr), chunkData)
}
r := internal.NewSimpleJoinerJob(ctx, s.getter, chunkToSend, toDecrypt)
return r, int64(spanLength), nil return r, int64(spanLength), nil
} }
...@@ -8,14 +8,18 @@ import ( ...@@ -8,14 +8,18 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"testing" "testing"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter"
filetest "github.com/ethersphere/bee/pkg/file/testing" filetest "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"gitlab.com/nolash/go-mockbytes"
) )
// TestJoiner verifies that a newly created joiner returns the data stored // TestJoiner verifies that a newly created joiner returns the data stored
...@@ -29,7 +33,7 @@ func TestJoinerSingleChunk(t *testing.T) { ...@@ -29,7 +33,7 @@ func TestJoinerSingleChunk(t *testing.T) {
defer cancel() defer cancel()
var err error var err error
_, _, err = joiner.Join(ctx, swarm.ZeroAddress) _, _, err = joiner.Join(ctx, swarm.ZeroAddress, false)
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
t.Fatalf("expected ErrNotFound for %x", swarm.ZeroAddress) t.Fatalf("expected ErrNotFound for %x", swarm.ZeroAddress)
} }
...@@ -47,7 +51,7 @@ func TestJoinerSingleChunk(t *testing.T) { ...@@ -47,7 +51,7 @@ func TestJoinerSingleChunk(t *testing.T) {
} }
// read back data and compare // read back data and compare
joinReader, l, err := joiner.Join(ctx, mockAddr) joinReader, l, err := joiner.Join(ctx, mockAddr, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -94,7 +98,7 @@ func TestJoinerWithReference(t *testing.T) { ...@@ -94,7 +98,7 @@ func TestJoinerWithReference(t *testing.T) {
} }
// read back data and compare // read back data and compare
joinReader, l, err := joiner.Join(ctx, rootChunk.Address()) joinReader, l, err := joiner.Join(ctx, rootChunk.Address(), false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -114,3 +118,63 @@ func TestJoinerWithReference(t *testing.T) { ...@@ -114,3 +118,63 @@ func TestJoinerWithReference(t *testing.T) {
t.Fatalf("expected resultbuffer %v, got %v", resultBuffer, firstChunk.Data()[:len(resultBuffer)]) t.Fatalf("expected resultbuffer %v, got %v", resultBuffer, firstChunk.Data()[:len(resultBuffer)])
} }
} }
func TestEncryptionAndDecryption(t *testing.T) {
var tests = []struct {
chunkLength int
}{
{10},
{100},
{1000},
{4095},
{4096},
{4097},
{15000},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Encrypt %d bytes", tt.chunkLength), func(t *testing.T) {
store := mock.NewStorer()
joinner := joiner.NewSimpleJoiner(store)
g := mockbytes.New(0, mockbytes.MockTypeStandard).WithModulus(255)
testData, err := g.SequentialBytes(tt.chunkLength)
if err != nil {
t.Fatal(err)
}
s := splitter.NewSimpleSplitter(store)
testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), true)
if err != nil {
t.Fatal(err)
}
reader, l, err := joinner.Join(context.Background(), resultAddress, true)
if err != nil {
t.Fatal(err)
}
if l != int64(len(testData)) {
t.Fatalf("expected join data length %d, got %d", len(testData), l)
}
totalGot := make([]byte, tt.chunkLength)
index := 0
resultBuffer := make([]byte, swarm.ChunkSize)
for index < tt.chunkLength {
n, err := reader.Read(resultBuffer)
if err != nil && err != io.EOF {
t.Fatal(err)
}
copy(totalGot[index:], resultBuffer[:n])
index += n
}
if !bytes.Equal(testData, totalGot) {
t.Fatal("input data and output data does not match")
}
})
}
}
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"hash" "hash"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -46,12 +47,19 @@ type SimpleSplitterJob struct { ...@@ -46,12 +47,19 @@ type SimpleSplitterJob struct {
cursors []int // section write position, indexed per level cursors []int // section write position, indexed per level
hasher bmt.Hash // underlying hasher used for hashing the tree hasher bmt.Hash // underlying hasher used for hashing the tree
buffer []byte // keeps data and hashes, indexed by cursors buffer []byte // keeps data and hashes, indexed by cursors
toEncrypt bool // to encryrpt the chunks or not
refSize int64
} }
// NewSimpleSplitterJob creates a new SimpleSplitterJob. // NewSimpleSplitterJob creates a new SimpleSplitterJob.
// //
// The spanLength is the length of the data that will be written. // The spanLength is the length of the data that will be written.
func NewSimpleSplitterJob(ctx context.Context, putter storage.Putter, spanLength int64) *SimpleSplitterJob { func NewSimpleSplitterJob(ctx context.Context, putter storage.Putter, spanLength int64, toEncrypt bool) *SimpleSplitterJob {
hashSize := swarm.HashSize
refSize := int64(hashSize)
if toEncrypt {
refSize += encryption.KeyLength
}
p := bmtlegacy.NewTreePool(hashFunc, swarm.Branches, bmtlegacy.PoolSize) p := bmtlegacy.NewTreePool(hashFunc, swarm.Branches, bmtlegacy.PoolSize)
return &SimpleSplitterJob{ return &SimpleSplitterJob{
ctx: ctx, ctx: ctx,
...@@ -61,6 +69,8 @@ func NewSimpleSplitterJob(ctx context.Context, putter storage.Putter, spanLength ...@@ -61,6 +69,8 @@ func NewSimpleSplitterJob(ctx context.Context, putter storage.Putter, spanLength
cursors: make([]int, levelBufferLimit), cursors: make([]int, levelBufferLimit),
hasher: bmtlegacy.New(p), hasher: bmtlegacy.New(p),
buffer: make([]byte, file.ChunkWithLengthSize*levelBufferLimit*2), // double size as temp workaround for weak calculation of needed buffer space buffer: make([]byte, file.ChunkWithLengthSize*levelBufferLimit*2), // double size as temp workaround for weak calculation of needed buffer space
toEncrypt: toEncrypt,
refSize: refSize,
} }
} }
...@@ -126,34 +136,48 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) { ...@@ -126,34 +136,48 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
s.sumCounts[lvl]++ s.sumCounts[lvl]++
spanSize := file.Spans[lvl] * swarm.ChunkSize spanSize := file.Spans[lvl] * swarm.ChunkSize
span := (s.length-1)%spanSize + 1 span := (s.length-1)%spanSize + 1
sizeToSum := s.cursors[lvl] - s.cursors[lvl+1] sizeToSum := s.cursors[lvl] - s.cursors[lvl+1]
// perform hashing //perform hashing
s.hasher.Reset() s.hasher.Reset()
err := s.hasher.SetSpan(span) err := s.hasher.SetSpan(span)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = s.hasher.Write(s.buffer[s.cursors[lvl+1] : s.cursors[lvl+1]+sizeToSum])
var ref encryption.Key
var chunkData []byte
data := s.buffer[s.cursors[lvl+1] : s.cursors[lvl+1]+sizeToSum]
_, err = s.hasher.Write(data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ref := s.hasher.Sum(nil) ref = s.hasher.Sum(nil)
// assemble chunk and put in store
addr := swarm.NewAddress(ref)
head := make([]byte, 8) head := make([]byte, 8)
binary.LittleEndian.PutUint64(head, uint64(span)) binary.LittleEndian.PutUint64(head, uint64(span))
tail := s.buffer[s.cursors[lvl+1]:s.cursors[lvl]] tail := s.buffer[s.cursors[lvl+1]:s.cursors[lvl]]
chunkData := append(head, tail...) chunkData = append(head, tail...)
ch := swarm.NewChunk(addr, chunkData)
// assemble chunk and put in store
addr := swarm.NewAddress(ref)
c := chunkData
var encryptionKey encryption.Key
if s.toEncrypt {
c, encryptionKey, err = s.encryptChunkData(chunkData)
if err != nil {
return nil, err
}
}
ch := swarm.NewChunk(addr, c)
_, err = s.putter.Put(s.ctx, storage.ModePutUpload, ch) _, err = s.putter.Put(s.ctx, storage.ModePutUpload, ch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ref, nil return append(ch.Address().Bytes(), encryptionKey...), nil
} }
// digest returns the calculated digest after a Sum call. // digest returns the calculated digest after a Sum call.
...@@ -164,7 +188,11 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) { ...@@ -164,7 +188,11 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
// The method does not check that the final hash actually has been written, so // The method does not check that the final hash actually has been written, so
// timing is the responsibility of the caller. // timing is the responsibility of the caller.
func (s *SimpleSplitterJob) digest() []byte { func (s *SimpleSplitterJob) digest() []byte {
return s.buffer[:swarm.SectionSize] if s.toEncrypt {
return s.buffer[:swarm.SectionSize*2]
} else {
return s.buffer[:swarm.SectionSize]
}
} }
// hashUnfinished hasher the remaining unhashed chunks at the end of each level if // hashUnfinished hasher the remaining unhashed chunks at the end of each level if
...@@ -229,3 +257,39 @@ func (s *SimpleSplitterJob) moveDanglingChunk() error { ...@@ -229,3 +257,39 @@ func (s *SimpleSplitterJob) moveDanglingChunk() error {
} }
return nil return nil
} }
func (s *SimpleSplitterJob) encryptChunkData(chunkData []byte) ([]byte, encryption.Key, error) {
if len(chunkData) < 8 {
return nil, nil, fmt.Errorf("invalid data, min length 8 got %v", len(chunkData))
}
key, encryptedSpan, encryptedData, err := s.encrypt(chunkData)
if err != nil {
return nil, nil, err
}
c := make([]byte, len(encryptedSpan)+len(encryptedData))
copy(c[:8], encryptedSpan)
copy(c[8:], encryptedData)
return c, key, nil
}
func (s *SimpleSplitterJob) encrypt(chunkData []byte) (encryption.Key, []byte, []byte, error) {
key := encryption.GenerateRandomKey(encryption.KeyLength)
encryptedSpan, err := s.newSpanEncryption(key).Encrypt(chunkData[:8])
if err != nil {
return nil, nil, nil, err
}
encryptedData, err := s.newDataEncryption(key).Encrypt(chunkData[8:])
if err != nil {
return nil, nil, nil, err
}
return key, encryptedSpan, encryptedData, nil
}
func (s *SimpleSplitterJob) newSpanEncryption(key encryption.Key) *encryption.Encryption {
return encryption.New(key, 0, uint32(swarm.ChunkSize/s.refSize), sha3.NewLegacyKeccak256)
}
func (s *SimpleSplitterJob) newDataEncryption(key encryption.Key) *encryption.Encryption {
return encryption.New(key, int(swarm.ChunkSize), 0, sha3.NewLegacyKeccak256)
}
...@@ -31,7 +31,7 @@ func TestSplitterJobPartialSingleChunk(t *testing.T) { ...@@ -31,7 +31,7 @@ func TestSplitterJobPartialSingleChunk(t *testing.T) {
defer cancel() defer cancel()
data := []byte("foo") data := []byte("foo")
j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data))) j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data)), false)
c, err := j.Write(data) c, err := j.Write(data)
if err != nil { if err != nil {
...@@ -74,7 +74,7 @@ func testSplitterJobVector(t *testing.T) { ...@@ -74,7 +74,7 @@ func testSplitterJobVector(t *testing.T) {
data, expect := test.GetVector(t, int(dataIdx)) data, expect := test.GetVector(t, int(dataIdx))
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data))) j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data)), false)
for i := 0; i < len(data); i += swarm.ChunkSize { for i := 0; i < len(data); i += swarm.ChunkSize {
l := swarm.ChunkSize l := swarm.ChunkSize
......
...@@ -34,9 +34,8 @@ func NewSimpleSplitter(putter storage.Putter) file.Splitter { ...@@ -34,9 +34,8 @@ func NewSimpleSplitter(putter storage.Putter) file.Splitter {
// multiple levels of hashing when building the file hash tree. // multiple levels of hashing when building the file hash tree.
// //
// It returns the Swarmhash of the data. // It returns the Swarmhash of the data.
func (s *simpleSplitter) Split(ctx context.Context, r io.ReadCloser, dataLength int64) (addr swarm.Address, err error) { func (s *simpleSplitter) Split(ctx context.Context, r io.ReadCloser, dataLength int64, toEncrypt bool) (addr swarm.Address, err error) {
j := internal.NewSimpleSplitterJob(ctx, s.putter, dataLength) j := internal.NewSimpleSplitterJob(ctx, s.putter, dataLength, toEncrypt)
var total int64 var total int64
data := make([]byte, swarm.ChunkSize) data := make([]byte, swarm.ChunkSize)
var eof bool var eof bool
...@@ -49,6 +48,7 @@ func (s *simpleSplitter) Split(ctx context.Context, r io.ReadCloser, dataLength ...@@ -49,6 +48,7 @@ func (s *simpleSplitter) Split(ctx context.Context, r io.ReadCloser, dataLength
return swarm.ZeroAddress, fmt.Errorf("splitter only received %d bytes of data, expected %d bytes", total+int64(c), dataLength) return swarm.ZeroAddress, fmt.Errorf("splitter only received %d bytes of data, expected %d bytes", total+int64(c), dataLength)
} }
eof = true eof = true
continue
} else { } else {
return swarm.ZeroAddress, err return swarm.ZeroAddress, err
} }
......
...@@ -26,7 +26,7 @@ func TestSplitIncomplete(t *testing.T) { ...@@ -26,7 +26,7 @@ func TestSplitIncomplete(t *testing.T) {
s := splitter.NewSimpleSplitter(store) s := splitter.NewSimpleSplitter(store)
testDataReader := file.NewSimpleReadCloser(testData) testDataReader := file.NewSimpleReadCloser(testData)
_, err := s.Split(context.Background(), testDataReader, 41) _, err := s.Split(context.Background(), testDataReader, 41, false)
if err == nil { if err == nil {
t.Fatalf("expected error on EOF before full length write") t.Fatalf("expected error on EOF before full length write")
} }
...@@ -45,7 +45,7 @@ func TestSplitSingleChunk(t *testing.T) { ...@@ -45,7 +45,7 @@ func TestSplitSingleChunk(t *testing.T) {
s := splitter.NewSimpleSplitter(store) s := splitter.NewSimpleSplitter(store)
testDataReader := file.NewSimpleReadCloser(testData) testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData))) resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -68,7 +68,7 @@ func TestSplitSingleChunk(t *testing.T) { ...@@ -68,7 +68,7 @@ func TestSplitSingleChunk(t *testing.T) {
func TestSplitThreeLevels(t *testing.T) { func TestSplitThreeLevels(t *testing.T) {
// edge case selected from internal/job_test.go // edge case selected from internal/job_test.go
g := mockbytes.New(0, mockbytes.MockTypeStandard).WithModulus(255) g := mockbytes.New(0, mockbytes.MockTypeStandard).WithModulus(255)
testData, err := g.SequentialBytes(swarm.ChunkSize * swarm.Branches) testData, err := g.SequentialBytes(swarm.ChunkSize * 128)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -77,7 +77,7 @@ func TestSplitThreeLevels(t *testing.T) { ...@@ -77,7 +77,7 @@ func TestSplitThreeLevels(t *testing.T) {
s := splitter.NewSimpleSplitter(store) s := splitter.NewSimpleSplitter(store)
testDataReader := file.NewSimpleReadCloser(testData) testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData))) resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -136,7 +136,7 @@ func TestUnalignedSplit(t *testing.T) { ...@@ -136,7 +136,7 @@ func TestUnalignedSplit(t *testing.T) {
doneC := make(chan swarm.Address) doneC := make(chan swarm.Address)
errC := make(chan error) errC := make(chan error)
go func() { go func() {
addr, err := sp.Split(ctx, chunkPipe, dataLen) addr, err := sp.Split(ctx, chunkPipe, dataLen, false)
if err != nil { if err != nil {
errC <- err errC <- err
} else { } else {
...@@ -180,5 +180,4 @@ func TestUnalignedSplit(t *testing.T) { ...@@ -180,5 +180,4 @@ func TestUnalignedSplit(t *testing.T) {
case <-timer.C: case <-timer.C:
t.Fatal("timeout") t.Fatal("timeout")
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment