Commit f0844278 authored by acud's avatar acud Committed by GitHub

api, debugapi, file, localstore, storage: add pin on upload (#565)

* api, debugapi, file, localstore, storage: allow atomic commit on upload
parent db4a844a
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
...@@ -160,7 +161,7 @@ func putEntry(cmd *cobra.Command, args []string) (err error) { ...@@ -160,7 +161,7 @@ func putEntry(cmd *cobra.Command, args []string) (err error) {
logger.Debugf("metadata contents: %s", metadataBytes) logger.Debugf("metadata contents: %s", metadataBytes)
// set up splitter to process the metadata // set up splitter to process the metadata
s := splitter.NewSimpleSplitter(stores) s := splitter.NewSimpleSplitter(stores, storage.ModePutUpload)
ctx := context.Background() ctx := context.Background()
// first add metadata // first add metadata
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
cmdfile "github.com/ethersphere/bee/cmd/internal/file" cmdfile "github.com/ethersphere/bee/cmd/internal/file"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
...@@ -93,7 +94,7 @@ func Split(cmd *cobra.Command, args []string) (err error) { ...@@ -93,7 +94,7 @@ func Split(cmd *cobra.Command, args []string) (err error) {
} }
// split and rule // split and rule
s := splitter.NewSimpleSplitter(stores) s := splitter.NewSimpleSplitter(stores, storage.ModePutUpload)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
addr, err := s.Split(ctx, infile, inputLength, false) addr, err := s.Split(ctx, infile, inputLength, false)
......
...@@ -6,6 +6,7 @@ package api ...@@ -6,6 +6,7 @@ package api
import ( import (
"net/http" "net/http"
"strings"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
m "github.com/ethersphere/bee/pkg/metrics" m "github.com/ethersphere/bee/pkg/metrics"
...@@ -48,3 +49,16 @@ func New(tags *tags.Tags, storer storage.Storer, corsAllowedOrigins []string, lo ...@@ -48,3 +49,16 @@ func New(tags *tags.Tags, storer storage.Storer, corsAllowedOrigins []string, lo
return s return s
} }
const (
SwarmPinHeader = "Swarm-Pin"
TagHeaderUid = "swarm-tag-uid"
)
// requestModePut returns the desired storage.ModePut for this request based on the request headers.
func requestModePut(r *http.Request) storage.ModePut {
if h := strings.ToLower(r.Header.Get(SwarmPinHeader)); h == "true" {
return storage.ModePutUploadPin
}
return storage.ModePutUpload
}
...@@ -36,7 +36,7 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -36,7 +36,7 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
toEncrypt := strings.ToLower(r.Header.Get(EncryptHeader)) == "true" toEncrypt := strings.ToLower(r.Header.Get(EncryptHeader)) == "true"
sp := splitter.NewSimpleSplitter(s.Storer) sp := splitter.NewSimpleSplitter(s.Storer, requestModePut(r))
address, err := file.SplitWriteAll(ctx, sp, r.Body, r.ContentLength, toEncrypt) address, err := file.SplitWriteAll(ctx, sp, r.Body, r.ContentLength, toEncrypt)
if err != nil { if err != nil {
s.Logger.Debugf("bytes upload: %v", err) s.Logger.Debugf("bytes upload: %v", err)
......
...@@ -23,6 +23,7 @@ import ( ...@@ -23,6 +23,7 @@ import (
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/manifest/jsonmanifest" "github.com/ethersphere/bee/pkg/manifest/jsonmanifest"
"github.com/ethersphere/bee/pkg/storage"
smock "github.com/ethersphere/bee/pkg/storage/mock" smock "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
...@@ -32,7 +33,7 @@ func TestBzz(t *testing.T) { ...@@ -32,7 +33,7 @@ func TestBzz(t *testing.T) {
var ( var (
bzzDownloadResource = func(addr, path string) string { return "/bzz/" + addr + "/" + path } bzzDownloadResource = func(addr, path string) string { return "/bzz/" + addr + "/" + path }
storer = smock.NewStorer() storer = smock.NewStorer()
sp = splitter.NewSimpleSplitter(storer) sp = splitter.NewSimpleSplitter(storer, storage.ModePutUpload)
client = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: storer, Storer: storer,
Tags: tags.NewTags(), Tags: tags.NewTags(),
......
...@@ -13,7 +13,6 @@ import ( ...@@ -13,7 +13,6 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/sctx" "github.com/ethersphere/bee/pkg/sctx"
...@@ -23,12 +22,6 @@ import ( ...@@ -23,12 +22,6 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
// Presence of this header means that it needs to be tagged using the uid
const TagHeaderUid = "swarm-tag-uid"
// Presence of this header in the HTTP request indicates the chunk needs to be pinned.
const PinHeaderName = "swarm-pin"
func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
addr := mux.Vars(r)["addr"] addr := mux.Vars(r)["addr"]
address, err := swarm.ParseHexAddress(addr) address, err := swarm.ParseHexAddress(addr)
...@@ -62,7 +55,7 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -62,7 +55,7 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
seen, err := s.Storer.Put(ctx, storage.ModePutUpload, swarm.NewChunk(address, data)) seen, err := s.Storer.Put(ctx, requestModePut(r), swarm.NewChunk(address, data))
if err != nil { if err != nil {
s.Logger.Debugf("chunk upload: chunk write error: %v, addr %s", err, address) s.Logger.Debugf("chunk upload: chunk write error: %v, addr %s", err, address)
s.Logger.Error("chunk upload: chunk write error") s.Logger.Error("chunk upload: chunk write error")
...@@ -75,18 +68,6 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -75,18 +68,6 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
// Indicate that the chunk is stored // Indicate that the chunk is stored
tag.Inc(tags.StateStored) tag.Inc(tags.StateStored)
// Check if this chunk needs to pinned and pin it
pinHeaderValues := r.Header.Get(PinHeaderName)
if pinHeaderValues != "" && strings.ToLower(pinHeaderValues) == "true" {
err = s.Storer.Set(ctx, storage.ModeSetPin, address)
if err != nil {
s.Logger.Debugf("chunk upload: chunk pinning error: %v, addr %s", err, address)
s.Logger.Error("chunk upload: chunk pinning error")
jsonhttp.InternalServerError(w, "cannot pin chunk")
return
}
}
tag.DoneSplit(address) tag.DoneSplit(address)
w.Header().Set(TagHeaderUid, fmt.Sprint(tag.Uid)) w.Header().Set(TagHeaderUid, fmt.Sprint(tag.Uid))
......
...@@ -36,7 +36,7 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -36,7 +36,7 @@ func TestChunkUploadDownload(t *testing.T) {
invalidContent = []byte("bbaattss") invalidContent = []byte("bbaattss")
mockValidator = validator.NewMockValidator(validHash, validContent) mockValidator = validator.NewMockValidator(validHash, validContent)
tag = tags.NewTags() tag = tags.NewTags()
mockValidatingStorer = mock.NewValidatingStorer(mockValidator, tag) mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
client = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
Tags: tag, Tags: tag,
...@@ -83,7 +83,7 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -83,7 +83,7 @@ func TestChunkUploadDownload(t *testing.T) {
t.Run("pin-invalid-value", func(t *testing.T) { t.Run("pin-invalid-value", func(t *testing.T) {
headers := make(map[string][]string) headers := make(map[string][]string)
headers[api.PinHeaderName] = []string{"hdgdh"} headers[api.SwarmPinHeader] = []string{"hdgdh"}
jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, client, http.MethodPost, resource(validHash), bytes.NewReader(validContent), http.StatusOK, jsonhttp.StatusResponse{ jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, client, http.MethodPost, resource(validHash), bytes.NewReader(validContent), http.StatusOK, jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
...@@ -108,14 +108,14 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -108,14 +108,14 @@ func TestChunkUploadDownload(t *testing.T) {
}) })
t.Run("pin-ok", func(t *testing.T) { t.Run("pin-ok", func(t *testing.T) {
headers := make(map[string][]string) headers := make(map[string][]string)
headers[api.PinHeaderName] = []string{"True"} headers[api.SwarmPinHeader] = []string{"True"}
jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, client, http.MethodPost, resource(validHash), bytes.NewReader(validContent), http.StatusOK, jsonhttp.StatusResponse{ jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, client, http.MethodPost, resource(validHash), bytes.NewReader(validContent), http.StatusOK, jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
}, headers) }, headers)
// Also check if the chunk is pinned // Also check if the chunk is pinned
if mockValidatingStorer.GetModeSet(validHash) != storage.ModeSetPin { if mockValidatingStorer.GetModePut(validHash) != storage.ModePutUploadPin {
t.Fatal("chunk is not pinned") t.Fatal("chunk is not pinned")
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"archive/tar" "archive/tar"
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
...@@ -16,6 +17,9 @@ import ( ...@@ -16,6 +17,9 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/splitter"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/manifest/jsonmanifest" "github.com/ethersphere/bee/pkg/manifest/jsonmanifest"
...@@ -40,7 +44,7 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -40,7 +44,7 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
reference, err := storeDir(ctx, r.Body, s.Storer, s.Logger) reference, err := storeDir(ctx, r.Body, s.Storer, requestModePut(r), s.Logger)
if err != nil { if err != nil {
s.Logger.Errorf("dir upload, store dir") s.Logger.Errorf("dir upload, store dir")
s.Logger.Debugf("dir upload, store dir err: %v", err) s.Logger.Debugf("dir upload, store dir err: %v", err)
...@@ -74,7 +78,7 @@ func validateRequest(r *http.Request) (context.Context, error) { ...@@ -74,7 +78,7 @@ func validateRequest(r *http.Request) (context.Context, error) {
// storeDir stores all files recursively contained in the directory given as a tar // storeDir stores all files recursively contained in the directory given as a tar
// it returns the hash for the uploaded manifest corresponding to the uploaded dir // it returns the hash for the uploaded manifest corresponding to the uploaded dir
func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logger logging.Logger) (swarm.Address, error) { func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, mode storage.ModePut, logger logging.Logger) (swarm.Address, error) {
dirManifest := jsonmanifest.NewManifest() dirManifest := jsonmanifest.NewManifest()
// set up HTTP body reader // set up HTTP body reader
...@@ -87,7 +91,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logge ...@@ -87,7 +91,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logge
if err == io.EOF { if err == io.EOF {
break break
} else if err != nil { } else if err != nil {
return swarm.ZeroAddress, fmt.Errorf("read tar stream error: %w", err) return swarm.ZeroAddress, fmt.Errorf("read tar stream: %w", err)
} }
filePath := fileHeader.Name filePath := fileHeader.Name
...@@ -108,9 +112,9 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logge ...@@ -108,9 +112,9 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logge
contentType: contentType, contentType: contentType,
reader: tarReader, reader: tarReader,
} }
fileReference, err := storeFile(ctx, fileInfo, s) fileReference, err := storeFile(ctx, fileInfo, s, mode)
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("store dir file error: %w", err) return swarm.ZeroAddress, fmt.Errorf("store dir file: %w", err)
} }
logger.Tracef("uploaded dir file %v with reference %v", filePath, fileReference) logger.Tracef("uploaded dir file %v with reference %v", filePath, fileReference)
...@@ -132,7 +136,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logge ...@@ -132,7 +136,7 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logge
// first, serialize into byte array // first, serialize into byte array
b, err := dirManifest.MarshalBinary() b, err := dirManifest.MarshalBinary()
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("manifest serialize error: %w", err) return swarm.ZeroAddress, fmt.Errorf("manifest serialize: %w", err)
} }
// set up reader for manifest file upload // set up reader for manifest file upload
...@@ -144,10 +148,57 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logge ...@@ -144,10 +148,57 @@ func storeDir(ctx context.Context, reader io.ReadCloser, s storage.Storer, logge
contentType: ManifestContentType, contentType: ManifestContentType,
reader: r, reader: r,
} }
manifestReference, err := storeFile(ctx, manifestFileInfo, s) manifestReference, err := storeFile(ctx, manifestFileInfo, s, mode)
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("store manifest error: %w", err) return swarm.ZeroAddress, fmt.Errorf("store manifest: %w", err)
} }
return manifestReference, nil return manifestReference, nil
} }
// storeFile uploads the given file and returns its reference
// this function was extracted from `fileUploadHandler` and should eventually replace its current code
func storeFile(ctx context.Context, fileInfo *fileUploadInfo, s storage.Storer, mode storage.ModePut) (swarm.Address, error) {
v := ctx.Value(toEncryptContextKey{})
toEncrypt, _ := v.(bool) // default is false
// first store the file and get its reference
sp := splitter.NewSimpleSplitter(s, mode)
fr, err := file.SplitWriteAll(ctx, sp, fileInfo.reader, fileInfo.size, toEncrypt)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split file: %w", err)
}
// if filename is still empty, use the file hash as the filename
if fileInfo.name == "" {
fileInfo.name = fr.String()
}
// then store the metadata and get its reference
m := entry.NewMetadata(fileInfo.name)
m.MimeType = fileInfo.contentType
metadataBytes, err := json.Marshal(m)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("metadata marshal: %w", err)
}
sp = splitter.NewSimpleSplitter(s, mode)
mr, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(metadataBytes), int64(len(metadataBytes)), toEncrypt)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split metadata: %w", err)
}
// now join both references (mr, fr) to create an entry and store it
e := entry.New(fr, mr)
fileEntryBytes, err := e.MarshalBinary()
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("entry marshal: %w", err)
}
sp = splitter.NewSimpleSplitter(s, mode)
reference, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)), toEncrypt)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split entry: %w", err)
}
return reference, nil
}
...@@ -51,8 +51,15 @@ type fileUploadResponse struct { ...@@ -51,8 +51,15 @@ type fileUploadResponse struct {
// - multipart http message // - multipart http message
// - other content types as complete file body // - other content types as complete file body
func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
toEncrypt := strings.ToLower(r.Header.Get(EncryptHeader)) == "true" var (
contentType := r.Header.Get("Content-Type") reader io.Reader
fileName, contentLength string
fileSize uint64
mode = requestModePut(r)
toEncrypt = strings.ToLower(r.Header.Get(EncryptHeader)) == "true"
contentType = r.Header.Get("Content-Type")
)
mediaType, params, err := mime.ParseMediaType(contentType) mediaType, params, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
s.Logger.Debugf("file upload: parse content type header %q: %v", contentType, err) s.Logger.Debugf("file upload: parse content type header %q: %v", contentType, err)
...@@ -61,10 +68,6 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -61,10 +68,6 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
var reader io.Reader
var fileName, contentLength string
var fileSize uint64
ta := s.createTag(w, r) ta := s.createTag(w, r)
if ta == nil { if ta == nil {
return return
...@@ -154,7 +157,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -154,7 +157,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
} }
// first store the file and get its reference // first store the file and get its reference
sp := splitter.NewSimpleSplitter(s.Storer) sp := splitter.NewSimpleSplitter(s.Storer, mode)
fr, err := file.SplitWriteAll(ctx, sp, reader, int64(fileSize), toEncrypt) fr, err := file.SplitWriteAll(ctx, sp, reader, int64(fileSize), toEncrypt)
if err != nil { if err != nil {
s.Logger.Debugf("file upload: file store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: file store, file %q: %v", fileName, err)
...@@ -178,7 +181,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -178,7 +181,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
jsonhttp.InternalServerError(w, "metadata marshal error") jsonhttp.InternalServerError(w, "metadata marshal error")
return return
} }
sp = splitter.NewSimpleSplitter(s.Storer) sp = splitter.NewSimpleSplitter(s.Storer, mode)
mr, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(metadataBytes), int64(len(metadataBytes)), toEncrypt) mr, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(metadataBytes), int64(len(metadataBytes)), toEncrypt)
if err != nil { if err != nil {
s.Logger.Debugf("file upload: metadata store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: metadata store, file %q: %v", fileName, err)
...@@ -196,7 +199,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -196,7 +199,7 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
jsonhttp.InternalServerError(w, "entry marshal error") jsonhttp.InternalServerError(w, "entry marshal error")
return return
} }
sp = splitter.NewSimpleSplitter(s.Storer) sp = splitter.NewSimpleSplitter(s.Storer, mode)
reference, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)), toEncrypt) reference, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)), toEncrypt)
if err != nil { if err != nil {
s.Logger.Debugf("file upload: entry store, file %q: %v", fileName, err) s.Logger.Debugf("file upload: entry store, file %q: %v", fileName, err)
...@@ -223,53 +226,6 @@ type fileUploadInfo struct { ...@@ -223,53 +226,6 @@ type fileUploadInfo struct {
reader io.Reader reader io.Reader
} }
// storeFile uploads the given file and returns its reference
// this function was extracted from `fileUploadHandler` and should eventually replace its current code
func storeFile(ctx context.Context, fileInfo *fileUploadInfo, s storage.Storer) (swarm.Address, error) {
v := ctx.Value(toEncryptContextKey{})
toEncrypt, _ := v.(bool) // default is false
// first store the file and get its reference
sp := splitter.NewSimpleSplitter(s)
fr, err := file.SplitWriteAll(ctx, sp, fileInfo.reader, fileInfo.size, toEncrypt)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split file error: %w", err)
}
// if filename is still empty, use the file hash as the filename
if fileInfo.name == "" {
fileInfo.name = fr.String()
}
// then store the metadata and get its reference
m := entry.NewMetadata(fileInfo.name)
m.MimeType = fileInfo.contentType
metadataBytes, err := json.Marshal(m)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("metadata marshal error: %w", err)
}
sp = splitter.NewSimpleSplitter(s)
mr, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(metadataBytes), int64(len(metadataBytes)), toEncrypt)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split metadata error: %w", err)
}
// now join both references (mr, fr) to create an entry and store it
e := entry.New(fr, mr)
fileEntryBytes, err := e.MarshalBinary()
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("entry marshal error: %w", err)
}
sp = splitter.NewSimpleSplitter(s)
reference, err := file.SplitWriteAll(ctx, sp, bytes.NewReader(fileEntryBytes), int64(len(fileEntryBytes)), toEncrypt)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("split entry error: %w", err)
}
return reference, nil
}
// fileDownloadHandler downloads the file given the entry's reference. // fileDownloadHandler downloads the file given the entry's reference.
func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
addr := mux.Vars(r)["addr"] addr := mux.Vars(r)["addr"]
......
...@@ -23,21 +23,25 @@ import ( ...@@ -23,21 +23,25 @@ import (
// invalid chunk address case etc. This test case has to be run in sequence and // invalid chunk address case etc. This test case has to be run in sequence and
// it assumes some state of the DB before another case is run. // it assumes some state of the DB before another case is run.
func TestPinChunkHandler(t *testing.T) { func TestPinChunkHandler(t *testing.T) {
resource := func(addr swarm.Address) string { return "/chunks/" + addr.String() } var (
hash := swarm.MustParseHexAddress("aabbcc") resource = func(addr swarm.Address) string { return "/chunks/" + addr.String() }
data := []byte("bbaatt") hash = swarm.MustParseHexAddress("aabbcc")
mockValidator := validator.NewMockValidator(hash, data) data = []byte("bbaatt")
tag := tags.NewTags() mockValidator = validator.NewMockValidator(hash, data)
mockValidatingStorer := mock.NewValidatingStorer(mockValidator, tag) mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
debugTestServer := newTestServer(t, testServerOptions{ tag = tags.NewTags()
Storer: mockValidatingStorer,
Tags: tag, debugTestServer = newTestServer(t, testServerOptions{
}) Storer: mockValidatingStorer,
// This server is used to store chunks Tags: tag,
bzzTestServer := newBZZTestServer(t, testServerOptions{ })
Storer: mockValidatingStorer,
Tags: tag, // This server is used to store chunks
}) bzzTestServer = newBZZTestServer(t, testServerOptions{
Storer: mockValidatingStorer,
Tags: tag,
})
)
// bad chunk address // bad chunk address
t.Run("pin-bad-address", func(t *testing.T) { t.Run("pin-bad-address", func(t *testing.T) {
......
...@@ -16,7 +16,6 @@ import ( ...@@ -16,7 +16,6 @@ import (
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
mp "github.com/ethersphere/bee/pkg/pusher/mock" mp "github.com/ethersphere/bee/pkg/pusher/mock"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/storage/mock/validator"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"gitlab.com/nolash/go-mockbytes" "gitlab.com/nolash/go-mockbytes"
...@@ -34,17 +33,16 @@ func TestTags(t *testing.T) { ...@@ -34,17 +33,16 @@ func TestTags(t *testing.T) {
tagResourceUUid = func(uuid uint64) string { return "/tags/" + strconv.FormatUint(uuid, 10) } tagResourceUUid = func(uuid uint64) string { return "/tags/" + strconv.FormatUint(uuid, 10) }
validHash = swarm.MustParseHexAddress("aabbcc") validHash = swarm.MustParseHexAddress("aabbcc")
validContent = []byte("bbaatt") validContent = []byte("bbaatt")
mockValidator = validator.NewMockValidator(validHash, validContent)
tag = tags.NewTags() tag = tags.NewTags()
mockValidatingStorer = mock.NewValidatingStorer(mockValidator, tag) mockStorer = mock.NewStorer()
mockPusher = mp.NewMockPusher(tag) mockPusher = mp.NewMockPusher(tag)
ts = newTestServer(t, testServerOptions{ ts = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockStorer,
Tags: tag, Tags: tag,
}) })
// This server is used to store chunks // This server is used to store chunks
apiClient = newBZZTestServer(t, testServerOptions{ apiClient = newBZZTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockStorer,
Tags: tag, Tags: tag,
}) })
) )
...@@ -106,11 +104,8 @@ func TestTags(t *testing.T) { ...@@ -106,11 +104,8 @@ func TestTags(t *testing.T) {
isTagFoundInResponse(t, rcvdHeaders, &ta) isTagFoundInResponse(t, rcvdHeaders, &ta)
// Add asecond valid contentto validator
secondValidHash := swarm.MustParseHexAddress("deadbeaf") secondValidHash := swarm.MustParseHexAddress("deadbeaf")
secondValidContent := []byte("123456") secondValidContent := []byte("123456")
mockValidator.AddPair(secondValidHash, secondValidContent)
sentHheaders = make(http.Header) sentHheaders = make(http.Header)
sentHheaders.Set(api.TagHeaderUid, strconv.FormatUint(uint64(ta.Uid), 10)) sentHheaders.Set(api.TagHeaderUid, strconv.FormatUint(uint64(ta.Uid), 10))
rcvdHeaders = jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, apiClient, http.MethodPost, resource(secondValidHash), bytes.NewReader(secondValidContent), http.StatusOK, jsonhttp.StatusResponse{ rcvdHeaders = jsonhttptest.ResponseDirectSendHeadersAndReceiveHeaders(t, apiClient, http.MethodPost, resource(secondValidHash), bytes.NewReader(secondValidContent), http.StatusOK, jsonhttp.StatusResponse{
...@@ -226,16 +221,8 @@ func TestTags(t *testing.T) { ...@@ -226,16 +221,8 @@ func TestTags(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
chunkAddress := swarm.MustParseHexAddress("c10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef")
rootBytes := swarm.MustParseHexAddress("c10090961e7682a10890c334d759a28426647141213abda93b096b892824d2ef").Bytes()
rootChunk := make([]byte, 64)
copy(rootChunk[:32], rootBytes)
copy(rootChunk[32:], rootBytes)
rootAddress := swarm.MustParseHexAddress("5e2a21902f51438be1adbd0e29e1bd34c53a21d3120aefa3c7275129f2f88de9") rootAddress := swarm.MustParseHexAddress("5e2a21902f51438be1adbd0e29e1bd34c53a21d3120aefa3c7275129f2f88de9")
mockValidator.AddPair(chunkAddress, dataChunk)
mockValidator.AddPair(rootAddress, rootChunk)
content := make([]byte, swarm.ChunkSize*2) content := make([]byte, swarm.ChunkSize*2)
copy(content[swarm.ChunkSize:], dataChunk) copy(content[swarm.ChunkSize:], dataChunk)
copy(content[:swarm.ChunkSize], dataChunk) copy(content[:swarm.ChunkSize], dataChunk)
...@@ -260,8 +247,8 @@ func TestTags(t *testing.T) { ...@@ -260,8 +247,8 @@ func TestTags(t *testing.T) {
if finalTag.Total != 3 { if finalTag.Total != 3 {
t.Errorf("tag total count mismatch. got %d want %d", finalTag.Total, 3) t.Errorf("tag total count mismatch. got %d want %d", finalTag.Total, 3)
} }
if finalTag.Seen != 3 { if finalTag.Seen != 1 {
t.Errorf("tag seen count mismatch. got %d want %d", finalTag.Seen, 3) t.Errorf("tag seen count mismatch. got %d want %d", finalTag.Seen, 1)
} }
if finalTag.Stored != 3 { if finalTag.Stored != 3 {
t.Errorf("tag stored count mismatch. got %d want %d", finalTag.Stored, 3) t.Errorf("tag stored count mismatch. got %d want %d", finalTag.Stored, 3)
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
test "github.com/ethersphere/bee/pkg/file/testing" test "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -44,7 +45,7 @@ func testSplitThenJoin(t *testing.T) { ...@@ -44,7 +45,7 @@ func testSplitThenJoin(t *testing.T) {
paramstring = strings.Split(t.Name(), "/") paramstring = strings.Split(t.Name(), "/")
dataIdx, _ = strconv.ParseInt(paramstring[1], 10, 0) dataIdx, _ = strconv.ParseInt(paramstring[1], 10, 0)
store = mock.NewStorer() store = mock.NewStorer()
s = splitter.NewSimpleSplitter(store) s = splitter.NewSimpleSplitter(store, storage.ModePutUpload)
j = joiner.NewSimpleJoiner(store) j = joiner.NewSimpleJoiner(store)
data, _ = test.GetVector(t, int(dataIdx)) data, _ = test.GetVector(t, int(dataIdx))
) )
......
...@@ -143,7 +143,7 @@ func TestEncryptionAndDecryption(t *testing.T) { ...@@ -143,7 +143,7 @@ func TestEncryptionAndDecryption(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
s := splitter.NewSimpleSplitter(store) s := splitter.NewSimpleSplitter(store, storage.ModePutUpload)
testDataReader := file.NewSimpleReadCloser(testData) testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), true) resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), true)
if err != nil { if err != nil {
......
...@@ -13,7 +13,6 @@ import ( ...@@ -13,7 +13,6 @@ import (
"github.com/ethersphere/bee/pkg/encryption" "github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bmt" "github.com/ethersphere/bmt"
...@@ -21,6 +20,10 @@ import ( ...@@ -21,6 +20,10 @@ import (
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
) )
type Putter interface {
Put(context.Context, swarm.Chunk) ([]bool, error)
}
// maximum amount of file tree levels this file hasher component can handle // maximum amount of file tree levels this file hasher component can handle
// (128 ^ (9 - 1)) * 4096 = 295147905179352825856 bytes // (128 ^ (9 - 1)) * 4096 = 295147905179352825856 bytes
const levelBufferLimit = 9 const levelBufferLimit = 9
...@@ -41,7 +44,7 @@ func hashFunc() hash.Hash { ...@@ -41,7 +44,7 @@ func hashFunc() hash.Hash {
// error and will may result in undefined result. // error and will may result in undefined result.
type SimpleSplitterJob struct { type SimpleSplitterJob struct {
ctx context.Context ctx context.Context
putter storage.Putter putter Putter
spanLength int64 // target length of data spanLength int64 // target length of data
length int64 // number of bytes written to the data level of the hasher length int64 // number of bytes written to the data level of the hasher
sumCounts []int // number of sums performed, indexed per level sumCounts []int // number of sums performed, indexed per level
...@@ -56,7 +59,7 @@ type SimpleSplitterJob struct { ...@@ -56,7 +59,7 @@ type SimpleSplitterJob struct {
// NewSimpleSplitterJob creates a new SimpleSplitterJob. // NewSimpleSplitterJob creates a new SimpleSplitterJob.
// //
// The spanLength is the length of the data that will be written. // The spanLength is the length of the data that will be written.
func NewSimpleSplitterJob(ctx context.Context, putter storage.Putter, spanLength int64, toEncrypt bool) *SimpleSplitterJob { func NewSimpleSplitterJob(ctx context.Context, putter Putter, spanLength int64, toEncrypt bool) *SimpleSplitterJob {
hashSize := swarm.HashSize hashSize := swarm.HashSize
refSize := int64(hashSize) refSize := int64(hashSize)
if toEncrypt { if toEncrypt {
...@@ -184,7 +187,7 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) { ...@@ -184,7 +187,7 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
ch = swarm.NewChunk(addr, c) ch = swarm.NewChunk(addr, c)
} }
seen, err := s.putter.Put(s.ctx, storage.ModePutUpload, ch) seen, err := s.putter.Put(s.ctx, ch)
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(seen) > 0 && seen[0] { } else if len(seen) > 0 && seen[0] {
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/ethersphere/bee/pkg/file/splitter/internal" "github.com/ethersphere/bee/pkg/file/splitter/internal"
test "github.com/ethersphere/bee/pkg/file/testing" test "github.com/ethersphere/bee/pkg/file/testing"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -21,17 +22,30 @@ var ( ...@@ -21,17 +22,30 @@ var (
end = test.GetVectorCount() end = test.GetVectorCount()
) )
type putWrapper struct {
putter func(context.Context, swarm.Chunk) ([]bool, error)
}
func (p putWrapper) Put(ctx context.Context, ch swarm.Chunk) ([]bool, error) {
return p.putter(ctx, ch)
}
// TestSplitterJobPartialSingleChunk passes sub-chunk length data to the splitter, // TestSplitterJobPartialSingleChunk passes sub-chunk length data to the splitter,
// verifies the correct hash is returned, and that write after Sum/complete Write // verifies the correct hash is returned, and that write after Sum/complete Write
// returns error. // returns error.
func TestSplitterJobPartialSingleChunk(t *testing.T) { func TestSplitterJobPartialSingleChunk(t *testing.T) {
store := mock.NewStorer() store := mock.NewStorer()
putter := putWrapper{
putter: func(ctx context.Context, ch swarm.Chunk) ([]bool, error) {
return store.Put(ctx, storage.ModePutUpload, ch)
},
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
data := []byte("foo") data := []byte("foo")
j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data)), false) j := internal.NewSimpleSplitterJob(ctx, putter, int64(len(data)), false)
c, err := j.Write(data) c, err := j.Write(data)
if err != nil { if err != nil {
...@@ -69,12 +83,17 @@ func testSplitterJobVector(t *testing.T) { ...@@ -69,12 +83,17 @@ func testSplitterJobVector(t *testing.T) {
paramstring = strings.Split(t.Name(), "/") paramstring = strings.Split(t.Name(), "/")
dataIdx, _ = strconv.ParseInt(paramstring[1], 10, 0) dataIdx, _ = strconv.ParseInt(paramstring[1], 10, 0)
store = mock.NewStorer() store = mock.NewStorer()
putter = putWrapper{
putter: func(ctx context.Context, ch swarm.Chunk) ([]bool, error) {
return store.Put(ctx, storage.ModePutUpload, ch)
},
}
) )
data, expect := test.GetVector(t, int(dataIdx)) data, expect := test.GetVector(t, int(dataIdx))
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
j := internal.NewSimpleSplitterJob(ctx, store, int64(len(data)), false) j := internal.NewSimpleSplitterJob(ctx, putter, int64(len(data)), false)
for i := 0; i < len(data); i += swarm.ChunkSize { for i := 0; i < len(data); i += swarm.ChunkSize {
l := swarm.ChunkSize l := swarm.ChunkSize
......
...@@ -16,15 +16,27 @@ import ( ...@@ -16,15 +16,27 @@ import (
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
type putWrapper struct {
putter func(context.Context, swarm.Chunk) ([]bool, error)
}
func (p putWrapper) Put(ctx context.Context, ch swarm.Chunk) ([]bool, error) {
return p.putter(ctx, ch)
}
// simpleSplitter wraps a non-optimized implementation of file.Splitter // simpleSplitter wraps a non-optimized implementation of file.Splitter
type simpleSplitter struct { type simpleSplitter struct {
putter storage.Putter putter internal.Putter
} }
// NewSimpleSplitter creates a new SimpleSplitter // NewSimpleSplitter creates a new SimpleSplitter
func NewSimpleSplitter(putter storage.Putter) file.Splitter { func NewSimpleSplitter(storePutter storage.Putter, mode storage.ModePut) file.Splitter {
return &simpleSplitter{ return &simpleSplitter{
putter: putter, putter: putWrapper{
putter: func(ctx context.Context, ch swarm.Chunk) ([]bool, error) {
return storePutter.Put(ctx, mode, ch)
},
},
} }
} }
......
...@@ -23,7 +23,7 @@ import ( ...@@ -23,7 +23,7 @@ import (
func TestSplitIncomplete(t *testing.T) { func TestSplitIncomplete(t *testing.T) {
testData := make([]byte, 42) testData := make([]byte, 42)
store := mock.NewStorer() store := mock.NewStorer()
s := splitter.NewSimpleSplitter(store) s := splitter.NewSimpleSplitter(store, storage.ModePutUpload)
testDataReader := file.NewSimpleReadCloser(testData) testDataReader := file.NewSimpleReadCloser(testData)
_, err := s.Split(context.Background(), testDataReader, 41, false) _, err := s.Split(context.Background(), testDataReader, 41, false)
...@@ -42,7 +42,7 @@ func TestSplitSingleChunk(t *testing.T) { ...@@ -42,7 +42,7 @@ func TestSplitSingleChunk(t *testing.T) {
} }
store := mock.NewStorer() store := mock.NewStorer()
s := splitter.NewSimpleSplitter(store) s := splitter.NewSimpleSplitter(store, storage.ModePutUpload)
testDataReader := file.NewSimpleReadCloser(testData) testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), false) resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), false)
...@@ -74,7 +74,7 @@ func TestSplitThreeLevels(t *testing.T) { ...@@ -74,7 +74,7 @@ func TestSplitThreeLevels(t *testing.T) {
} }
store := mock.NewStorer() store := mock.NewStorer()
s := splitter.NewSimpleSplitter(store) s := splitter.NewSimpleSplitter(store, storage.ModePutUpload)
testDataReader := file.NewSimpleReadCloser(testData) testDataReader := file.NewSimpleReadCloser(testData)
resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), false) resultAddress, err := s.Split(context.Background(), testDataReader, int64(len(testData)), false)
...@@ -131,7 +131,7 @@ func TestUnalignedSplit(t *testing.T) { ...@@ -131,7 +131,7 @@ func TestUnalignedSplit(t *testing.T) {
} }
// perform the split in a separate thread // perform the split in a separate thread
sp := splitter.NewSimpleSplitter(storer) sp := splitter.NewSimpleSplitter(storer, storage.ModePutUpload)
ctx := context.Background() ctx := context.Background()
doneC := make(chan swarm.Address) doneC := make(chan swarm.Address)
errC := make(chan error) errC := make(chan error)
......
...@@ -342,6 +342,24 @@ func newGCIndexTest(db *DB, chunk swarm.Chunk, storeTimestamp, accessTimestamp i ...@@ -342,6 +342,24 @@ func newGCIndexTest(db *DB, chunk swarm.Chunk, storeTimestamp, accessTimestamp i
} }
} }
// newPinIndexTest returns a test function that validates if the right
// chunk values are in the pin index.
func newPinIndexTest(db *DB, chunk swarm.Chunk, wantError error) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()
item, err := db.pinIndex.Get(shed.Item{
Address: chunk.Address().Bytes(),
})
if !errors.Is(err, wantError) {
t.Errorf("got error %v, want %v", err, wantError)
}
if err == nil {
validateItem(t, item, chunk.Address().Bytes(), nil, 0, 0)
}
}
}
// newItemsCountTest returns a test function that validates if // newItemsCountTest returns a test function that validates if
// an index contains expected number of key/value pairs. // an index contains expected number of key/value pairs.
func newItemsCountTest(i shed.Index, want int) func(t *testing.T) { func newItemsCountTest(i shed.Index, want int) func(t *testing.T) {
......
...@@ -87,7 +87,7 @@ func (db *DB) put(mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err e ...@@ -87,7 +87,7 @@ func (db *DB) put(mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err e
gcSizeChange += c gcSizeChange += c
} }
case storage.ModePutUpload: case storage.ModePutUpload, storage.ModePutUploadPin:
for i, ch := range chs { for i, ch := range chs {
if containsChunk(ch.Address(), chs[:i]...) { if containsChunk(ch.Address(), chs[:i]...) {
exist[i] = true exist[i] = true
...@@ -105,6 +105,12 @@ func (db *DB) put(mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err e ...@@ -105,6 +105,12 @@ func (db *DB) put(mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err e
triggerPushFeed = true triggerPushFeed = true
} }
gcSizeChange += c gcSizeChange += c
if mode == storage.ModePutUploadPin {
err = db.setPin(batch, ch.Address())
if err != nil {
return nil, err
}
}
} }
case storage.ModePutSync: case storage.ModePutSync:
......
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/syndtr/goleveldb/leveldb"
) )
// TestModePutRequest validates ModePutRequest index values on the provided DB. // TestModePutRequest validates ModePutRequest index values on the provided DB.
...@@ -108,6 +109,7 @@ func TestModePutSync(t *testing.T) { ...@@ -108,6 +109,7 @@ func TestModePutSync(t *testing.T) {
newRetrieveIndexesTest(db, ch, wantTimestamp, 0)(t) newRetrieveIndexesTest(db, ch, wantTimestamp, 0)(t)
newPullIndexTest(db, ch, binIDs[po], nil)(t) newPullIndexTest(db, ch, binIDs[po], nil)(t)
newPinIndexTest(db, ch, leveldb.ErrNotFound)(t)
} }
}) })
} }
...@@ -140,6 +142,40 @@ func TestModePutUpload(t *testing.T) { ...@@ -140,6 +142,40 @@ func TestModePutUpload(t *testing.T) {
newRetrieveIndexesTest(db, ch, wantTimestamp, 0)(t) newRetrieveIndexesTest(db, ch, wantTimestamp, 0)(t)
newPullIndexTest(db, ch, binIDs[po], nil)(t) newPullIndexTest(db, ch, binIDs[po], nil)(t)
newPushIndexTest(db, ch, wantTimestamp, nil)(t) newPushIndexTest(db, ch, wantTimestamp, nil)(t)
newPinIndexTest(db, ch, leveldb.ErrNotFound)(t)
}
})
}
}
// TestModePutUploadPin validates ModePutUploadPin index values on the provided DB.
func TestModePutUploadPin(t *testing.T) {
for _, tc := range multiChunkTestCases {
t.Run(tc.name, func(t *testing.T) {
db := newTestDB(t, nil)
wantTimestamp := time.Now().UTC().UnixNano()
defer setNow(func() (t int64) {
return wantTimestamp
})()
chunks := generateTestRandomChunks(tc.count)
_, err := db.Put(context.Background(), storage.ModePutUploadPin, chunks...)
if err != nil {
t.Fatal(err)
}
binIDs := make(map[uint8]uint64)
for _, ch := range chunks {
po := db.po(ch.Address())
binIDs[po]++
newRetrieveIndexesTest(db, ch, wantTimestamp, 0)(t)
newPullIndexTest(db, ch, binIDs[po], nil)(t)
newPushIndexTest(db, ch, wantTimestamp, nil)(t)
newPinIndexTest(db, ch, nil)(t)
} }
}) })
} }
......
...@@ -64,13 +64,13 @@ func (db *DB) PinnedChunks(ctx context.Context, cursor swarm.Address) (pinnedChu ...@@ -64,13 +64,13 @@ func (db *DB) PinnedChunks(ctx context.Context, cursor swarm.Address) (pinnedChu
return pinnedChunks, err return pinnedChunks, err
} }
// Pinner returns the pin counter given a swarm address, provided that the // PinInfo returns the pin counter for a given swarm address, provided that the
// address has to be pinned already. // address has been pinned.
func (db *DB) PinInfo(address swarm.Address) (uint64, error) { func (db *DB) PinInfo(address swarm.Address) (uint64, error) {
it := shed.Item{ out, err := db.pinIndex.Get(shed.Item{
Address: address.Bytes(), Address: address.Bytes(),
} })
out, err := db.pinIndex.Get(it)
if err != nil { if err != nil {
if errors.Is(err, leveldb.ErrNotFound) { if errors.Is(err, leveldb.ErrNotFound) {
return 0, storage.ErrNotFound return 0, storage.ErrNotFound
......
...@@ -11,22 +11,19 @@ import ( ...@@ -11,22 +11,19 @@ import (
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags"
) )
var _ storage.Storer = (*MockStorer)(nil) var _ storage.Storer = (*MockStorer)(nil)
type MockStorer struct { type MockStorer struct {
store map[string][]byte store map[string][]byte
modePut map[string]storage.ModePut
modeSet map[string]storage.ModeSet modeSet map[string]storage.ModeSet
modeSetMu sync.Mutex
pinnedAddress []swarm.Address // Stores the pinned address pinnedAddress []swarm.Address // Stores the pinned address
pinnedCounter []uint64 // and its respective counter. These are stored as slices to preserve the order. pinnedCounter []uint64 // and its respective counter. These are stored as slices to preserve the order.
pinSetMu sync.Mutex
subpull []storage.Descriptor subpull []storage.Descriptor
partialInterval bool partialInterval bool
validator swarm.Validator validator swarm.Validator
tags *tags.Tags
morePull chan struct{} morePull chan struct{}
mtx sync.Mutex mtx sync.Mutex
quit chan struct{} quit chan struct{}
...@@ -49,9 +46,9 @@ func WithBaseAddress(a swarm.Address) Option { ...@@ -49,9 +46,9 @@ func WithBaseAddress(a swarm.Address) Option {
}) })
} }
func WithTags(t *tags.Tags) Option { func WithValidator(v swarm.Validator) Option {
return optionFunc(func(m *MockStorer) { return optionFunc(func(m *MockStorer) {
m.tags = t m.validator = v
}) })
} }
...@@ -63,12 +60,12 @@ func WithPartialInterval(v bool) Option { ...@@ -63,12 +60,12 @@ func WithPartialInterval(v bool) Option {
func NewStorer(opts ...Option) *MockStorer { func NewStorer(opts ...Option) *MockStorer {
s := &MockStorer{ s := &MockStorer{
store: make(map[string][]byte), store: make(map[string][]byte),
modeSet: make(map[string]storage.ModeSet), modePut: make(map[string]storage.ModePut),
modeSetMu: sync.Mutex{}, modeSet: make(map[string]storage.ModeSet),
morePull: make(chan struct{}), morePull: make(chan struct{}),
quit: make(chan struct{}), quit: make(chan struct{}),
bins: make([]uint64, swarm.MaxBins), bins: make([]uint64, swarm.MaxBins),
} }
for _, v := range opts { for _, v := range opts {
...@@ -78,27 +75,6 @@ func NewStorer(opts ...Option) *MockStorer { ...@@ -78,27 +75,6 @@ func NewStorer(opts ...Option) *MockStorer {
return s return s
} }
func NewValidatingStorer(v swarm.Validator, tags *tags.Tags) *MockStorer {
return &MockStorer{
store: make(map[string][]byte),
modeSet: make(map[string]storage.ModeSet),
modeSetMu: sync.Mutex{},
pinSetMu: sync.Mutex{},
validator: v,
tags: tags,
}
}
func NewTagsStorer(tags *tags.Tags) *MockStorer {
return &MockStorer{
store: make(map[string][]byte),
modeSet: make(map[string]storage.ModeSet),
modeSetMu: sync.Mutex{},
pinSetMu: sync.Mutex{},
tags: tags,
}
}
func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) { func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) {
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock() defer m.mtx.Unlock()
...@@ -114,26 +90,23 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm ...@@ -114,26 +90,23 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock() defer m.mtx.Unlock()
for _, ch := range chs { exist = make([]bool, len(chs))
for i, ch := range chs {
if m.validator != nil { if m.validator != nil {
if !m.validator.Validate(ch) { if !m.validator.Validate(ch) {
return nil, storage.ErrInvalidChunk return nil, storage.ErrInvalidChunk
} }
} }
m.store[ch.Address().String()] = ch.Data() exist[i], err = m.has(ctx, ch.Address())
yes, err := m.has(ctx, ch.Address())
if err != nil { if err != nil {
exist = append(exist, false) return exist, err
continue
} }
if yes { if !exist[i] {
exist = append(exist, true)
} else {
po := swarm.Proximity(ch.Address().Bytes(), m.baseAddress) po := swarm.Proximity(ch.Address().Bytes(), m.baseAddress)
m.bins[po]++ m.bins[po]++
exist = append(exist, false)
} }
m.store[ch.Address().String()] = ch.Data()
m.modePut[ch.Address().String()] = mode
} }
return exist, nil return exist, nil
} }
...@@ -158,10 +131,8 @@ func (m *MockStorer) HasMulti(ctx context.Context, addrs ...swarm.Address) (yes ...@@ -158,10 +131,8 @@ func (m *MockStorer) HasMulti(ctx context.Context, addrs ...swarm.Address) (yes
} }
func (m *MockStorer) Set(ctx context.Context, mode storage.ModeSet, addrs ...swarm.Address) (err error) { func (m *MockStorer) Set(ctx context.Context, mode storage.ModeSet, addrs ...swarm.Address) (err error) {
m.modeSetMu.Lock() m.mtx.Lock()
m.pinSetMu.Lock() defer m.mtx.Unlock()
defer m.modeSetMu.Unlock()
defer m.pinSetMu.Unlock()
for _, addr := range addrs { for _, addr := range addrs {
m.modeSet[addr.String()] = mode m.modeSet[addr.String()] = mode
switch mode { switch mode {
...@@ -202,10 +173,18 @@ func (m *MockStorer) Set(ctx context.Context, mode storage.ModeSet, addrs ...swa ...@@ -202,10 +173,18 @@ func (m *MockStorer) Set(ctx context.Context, mode storage.ModeSet, addrs ...swa
} }
return nil return nil
} }
func (m *MockStorer) GetModePut(addr swarm.Address) (mode storage.ModePut) {
m.mtx.Lock()
defer m.mtx.Unlock()
if mode, ok := m.modePut[addr.String()]; ok {
return mode
}
return mode
}
func (m *MockStorer) GetModeSet(addr swarm.Address) (mode storage.ModeSet) { func (m *MockStorer) GetModeSet(addr swarm.Address) (mode storage.ModeSet) {
m.modeSetMu.Lock() m.mtx.Lock()
defer m.modeSetMu.Unlock() defer m.mtx.Unlock()
if mode, ok := m.modeSet[addr.String()]; ok { if mode, ok := m.modeSet[addr.String()]; ok {
return mode return mode
} }
...@@ -289,8 +268,8 @@ func (m *MockStorer) SubscribePush(ctx context.Context) (c <-chan swarm.Chunk, s ...@@ -289,8 +268,8 @@ func (m *MockStorer) SubscribePush(ctx context.Context) (c <-chan swarm.Chunk, s
} }
func (m *MockStorer) PinnedChunks(ctx context.Context, cursor swarm.Address) (pinnedChunks []*storage.Pinner, err error) { func (m *MockStorer) PinnedChunks(ctx context.Context, cursor swarm.Address) (pinnedChunks []*storage.Pinner, err error) {
m.pinSetMu.Lock() m.mtx.Lock()
defer m.pinSetMu.Unlock() defer m.mtx.Unlock()
if len(m.pinnedAddress) == 0 { if len(m.pinnedAddress) == 0 {
return pinnedChunks, nil return pinnedChunks, nil
} }
...@@ -308,8 +287,8 @@ func (m *MockStorer) PinnedChunks(ctx context.Context, cursor swarm.Address) (pi ...@@ -308,8 +287,8 @@ func (m *MockStorer) PinnedChunks(ctx context.Context, cursor swarm.Address) (pi
} }
func (m *MockStorer) PinInfo(address swarm.Address) (uint64, error) { func (m *MockStorer) PinInfo(address swarm.Address) (uint64, error) {
m.pinSetMu.Lock() m.mtx.Lock()
defer m.pinSetMu.Unlock() defer m.mtx.Unlock()
for i, addr := range m.pinnedAddress { for i, addr := range m.pinnedAddress {
if addr.String() == address.String() { if addr.String() == address.String() {
return m.pinnedCounter[i], nil return m.pinnedCounter[i], nil
......
...@@ -5,8 +5,6 @@ import ( ...@@ -5,8 +5,6 @@ import (
"context" "context"
"testing" "testing"
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/storage/mock/validator" "github.com/ethersphere/bee/pkg/storage/mock/validator"
...@@ -65,7 +63,7 @@ func TestMockValidatingStorer(t *testing.T) { ...@@ -65,7 +63,7 @@ func TestMockValidatingStorer(t *testing.T) {
validContent := []byte("bbaatt") validContent := []byte("bbaatt")
invalidContent := []byte("bbaattss") invalidContent := []byte("bbaattss")
s := mock.NewValidatingStorer(validator.NewMockValidator(validAddress, validContent), tags.NewTags()) s := mock.NewStorer(mock.WithValidator(validator.NewMockValidator(validAddress, validContent)))
ctx := context.Background() ctx := context.Background()
......
...@@ -59,6 +59,8 @@ func (m ModePut) String() string { ...@@ -59,6 +59,8 @@ func (m ModePut) String() string {
return "Sync" return "Sync"
case ModePutUpload: case ModePutUpload:
return "Upload" return "Upload"
case ModePutUploadPin:
return "UploadPin"
default: default:
return "Unknown" return "Unknown"
} }
...@@ -72,6 +74,8 @@ const ( ...@@ -72,6 +74,8 @@ const (
ModePutSync ModePutSync
// ModePutUpload: when a chunk is created by local upload // ModePutUpload: when a chunk is created by local upload
ModePutUpload ModePutUpload
// ModePutUploadPin: the same as ModePutUpload but also pin the chunk atomically with the put
ModePutUploadPin
) )
// ModeSet enumerates different Setter modes. // ModeSet enumerates different Setter modes.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment