Commit 3cdcdc62 authored by Zahoor Mohamed's avatar Zahoor Mohamed Committed by GitHub

Fix file upload/download during encryption (#472)

* Fix file upload/download during encryption
parent 2b82b44a
...@@ -6,6 +6,7 @@ package api_test ...@@ -6,6 +6,7 @@ package api_test
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"mime" "mime"
...@@ -51,15 +52,21 @@ func TestFiles(t *testing.T) { ...@@ -51,15 +52,21 @@ func TestFiles(t *testing.T) {
t.Run("encrypt-decrypt", func(t *testing.T) { t.Run("encrypt-decrypt", func(t *testing.T) {
fileName := "my-pictures.jpeg" fileName := "my-pictures.jpeg"
rootHash := "f2e761160deda91c1fbfab065a5abf530b0766b3e102b51fbd626ba37c3bc581"
headers := make(http.Header) headers := make(http.Header)
headers.Add("EncryptHeader", "True") headers.Add(api.EncryptHeader, "True")
headers.Add("Content-Type", "image/jpeg; charset=utf-8") headers.Add("Content-Type", "image/jpeg; charset=utf-8")
_ = jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, client, http.MethodPost, fileUploadResource+"?name="+fileName, bytes.NewReader(simpleData), http.StatusOK, api.FileUploadResponse{ _, respBytes := jsonhttptest.ResponseDirectSendHeadersAndDontCheckResponse(t, client, http.MethodPost, fileUploadResource+"?name="+fileName, bytes.NewReader(simpleData), http.StatusOK, headers)
Reference: swarm.MustParseHexAddress(rootHash), read := bytes.NewReader(respBytes)
}, headers)
// get the reference as everytime it will change because of random encryption key
var resp api.FileUploadResponse
err := json.NewDecoder(read).Decode(&resp)
if err != nil {
t.Fatal(err)
}
rootHash := resp.Reference.String()
rcvdHeader := jsonhttptest.ResponseDirectCheckBinaryResponse(t, client, http.MethodGet, fileDownloadResource(rootHash), nil, http.StatusOK, simpleData, nil) rcvdHeader := jsonhttptest.ResponseDirectCheckBinaryResponse(t, client, http.MethodGet, fileDownloadResource(rootHash), nil, http.StatusOK, simpleData, nil)
cd := rcvdHeader.Get("Content-Disposition") cd := rcvdHeader.Get("Content-Disposition")
_, params, err := mime.ParseMediaType(cd) _, params, err := mime.ParseMediaType(cd)
......
...@@ -12,8 +12,9 @@ import ( ...@@ -12,8 +12,9 @@ import (
) )
var ( var (
_ = collection.Entry(&Entry{}) _ = collection.Entry(&Entry{})
serializedDataSize = swarm.SectionSize * 2 serializedDataSize = swarm.SectionSize * 2
encryptedSerializedDataSize = swarm.EncryptedReferenceSize * 2
) )
// Entry provides addition of metadata to a data reference. // Entry provides addition of metadata to a data reference.
...@@ -51,10 +52,15 @@ func (e *Entry) MarshalBinary() ([]byte, error) { ...@@ -51,10 +52,15 @@ func (e *Entry) MarshalBinary() ([]byte, error) {
// UnmarshalBinary implements encoding.BinaryUnmarshaler // UnmarshalBinary implements encoding.BinaryUnmarshaler
func (e *Entry) UnmarshalBinary(b []byte) error { func (e *Entry) UnmarshalBinary(b []byte) error {
if len(b) != serializedDataSize { var size int
if len(b) == serializedDataSize {
size = serializedDataSize
} else if len(b) == encryptedSerializedDataSize {
size = encryptedSerializedDataSize
} else {
return errors.New("invalid data length") return errors.New("invalid data length")
} }
e.reference = swarm.NewAddress(b[:swarm.SectionSize]) e.reference = swarm.NewAddress(b[:size/2])
e.metadata = swarm.NewAddress(b[swarm.SectionSize:]) e.metadata = swarm.NewAddress(b[size/2:])
return nil return nil
} }
This diff is collapsed.
...@@ -31,6 +31,17 @@ func NewSimpleJoiner(getter storage.Getter) file.Joiner { ...@@ -31,6 +31,17 @@ func NewSimpleJoiner(getter storage.Getter) file.Joiner {
} }
func (s *simpleJoiner) Size(ctx context.Context, address swarm.Address) (dataSize int64, err error) { func (s *simpleJoiner) Size(ctx context.Context, address swarm.Address) (dataSize int64, err error) {
// Handle size based on whether the root chunk is encrypted or not
toDecrypt := len(address.Bytes()) == swarm.EncryptedReferenceSize
var key encryption.Key
addrBytes := address.Bytes()
if toDecrypt {
addrBytes = address.Bytes()[:swarm.HashSize]
key = address.Bytes()[swarm.HashSize : swarm.HashSize+encryption.KeyLength]
}
address = swarm.NewAddress(addrBytes)
// retrieve the root chunk to read the total data length the be retrieved // retrieve the root chunk to read the total data length the be retrieved
rootChunk, err := s.getter.Get(ctx, storage.ModeGetRequest, address) rootChunk, err := s.getter.Get(ctx, storage.ModeGetRequest, address)
if err != nil { if err != nil {
...@@ -42,7 +53,15 @@ func (s *simpleJoiner) Size(ctx context.Context, address swarm.Address) (dataSiz ...@@ -42,7 +53,15 @@ func (s *simpleJoiner) Size(ctx context.Context, address swarm.Address) (dataSiz
return 0, fmt.Errorf("invalid chunk content of %d bytes", chunkLength) return 0, fmt.Errorf("invalid chunk content of %d bytes", chunkLength)
} }
dataLength := binary.LittleEndian.Uint64(rootChunk.Data()) chunkData := rootChunk.Data()
if toDecrypt {
originalData, err := internal.DecryptChunkData(rootChunk.Data(), key)
if err != nil {
return 0, err
}
chunkData = originalData
}
dataLength := binary.LittleEndian.Uint64(chunkData[:8])
return int64(dataLength), nil return int64(dataLength), nil
} }
......
...@@ -136,41 +136,37 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) { ...@@ -136,41 +136,37 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
s.sumCounts[lvl]++ s.sumCounts[lvl]++
spanSize := file.Spans[lvl] * swarm.ChunkSize spanSize := file.Spans[lvl] * swarm.ChunkSize
span := (s.length-1)%spanSize + 1 span := (s.length-1)%spanSize + 1
sizeToSum := s.cursors[lvl] - s.cursors[lvl+1]
//perform hashing
s.hasher.Reset()
err := s.hasher.SetSpan(span)
if err != nil {
return nil, err
}
var ref encryption.Key
var chunkData []byte var chunkData []byte
data := s.buffer[s.cursors[lvl+1] : s.cursors[lvl+1]+sizeToSum] var addr swarm.Address
_, err = s.hasher.Write(data)
if err != nil {
return nil, err
}
ref = s.hasher.Sum(nil)
head := make([]byte, 8) head := make([]byte, 8)
binary.LittleEndian.PutUint64(head, uint64(span)) binary.LittleEndian.PutUint64(head, uint64(span))
tail := s.buffer[s.cursors[lvl+1]:s.cursors[lvl]] tail := s.buffer[s.cursors[lvl+1]:s.cursors[lvl]]
chunkData = append(head, tail...) chunkData = append(head, tail...)
// assemble chunk and put in store
addr := swarm.NewAddress(ref)
c := chunkData c := chunkData
var encryptionKey encryption.Key var encryptionKey encryption.Key
if s.toEncrypt { if s.toEncrypt {
var err error
c, encryptionKey, err = s.encryptChunkData(chunkData) c, encryptionKey, err = s.encryptChunkData(chunkData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
s.hasher.Reset()
err := s.hasher.SetSpanBytes(c[:8])
if err != nil {
return nil, err
}
_, err = s.hasher.Write(c[8:])
if err != nil {
return nil, err
}
ref := s.hasher.Sum(nil)
addr = swarm.NewAddress(ref)
ch := swarm.NewChunk(addr, c) ch := swarm.NewChunk(addr, c)
_, err = s.putter.Put(s.ctx, storage.ModePutUpload, ch) _, err = s.putter.Put(s.ctx, storage.ModePutUpload, ch)
if err != nil { if err != nil {
......
...@@ -137,6 +137,23 @@ func ResponseDirectSendHeadersAndReceiveHeaders(t *testing.T, client *http.Clien ...@@ -137,6 +137,23 @@ func ResponseDirectSendHeadersAndReceiveHeaders(t *testing.T, client *http.Clien
return resp.Header return resp.Header
} }
// ResponseDirectSendHeadersAndDontCheckResponse sends a request with the given headers and does not check for the returned reference.
// this is useful in tests which does not know the return reference, for ex: when encryption flag is set
func ResponseDirectSendHeadersAndDontCheckResponse(t *testing.T, client *http.Client, method, url string, body io.Reader, responseCode int, headers http.Header) (http.Header, []byte) {
t.Helper()
resp := request(t, client, method, url, body, responseCode, headers)
defer resp.Body.Close()
got, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
got = bytes.TrimSpace(got)
return resp.Header, got
}
func ResponseUnmarshal(t *testing.T, client *http.Client, method, url string, body io.Reader, responseCode int, response interface{}) { func ResponseUnmarshal(t *testing.T, client *http.Client, method, url string, body io.Reader, responseCode int, response interface{}) {
t.Helper() t.Helper()
......
...@@ -11,18 +11,20 @@ import ( ...@@ -11,18 +11,20 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/encryption"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
) )
const ( const (
SpanSize = 8 SpanSize = 8
SectionSize = 32 SectionSize = 32
Branches = 128 Branches = 128
ChunkSize = SectionSize * Branches ChunkSize = SectionSize * Branches
HashSize = 32 HashSize = 32
MaxPO uint8 = 15 EncryptedReferenceSize = HashSize + encryption.KeyLength
MaxBins = MaxPO + 1 MaxPO uint8 = 15
ChunkWithSpanSize = ChunkSize + SpanSize MaxBins = MaxPO + 1
ChunkWithSpanSize = ChunkSize + SpanSize
) )
var ( var (
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment