Commit daecae2b authored by acud's avatar acud Committed by GitHub

simplify validators (#1043)

parent 57e45fea
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
statestore "github.com/ethersphere/bee/pkg/statestore/mock" statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
testingc "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
) )
...@@ -43,18 +44,16 @@ func TestApiStore(t *testing.T) { ...@@ -43,18 +44,16 @@ func TestApiStore(t *testing.T) {
} }
a := cmdfile.NewApiStore(host, port, false) a := cmdfile.NewApiStore(host, port, false)
chunkAddr := swarm.MustParseHexAddress(hashOfFoo) ch := testingc.GenerateTestRandomChunk()
chunkData := []byte("foo")
ch := swarm.NewChunk(chunkAddr, chunkData)
_, err = a.Put(ctx, storage.ModePutUpload, ch) _, err = a.Put(ctx, storage.ModePutUpload, ch)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, err = storer.Get(ctx, storage.ModeGetRequest, chunkAddr) _, err = storer.Get(ctx, storage.ModeGetRequest, ch.Address())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
chResult, err := a.Get(ctx, storage.ModeGetRequest, chunkAddr) chResult, err := a.Get(ctx, storage.ModeGetRequest, ch.Address())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -12,7 +12,9 @@ import ( ...@@ -12,7 +12,9 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"github.com/ethersphere/bee/pkg/content"
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/sctx" "github.com/ethersphere/bee/pkg/sctx"
...@@ -64,7 +66,17 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -64,7 +66,17 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
seen, err := s.Storer.Put(ctx, requestModePut(r), swarm.NewChunk(address, data)) chunk := swarm.NewChunk(address, data)
if !content.Valid(chunk) {
if !soc.Valid(chunk) {
s.Logger.Debugf("chunk upload: invalid chunk: %s", address)
s.Logger.Error("chunk upload: invalid chunk")
jsonhttp.BadRequest(w, nil)
return
}
}
seen, err := s.Storer.Put(ctx, requestModePut(r), chunk)
if err != nil { if err != nil {
s.Logger.Debugf("chunk upload: chunk write error: %v, addr %s", err, address) s.Logger.Debugf("chunk upload: chunk write error: %v, addr %s", err, address)
s.Logger.Error("chunk upload: chunk write error") s.Logger.Error("chunk upload: chunk write error")
......
...@@ -21,7 +21,7 @@ import ( ...@@ -21,7 +21,7 @@ import (
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/storage/mock/validator" testingc "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -33,50 +33,46 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -33,50 +33,46 @@ func TestChunkUploadDownload(t *testing.T) {
targets = "0x222" targets = "0x222"
resource = func(addr swarm.Address) string { return "/chunks/" + addr.String() } resource = func(addr swarm.Address) string { return "/chunks/" + addr.String() }
resourceTargets = func(addr swarm.Address) string { return "/chunks/" + addr.String() + "?targets=" + targets } resourceTargets = func(addr swarm.Address) string { return "/chunks/" + addr.String() + "?targets=" + targets }
validHash = swarm.MustParseHexAddress("aabbcc") someHash = swarm.MustParseHexAddress("aabbcc")
invalidHash = swarm.MustParseHexAddress("bbccdd") chunk = testingc.GenerateTestRandomChunk()
validContent = []byte("bbaatt")
invalidContent = []byte("bbaattss")
mockValidator = validator.NewMockValidator(validHash, validContent)
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger) tag = tags.NewTags(mockStatestore, logger)
mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator)) mockStorer = mock.NewStorer()
client, _, _ = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockStorer,
Tags: tag, Tags: tag,
}) })
) )
t.Run("invalid hash", func(t *testing.T) { t.Run("invalid chunk", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, resource(invalidHash), http.StatusBadRequest, jsonhttptest.Request(t, client, http.MethodPost, resource(someHash), http.StatusBadRequest,
jsonhttptest.WithRequestBody(bytes.NewReader(validContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "chunk write error", Message: http.StatusText(http.StatusBadRequest),
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
}), }),
) )
// make sure chunk is not retrievable // make sure chunk is not retrievable
_ = request(t, client, http.MethodGet, resource(invalidHash), nil, http.StatusNotFound) _ = request(t, client, http.MethodGet, resource(someHash), nil, http.StatusNotFound)
}) })
t.Run("invalid content", func(t *testing.T) { t.Run("empty chunk", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, resource(invalidHash), http.StatusBadRequest, jsonhttptest.Request(t, client, http.MethodPost, resource(someHash), http.StatusBadRequest,
jsonhttptest.WithRequestBody(bytes.NewReader(invalidContent)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "chunk write error", Message: http.StatusText(http.StatusBadRequest),
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
}), }),
) )
// make sure not retrievable // make sure chunk is not retrievable
_ = request(t, client, http.MethodGet, resource(validHash), nil, http.StatusNotFound) _ = request(t, client, http.MethodGet, resource(someHash), nil, http.StatusNotFound)
}) })
t.Run("ok", func(t *testing.T) { t.Run("ok", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, resource(validHash), http.StatusOK, jsonhttptest.Request(t, client, http.MethodPost, resource(chunk.Address()), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(validContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
...@@ -84,20 +80,20 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -84,20 +80,20 @@ func TestChunkUploadDownload(t *testing.T) {
) )
// try to fetch the same chunk // try to fetch the same chunk
resp := request(t, client, http.MethodGet, resource(validHash), nil, http.StatusOK) resp := request(t, client, http.MethodGet, resource(chunk.Address()), nil, http.StatusOK)
data, err := ioutil.ReadAll(resp.Body) data, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(validContent, data) { if !bytes.Equal(chunk.Data(), data) {
t.Fatal("data retrieved doesnt match uploaded content") t.Fatal("data retrieved doesnt match uploaded content")
} }
}) })
t.Run("pin-invalid-value", func(t *testing.T) { t.Run("pin-invalid-value", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, resource(validHash), http.StatusOK, jsonhttptest.Request(t, client, http.MethodPost, resource(chunk.Address()), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(validContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
...@@ -106,13 +102,13 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -106,13 +102,13 @@ func TestChunkUploadDownload(t *testing.T) {
) )
// Also check if the chunk is NOT pinned // Also check if the chunk is NOT pinned
if mockValidatingStorer.GetModeSet(validHash) == storage.ModeSetPin { if mockStorer.GetModeSet(chunk.Address()) == storage.ModeSetPin {
t.Fatal("chunk should not be pinned") t.Fatal("chunk should not be pinned")
} }
}) })
t.Run("pin-header-missing", func(t *testing.T) { t.Run("pin-header-missing", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, resource(validHash), http.StatusOK, jsonhttptest.Request(t, client, http.MethodPost, resource(chunk.Address()), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(validContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
...@@ -120,13 +116,13 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -120,13 +116,13 @@ func TestChunkUploadDownload(t *testing.T) {
) )
// Also check if the chunk is NOT pinned // Also check if the chunk is NOT pinned
if mockValidatingStorer.GetModeSet(validHash) == storage.ModeSetPin { if mockStorer.GetModeSet(chunk.Address()) == storage.ModeSetPin {
t.Fatal("chunk should not be pinned") t.Fatal("chunk should not be pinned")
} }
}) })
t.Run("pin-ok", func(t *testing.T) { t.Run("pin-ok", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, resource(validHash), http.StatusOK, jsonhttptest.Request(t, client, http.MethodPost, resource(chunk.Address()), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(validContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
...@@ -135,13 +131,13 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -135,13 +131,13 @@ func TestChunkUploadDownload(t *testing.T) {
) )
// Also check if the chunk is pinned // Also check if the chunk is pinned
if mockValidatingStorer.GetModePut(validHash) != storage.ModePutUploadPin { if mockStorer.GetModePut(chunk.Address()) != storage.ModePutUploadPin {
t.Fatal("chunk is not pinned") t.Fatal("chunk is not pinned")
} }
}) })
t.Run("retrieve-targets", func(t *testing.T) { t.Run("retrieve-targets", func(t *testing.T) {
resp := request(t, client, http.MethodGet, resourceTargets(validHash), nil, http.StatusOK) resp := request(t, client, http.MethodGet, resourceTargets(chunk.Address()), nil, http.StatusOK)
// Check if the target is obtained correctly // Check if the target is obtained correctly
if resp.Header.Get(api.TargetsRecoveryHeader) != targets { if resp.Header.Get(api.TargetsRecoveryHeader) != targets {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package api_test package api_test
import ( import (
"bytes"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"testing" "testing"
...@@ -15,11 +16,13 @@ import ( ...@@ -15,11 +16,13 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock" statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
testingc "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
) )
func TestGatewayMode(t *testing.T) { func TestGatewayMode(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
chunk := testingc.GenerateTestRandomChunk()
client, _, _ := newTestServer(t, testServerOptions{ client, _, _ := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(statestore.NewStateStore(), logger), Tags: tags.NewTags(statestore.NewStateStore(), logger),
...@@ -61,7 +64,11 @@ func TestGatewayMode(t *testing.T) { ...@@ -61,7 +64,11 @@ func TestGatewayMode(t *testing.T) {
Code: http.StatusForbidden, Code: http.StatusForbidden,
}) })
jsonhttptest.Request(t, client, http.MethodPost, "/chunks/0773a91efd6547c754fc1d95fb1c62c7d1b47f959c2caa685dfec8736da95c1c", http.StatusOK) // should work without pinning // should work without pinning
jsonhttptest.Request(t, client, http.MethodPost, "/chunks/"+chunk.Address().String(), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
)
jsonhttptest.Request(t, client, http.MethodPost, "/chunks/0773a91efd6547c754fc1d95fb1c62c7d1b47f959c2caa685dfec8736da95c1c", http.StatusForbidden, forbiddenResponseOption, headerOption) jsonhttptest.Request(t, client, http.MethodPost, "/chunks/0773a91efd6547c754fc1d95fb1c62c7d1b47f959c2caa685dfec8736da95c1c", http.StatusForbidden, forbiddenResponseOption, headerOption)
jsonhttptest.Request(t, client, http.MethodPost, "/bytes", http.StatusOK) // should work without pinning jsonhttptest.Request(t, client, http.MethodPost, "/bytes", http.StatusOK) // should work without pinning
......
This diff is collapsed.
...@@ -22,6 +22,7 @@ import ( ...@@ -22,6 +22,7 @@ import (
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
mp "github.com/ethersphere/bee/pkg/pusher/mock" mp "github.com/ethersphere/bee/pkg/pusher/mock"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
testingc "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/swarm/test" "github.com/ethersphere/bee/pkg/swarm/test"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
...@@ -41,8 +42,7 @@ func TestTags(t *testing.T) { ...@@ -41,8 +42,7 @@ func TestTags(t *testing.T) {
bytesResource = "/bytes" bytesResource = "/bytes"
chunksResource = func(addr swarm.Address) string { return "/chunks/" + addr.String() } chunksResource = func(addr swarm.Address) string { return "/chunks/" + addr.String() }
tagsResource = "/tags" tagsResource = "/tags"
someHash = swarm.MustParseHexAddress("aabbcc") chunk = testingc.GenerateTestRandomChunk()
someContent = []byte("bbaatt")
someTagName = "file.jpg" someTagName = "file.jpg"
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
...@@ -81,8 +81,8 @@ func TestTags(t *testing.T) { ...@@ -81,8 +81,8 @@ func TestTags(t *testing.T) {
}) })
t.Run("create tag with invalid id", func(t *testing.T) { t.Run("create tag with invalid id", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, chunksResource(someHash), http.StatusInternalServerError, jsonhttptest.Request(t, client, http.MethodPost, chunksResource(chunk.Address()), http.StatusInternalServerError,
jsonhttptest.WithRequestBody(bytes.NewReader(someContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "cannot get or create tag", Message: "cannot get or create tag",
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
...@@ -110,8 +110,8 @@ func TestTags(t *testing.T) { ...@@ -110,8 +110,8 @@ func TestTags(t *testing.T) {
}) })
t.Run("tag id in chunk upload", func(t *testing.T) { t.Run("tag id in chunk upload", func(t *testing.T) {
rcvdHeaders := jsonhttptest.Request(t, client, http.MethodPost, chunksResource(someHash), http.StatusOK, rcvdHeaders := jsonhttptest.Request(t, client, http.MethodPost, chunksResource(chunk.Address()), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(someContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
...@@ -136,8 +136,8 @@ func TestTags(t *testing.T) { ...@@ -136,8 +136,8 @@ func TestTags(t *testing.T) {
} }
// now upload a chunk and see if we receive a tag with the same id // now upload a chunk and see if we receive a tag with the same id
rcvdHeaders := jsonhttptest.Request(t, client, http.MethodPost, chunksResource(someHash), http.StatusOK, rcvdHeaders := jsonhttptest.Request(t, client, http.MethodPost, chunksResource(chunk.Address()), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(someContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
...@@ -150,8 +150,8 @@ func TestTags(t *testing.T) { ...@@ -150,8 +150,8 @@ func TestTags(t *testing.T) {
}) })
t.Run("tag counters", func(t *testing.T) { t.Run("tag counters", func(t *testing.T) {
rcvdHeaders := jsonhttptest.Request(t, client, http.MethodPost, chunksResource(someHash), http.StatusOK, rcvdHeaders := jsonhttptest.Request(t, client, http.MethodPost, chunksResource(chunk.Address()), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(someContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: http.StatusText(http.StatusOK), Message: http.StatusText(http.StatusOK),
Code: http.StatusOK, Code: http.StatusOK,
...@@ -252,8 +252,8 @@ func TestTags(t *testing.T) { ...@@ -252,8 +252,8 @@ func TestTags(t *testing.T) {
addr := test.RandomAddress() addr := test.RandomAddress()
// upload content with tag // upload content with tag
jsonhttptest.Request(t, client, http.MethodPost, chunksResource(someHash), http.StatusOK, jsonhttptest.Request(t, client, http.MethodPost, chunksResource(chunk.Address()), http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(someContent)), jsonhttptest.WithRequestBody(bytes.NewReader(chunk.Data())),
jsonhttptest.WithRequestHeader(api.SwarmTagUidHeader, fmt.Sprint(tagId)), jsonhttptest.WithRequestHeader(api.SwarmTagUidHeader, fmt.Sprint(tagId)),
) )
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package contains convenience methods and validator for content-addressed chunks
package content
import (
"encoding/binary"
"errors"
"fmt"
"github.com/ethersphere/bee/pkg/bmtpool"
"github.com/ethersphere/bee/pkg/swarm"
)
// NewChunk creates a new content-addressed single-span chunk.
// The length of the chunk data is set as the span.
func NewChunk(data []byte) (swarm.Chunk, error) {
return NewChunkWithSpan(data, int64(len(data)))
}
// NewChunkWithSpan creates a new content-addressed chunk from given data and span.
func NewChunkWithSpan(data []byte, span int64) (swarm.Chunk, error) {
if len(data) > swarm.ChunkSize {
return nil, errors.New("max chunk size exceeded")
}
if span < swarm.ChunkSize && span != int64(len(data)) {
return nil, fmt.Errorf("single-span chunk size mismatch; span is %d, chunk data length %d", span, len(data))
}
hasher := bmtpool.Get()
defer bmtpool.Put(hasher)
// execute hash, compare and return result
spanBytes := make([]byte, 8)
binary.LittleEndian.PutUint64(spanBytes, uint64(span))
err := hasher.SetSpanBytes(spanBytes)
if err != nil {
return nil, err
}
_, err = hasher.Write(data)
if err != nil {
return nil, err
}
s := hasher.Sum(nil)
payload := append(spanBytes, data...)
address := swarm.NewAddress(s)
return swarm.NewChunk(address, payload), nil
}
// NewChunkWithSpanBytes deserializes a content-addressed chunk from separate
// data and span byte slices.
func NewChunkWithSpanBytes(data, spanBytes []byte) (swarm.Chunk, error) {
hasher := bmtpool.Get()
defer bmtpool.Put(hasher)
// execute hash, compare and return result
err := hasher.SetSpanBytes(spanBytes)
if err != nil {
return nil, err
}
_, err = hasher.Write(data)
if err != nil {
return nil, err
}
s := hasher.Sum(nil)
payload := append(spanBytes, data...)
address := swarm.NewAddress(s)
return swarm.NewChunk(address, payload), nil
}
// contentChunkFromBytes deserializes a content-addressed chunk.
func contentChunkFromBytes(chunkData []byte) (swarm.Chunk, error) {
if len(chunkData) < swarm.SpanSize {
return nil, errors.New("shorter than minimum length")
}
return NewChunkWithSpanBytes(chunkData[8:], chunkData[:8])
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package content_test
import (
"encoding/binary"
"testing"
"github.com/ethersphere/bee/pkg/content"
"github.com/ethersphere/bee/pkg/swarm"
)
// TestChunkWithSpan verifies creation of content addressed chunk from
// byte data.
func TestChunk(t *testing.T) {
bmtHashOfFoo := "2387e8e7d8a48c2a9339c97c1dc3461a9a7aa07e994c5cb8b38fd7c1b3e6ea48"
address := swarm.MustParseHexAddress(bmtHashOfFoo)
c, err := content.NewChunk([]byte("foo"))
if err != nil {
t.Fatal(err)
}
if !address.Equal(c.Address()) {
t.Fatal("address mismatch")
}
}
// TestChunkWithSpan verifies creation of content addressed chunk from
// payload data and span in integer form.
func TestChunkWithSpan(t *testing.T) {
bmtHashOfFoo := "2387e8e7d8a48c2a9339c97c1dc3461a9a7aa07e994c5cb8b38fd7c1b3e6ea48"
address := swarm.MustParseHexAddress(bmtHashOfFoo)
data := []byte("foo")
c, err := content.NewChunkWithSpan(data, int64(len(data)))
if err != nil {
t.Fatal(err)
}
if !address.Equal(c.Address()) {
t.Fatal("address mismatch")
}
}
// TestChunkWithSpanBytes verifies creation of content addressed chunk from
// payload data and span in byte form.
func TestChunkWithSpanBytes(t *testing.T) {
bmtHashOfFoo := "2387e8e7d8a48c2a9339c97c1dc3461a9a7aa07e994c5cb8b38fd7c1b3e6ea48"
address := swarm.MustParseHexAddress(bmtHashOfFoo)
data := []byte("foo")
span := len(data)
spanBytes := make([]byte, 8)
binary.LittleEndian.PutUint64(spanBytes, uint64(span))
c, err := content.NewChunkWithSpanBytes(data, spanBytes)
if err != nil {
t.Fatal(err)
}
if !address.Equal(c.Address()) {
t.Fatal("address mismatch")
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/swarm"
)
var _ swarm.Validator = (*Validator)(nil)
type Validator struct {
rv bool
}
// NewValidator constructs a new Validator
func NewValidator(rv bool) swarm.Validator {
return &Validator{rv: rv}
}
// Validate returns rv from mock struct
func (v *Validator) Validate(ch swarm.Chunk) (valid bool) {
return v.rv
}
...@@ -5,29 +5,35 @@ ...@@ -5,29 +5,35 @@
package content package content
import ( import (
"bytes"
"github.com/ethersphere/bee/pkg/bmtpool"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
var _ swarm.Validator = (*Validator)(nil) // Valid checks whether the given chunk is a valid content-addressed chunk.
func Valid(c swarm.Chunk) bool {
data := c.Data()
if len(data) < swarm.SpanSize {
return false
}
// Validator validates that the address of a given chunk span := data[:swarm.SpanSize]
// is the content address of its contents. content := data[swarm.SpanSize:]
type Validator struct {
}
// NewValidator constructs a new Validator hasher := bmtpool.Get()
func NewValidator() swarm.Validator { defer bmtpool.Put(hasher)
return &Validator{}
}
// Validate performs the validation check. // execute hash, compare and return result
func (v *Validator) Validate(ch swarm.Chunk) (valid bool) { err := hasher.SetSpanBytes(span)
chunkData := ch.Data() if err != nil {
rch, err := contentChunkFromBytes(chunkData) return false
}
_, err = hasher.Write(content)
if err != nil { if err != nil {
return false return false
} }
s := hasher.Sum(nil)
address := ch.Address() return bytes.Equal(s, c.Address().Bytes())
return address.Equal(rch.Address())
} }
...@@ -14,10 +14,6 @@ import ( ...@@ -14,10 +14,6 @@ import (
// TestValidator checks that the validator evaluates correctly // TestValidator checks that the validator evaluates correctly
// on valid and invalid input // on valid and invalid input
func TestValidator(t *testing.T) { func TestValidator(t *testing.T) {
// instantiate validator
validator := content.NewValidator()
// generate address from pre-generated hex of 'foo' from legacy bmt // generate address from pre-generated hex of 'foo' from legacy bmt
bmtHashOfFoo := "2387e8e7d8a48c2a9339c97c1dc3461a9a7aa07e994c5cb8b38fd7c1b3e6ea48" bmtHashOfFoo := "2387e8e7d8a48c2a9339c97c1dc3461a9a7aa07e994c5cb8b38fd7c1b3e6ea48"
address := swarm.MustParseHexAddress(bmtHashOfFoo) address := swarm.MustParseHexAddress(bmtHashOfFoo)
...@@ -30,13 +26,13 @@ func TestValidator(t *testing.T) { ...@@ -30,13 +26,13 @@ func TestValidator(t *testing.T) {
binary.LittleEndian.PutUint64(fooBytes, uint64(fooLength)) binary.LittleEndian.PutUint64(fooBytes, uint64(fooLength))
copy(fooBytes[8:], foo) copy(fooBytes[8:], foo)
ch := swarm.NewChunk(address, fooBytes) ch := swarm.NewChunk(address, fooBytes)
if !validator.Validate(ch) { if !content.Valid(ch) {
t.Fatalf("data '%s' should have validated to hash '%s'", ch.Data(), ch.Address()) t.Fatalf("data '%s' should have validated to hash '%s'", ch.Data(), ch.Address())
} }
// now test with incorrect data // now test with incorrect data
ch = swarm.NewChunk(address, fooBytes[:len(fooBytes)-1]) ch = swarm.NewChunk(address, fooBytes[:len(fooBytes)-1])
if validator.Validate(ch) { if content.Valid(ch) {
t.Fatalf("data '%s' should not have validated to hash '%s'", ch.Data(), ch.Address()) t.Fatalf("data '%s' should not have validated to hash '%s'", ch.Data(), ch.Address())
} }
} }
...@@ -222,7 +222,7 @@ func TestGenerateTestRandomChunk(t *testing.T) { ...@@ -222,7 +222,7 @@ func TestGenerateTestRandomChunk(t *testing.T) {
t.Errorf("first chunk address length %v, want %v", addrLen, 32) t.Errorf("first chunk address length %v, want %v", addrLen, 32)
} }
dataLen := len(c1.Data()) dataLen := len(c1.Data())
if dataLen != swarm.ChunkSize { if dataLen != swarm.ChunkSize+swarm.SpanSize {
t.Errorf("first chunk data length %v, want %v", dataLen, swarm.ChunkSize) t.Errorf("first chunk data length %v, want %v", dataLen, swarm.ChunkSize)
} }
addrLen = len(c2.Address().Bytes()) addrLen = len(c2.Address().Bytes())
...@@ -230,7 +230,7 @@ func TestGenerateTestRandomChunk(t *testing.T) { ...@@ -230,7 +230,7 @@ func TestGenerateTestRandomChunk(t *testing.T) {
t.Errorf("second chunk address length %v, want %v", addrLen, 32) t.Errorf("second chunk address length %v, want %v", addrLen, 32)
} }
dataLen = len(c2.Data()) dataLen = len(c2.Data())
if dataLen != swarm.ChunkSize { if dataLen != swarm.ChunkSize+swarm.SpanSize {
t.Errorf("second chunk data length %v, want %v", dataLen, swarm.ChunkSize) t.Errorf("second chunk data length %v, want %v", dataLen, swarm.ChunkSize)
} }
if c1.Address().Equal(c2.Address()) { if c1.Address().Equal(c2.Address()) {
......
...@@ -21,7 +21,6 @@ import ( ...@@ -21,7 +21,6 @@ import (
"github.com/ethersphere/bee/pkg/accounting" "github.com/ethersphere/bee/pkg/accounting"
"github.com/ethersphere/bee/pkg/addressbook" "github.com/ethersphere/bee/pkg/addressbook"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/content"
"github.com/ethersphere/bee/pkg/crypto" "github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/debugapi" "github.com/ethersphere/bee/pkg/debugapi"
"github.com/ethersphere/bee/pkg/hive" "github.com/ethersphere/bee/pkg/hive"
...@@ -48,7 +47,6 @@ import ( ...@@ -48,7 +47,6 @@ import (
"github.com/ethersphere/bee/pkg/settlement/swap/chequebook" "github.com/ethersphere/bee/pkg/settlement/swap/chequebook"
"github.com/ethersphere/bee/pkg/settlement/swap/swapprotocol" "github.com/ethersphere/bee/pkg/settlement/swap/swapprotocol"
"github.com/ethersphere/bee/pkg/settlement/swap/transaction" "github.com/ethersphere/bee/pkg/settlement/swap/transaction"
"github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/statestore/leveldb" "github.com/ethersphere/bee/pkg/statestore/leveldb"
mockinmem "github.com/ethersphere/bee/pkg/statestore/mock" mockinmem "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -329,9 +327,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -329,9 +327,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
} }
b.localstoreCloser = storer b.localstoreCloser = storer
chunkvalidator := swarm.NewMultiValidator([]swarm.Validator{content.NewValidator(), soc.NewValidator()}) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator, tracer)
tagService := tags.NewTags(stateStore, logger) tagService := tags.NewTags(stateStore, logger)
b.tagsCloser = tagService b.tagsCloser = tagService
...@@ -353,9 +349,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -353,9 +349,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
ns = netstore.New(storer, nil, retrieve, logger) ns = netstore.New(storer, nil, retrieve, logger)
} }
chunkvalidatorWithCallback := swarm.NewMultiValidator([]swarm.Validator{content.NewValidator(), soc.NewValidator()}, pssService.TryUnwrap) pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, pssService.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, chunkvalidatorWithCallback, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
// set the pushSyncer in the PSS // set the pushSyncer in the PSS
pssService.SetPushSyncer(pushSyncProtocol) pssService.SetPushSyncer(pushSyncProtocol)
...@@ -375,7 +369,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -375,7 +369,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
pullStorage := pullstorage.New(storer) pullStorage := pullstorage.New(storer)
pullSync := pullsync.New(p2ps, pullStorage, chunkvalidator, logger) pullSync := pullsync.New(p2ps, pullStorage, pssService.TryUnwrap, logger)
b.pullSyncCloser = pullSync b.pullSyncCloser = pullSync
if err = p2ps.AddProtocol(pullSync.Protocol()); err != nil { if err = p2ps.AddProtocol(pullSync.Protocol()); err != nil {
......
...@@ -15,11 +15,13 @@ import ( ...@@ -15,11 +15,13 @@ import (
"time" "time"
"github.com/ethersphere/bee/pkg/bitvector" "github.com/ethersphere/bee/pkg/bitvector"
"github.com/ethersphere/bee/pkg/content"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pullsync/pb" "github.com/ethersphere/bee/pkg/pullsync/pb"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage" "github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -62,7 +64,7 @@ type Syncer struct { ...@@ -62,7 +64,7 @@ type Syncer struct {
storage pullstorage.Storer storage pullstorage.Storer
quit chan struct{} quit chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
validator swarm.ValidatorWithCallback unwrap func(swarm.Chunk)
ruidMtx sync.Mutex ruidMtx sync.Mutex
ruidCtx map[uint32]func() ruidCtx map[uint32]func()
...@@ -71,12 +73,12 @@ type Syncer struct { ...@@ -71,12 +73,12 @@ type Syncer struct {
io.Closer io.Closer
} }
func New(streamer p2p.Streamer, storage pullstorage.Storer, validator swarm.ValidatorWithCallback, logger logging.Logger) *Syncer { func New(streamer p2p.Streamer, storage pullstorage.Storer, unwrap func(swarm.Chunk), logger logging.Logger) *Syncer {
return &Syncer{ return &Syncer{
streamer: streamer, streamer: streamer,
storage: storage, storage: storage,
metrics: newMetrics(), metrics: newMetrics(),
validator: validator, unwrap: unwrap,
logger: logger, logger: logger,
ruidCtx: make(map[uint32]func()), ruidCtx: make(map[uint32]func()),
wg: sync.WaitGroup{}, wg: sync.WaitGroup{},
...@@ -215,16 +217,15 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8 ...@@ -215,16 +217,15 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8
s.metrics.DeliveryCounter.Inc() s.metrics.DeliveryCounter.Inc()
chunk := swarm.NewChunk(addr, delivery.Data) chunk := swarm.NewChunk(addr, delivery.Data)
valid, callback := s.validator.ValidWithCallback(chunk) if content.Valid(chunk) {
if !valid { go s.unwrap(chunk)
} else if !soc.Valid(chunk) {
return 0, ru.Ruid, swarm.ErrInvalidChunk return 0, ru.Ruid, swarm.ErrInvalidChunk
} }
if err = s.storage.Put(ctx, storage.ModePutSync, chunk); err != nil { if err = s.storage.Put(ctx, storage.ModePutSync, chunk); err != nil {
return 0, ru.Ruid, fmt.Errorf("delivery put: %w", err) return 0, ru.Ruid, fmt.Errorf("delivery put: %w", err)
} }
if callback != nil {
go callback()
}
} }
return offer.Topmost, ru.Ruid, nil return offer.Topmost, ru.Ruid, nil
} }
......
...@@ -6,7 +6,6 @@ package pullsync_test ...@@ -6,7 +6,6 @@ package pullsync_test
import ( import (
"context" "context"
"crypto/rand"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -17,17 +16,12 @@ import ( ...@@ -17,17 +16,12 @@ import (
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pullsync" "github.com/ethersphere/bee/pkg/pullsync"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage/mock" "github.com/ethersphere/bee/pkg/pullsync/pullstorage/mock"
testingc "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
var ( var (
addrs = []swarm.Address{ addrs []swarm.Address
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000001"),
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000002"),
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000003"),
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000004"),
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000005"),
}
chunks []swarm.Chunk chunks []swarm.Chunk
) )
...@@ -39,11 +33,12 @@ func someChunks(i ...int) (c []swarm.Chunk) { ...@@ -39,11 +33,12 @@ func someChunks(i ...int) (c []swarm.Chunk) {
} }
func init() { func init() {
chunks = make([]swarm.Chunk, 5) n := 5
for i := 0; i < 5; i++ { chunks = make([]swarm.Chunk, n)
data := make([]byte, swarm.ChunkSize) addrs = make([]swarm.Address, n)
_, _ = rand.Read(data) for i := 0; i < n; i++ {
chunks[i] = swarm.NewChunk(addrs[i], data) chunks[i] = testingc.GenerateTestRandomChunk()
addrs[i] = chunks[i].Address()
} }
} }
...@@ -217,20 +212,7 @@ func haveChunks(t *testing.T, s *mock.PullStorage, addrs ...swarm.Address) { ...@@ -217,20 +212,7 @@ func haveChunks(t *testing.T, s *mock.PullStorage, addrs ...swarm.Address) {
func newPullSync(s p2p.Streamer, o ...mock.Option) (*pullsync.Syncer, *mock.PullStorage) { func newPullSync(s p2p.Streamer, o ...mock.Option) (*pullsync.Syncer, *mock.PullStorage) {
storage := mock.NewPullStorage(o...) storage := mock.NewPullStorage(o...)
c := make(chan swarm.Chunk)
validator := &mockValidator{c}
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
return pullsync.New(s, storage, validator, logger), storage unwrap := func(swarm.Chunk) {}
} return pullsync.New(s, storage, unwrap, logger), storage
type mockValidator struct {
c chan swarm.Chunk
}
func (*mockValidator) Validate(swarm.Chunk) bool {
return true
}
func (mv *mockValidator) ValidWithCallback(c swarm.Chunk) (bool, func()) {
return true, func() { mv.c <- c }
} }
...@@ -11,10 +11,12 @@ import ( ...@@ -11,10 +11,12 @@ import (
"time" "time"
"github.com/ethersphere/bee/pkg/accounting" "github.com/ethersphere/bee/pkg/accounting"
"github.com/ethersphere/bee/pkg/content"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pushsync/pb" "github.com/ethersphere/bee/pkg/pushsync/pb"
"github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
...@@ -46,7 +48,7 @@ type PushSync struct { ...@@ -46,7 +48,7 @@ type PushSync struct {
storer storage.Putter storer storage.Putter
peerSuggester topology.ClosestPeerer peerSuggester topology.ClosestPeerer
tagger *tags.Tags tagger *tags.Tags
validator swarm.ValidatorWithCallback unwrap func(swarm.Chunk)
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer pricer accounting.Pricer
...@@ -56,13 +58,13 @@ type PushSync struct { ...@@ -56,13 +58,13 @@ type PushSync struct {
var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk
func New(streamer p2p.StreamerDisconnecter, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, validator swarm.ValidatorWithCallback, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync { func New(streamer p2p.StreamerDisconnecter, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, unwrap func(swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync {
ps := &PushSync{ ps := &PushSync{
streamer: streamer, streamer: streamer,
storer: storer, storer: storer,
peerSuggester: closestPeerer, peerSuggester: closestPeerer,
tagger: tagger, tagger: tagger,
validator: validator, unwrap: unwrap,
logger: logger, logger: logger,
accounting: accounting, accounting: accounting,
pricer: pricer, pricer: pricer,
...@@ -106,11 +108,15 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -106,11 +108,15 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
chunk := swarm.NewChunk(swarm.NewAddress(ch.Address), ch.Data) chunk := swarm.NewChunk(swarm.NewAddress(ch.Address), ch.Data)
// validate the chunk and returns the delivery callback for the validator if content.Valid(chunk) {
valid, callback := ps.validator.ValidWithCallback(chunk) if ps.unwrap != nil {
if !valid { go ps.unwrap(chunk)
}
} else {
if !soc.Valid(chunk) {
return swarm.ErrInvalidChunk return swarm.ErrInvalidChunk
} }
}
span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()}) span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()})
defer span.Finish() defer span.Finish()
...@@ -120,9 +126,6 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -120,9 +126,6 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
if err != nil { if err != nil {
// If i am the closest peer then store the chunk and send receipt // If i am the closest peer then store the chunk and send receipt
if errors.Is(err, topology.ErrWantSelf) { if errors.Is(err, topology.ErrWantSelf) {
if callback != nil {
go callback()
}
return ps.handleDeliveryResponse(ctx, w, p, chunk) return ps.handleDeliveryResponse(ctx, w, p, chunk)
} }
return err return err
......
This diff is collapsed.
...@@ -225,11 +225,11 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store ...@@ -225,11 +225,11 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
return nil return nil
}} }}
server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps, logger, serverMockAccounting, nil, nil, nil) server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps, logger, serverMockAccounting, nil, nil)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
) )
retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil, nil) retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil)
ns := netstore.New(storer, recoveryFunc, retrieve, logger) ns := netstore.New(storer, recoveryFunc, retrieve, logger)
return ns return ns
} }
......
...@@ -11,10 +11,12 @@ import ( ...@@ -11,10 +11,12 @@ import (
"time" "time"
"github.com/ethersphere/bee/pkg/accounting" "github.com/ethersphere/bee/pkg/accounting"
"github.com/ethersphere/bee/pkg/content"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
pb "github.com/ethersphere/bee/pkg/retrieval/pb" pb "github.com/ethersphere/bee/pkg/retrieval/pb"
"github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
...@@ -46,11 +48,10 @@ type Service struct { ...@@ -46,11 +48,10 @@ type Service struct {
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer pricer accounting.Pricer
validator swarm.Validator
tracer *tracing.Tracer tracer *tracing.Tracer
} }
func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, validator swarm.Validator, tracer *tracing.Tracer) *Service { func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *Service {
return &Service{ return &Service{
addr: addr, addr: addr,
streamer: streamer, streamer: streamer,
...@@ -59,7 +60,6 @@ func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunk ...@@ -59,7 +60,6 @@ func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunk
logger: logger, logger: logger,
accounting: accounting, accounting: accounting,
pricer: pricer, pricer: pricer,
validator: validator,
tracer: tracer, tracer: tracer,
} }
} }
...@@ -141,7 +141,6 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -141,7 +141,6 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
} }
ctx, cancel := context.WithTimeout(ctx, retrieveChunkTimeout) ctx, cancel := context.WithTimeout(ctx, retrieveChunkTimeout)
defer cancel() defer cancel()
peer, err = s.closestPeer(addr, skipPeers, allowUpstream) peer, err = s.closestPeer(addr, skipPeers, allowUpstream)
if err != nil { if err != nil {
return nil, peer, fmt.Errorf("get closest for address %s, allow upstream %v: %w", addr.String(), allowUpstream, err) return nil, peer, fmt.Errorf("get closest for address %s, allow upstream %v: %w", addr.String(), allowUpstream, err)
...@@ -169,7 +168,6 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -169,7 +168,6 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
}() }()
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsgWithContext(ctx, &pb.Request{ if err := w.WriteMsgWithContext(ctx, &pb.Request{
Addr: addr.Bytes(), Addr: addr.Bytes(),
}); err != nil { }); err != nil {
...@@ -181,12 +179,14 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -181,12 +179,14 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
return nil, peer, fmt.Errorf("read delivery: %w peer %s", err, peer.String()) return nil, peer, fmt.Errorf("read delivery: %w peer %s", err, peer.String())
} }
// credit the peer after successful delivery
chunk = swarm.NewChunk(addr, d.Data) chunk = swarm.NewChunk(addr, d.Data)
if !s.validator.Validate(chunk) { if !content.Valid(chunk) {
return nil, peer, err if !soc.Valid(chunk) {
return nil, peer, swarm.ErrInvalidChunk
}
} }
// credit the peer after successful delivery
err = s.accounting.Credit(peer, chunkPrice) err = s.accounting.Credit(peer, chunkPrice)
if err != nil { if err != nil {
return nil, peer, err return nil, peer, err
...@@ -236,7 +236,6 @@ func (s *Service) closestPeer(addr swarm.Address, skipPeers []swarm.Address, all ...@@ -236,7 +236,6 @@ func (s *Service) closestPeer(addr swarm.Address, skipPeers []swarm.Address, all
if closest.IsZero() { if closest.IsZero() {
return swarm.Address{}, topology.ErrNotFound return swarm.Address{}, topology.ErrNotFound
} }
if allowUpstream { if allowUpstream {
return closest, nil return closest, nil
} }
......
...@@ -9,12 +9,11 @@ import ( ...@@ -9,12 +9,11 @@ import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors" "errors"
"io/ioutil" "os"
"testing" "testing"
"time" "time"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock" accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
...@@ -22,6 +21,7 @@ import ( ...@@ -22,6 +21,7 @@ import (
pb "github.com/ethersphere/bee/pkg/retrieval/pb" pb "github.com/ethersphere/bee/pkg/retrieval/pb"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
storemock "github.com/ethersphere/bee/pkg/storage/mock" storemock "github.com/ethersphere/bee/pkg/storage/mock"
testingc "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
) )
...@@ -30,17 +30,12 @@ var testTimeout = 5 * time.Second ...@@ -30,17 +30,12 @@ var testTimeout = 5 * time.Second
// TestDelivery tests that a naive request -> delivery flow works. // TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) { func TestDelivery(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(os.Stdout, 5)
mockValidator := mock.NewValidator(true)
mockStorer := storemock.NewStorer() mockStorer := storemock.NewStorer()
reqAddr, err := swarm.ParseHexAddress("00112233") chunk := testingc.FixtureChunk("0033")
if err != nil {
t.Fatal(err)
}
reqData := []byte("data data data")
// put testdata in the mock store of the server // put testdata in the mock store of the server
_, err = mockStorer.Put(context.Background(), storage.ModePutUpload, swarm.NewChunk(reqAddr, reqData)) _, err := mockStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -51,7 +46,7 @@ func TestDelivery(t *testing.T) { ...@@ -51,7 +46,7 @@ func TestDelivery(t *testing.T) {
pricerMock := accountingmock.NewPricer(price, price) pricerMock := accountingmock.NewPricer(price, price)
// create the server that will handle the request and will serve the response // create the server that will handle the request and will serve the response
server := retrieval.New(swarm.MustParseHexAddress("00112234"), mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, mockValidator, nil) server := retrieval.New(swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, nil)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
) )
...@@ -60,7 +55,7 @@ func TestDelivery(t *testing.T) { ...@@ -60,7 +55,7 @@ func TestDelivery(t *testing.T) {
// client mock storer does not store any data at this point // client mock storer does not store any data at this point
// but should be checked at at the end of the test for the // but should be checked at at the end of the test for the
// presence of the reqAddr key and value to ensure delivery // presence of the chunk address key and value to ensure delivery
// was successful // was successful
clientMockStorer := storemock.NewStorer() clientMockStorer := storemock.NewStorer()
...@@ -69,15 +64,15 @@ func TestDelivery(t *testing.T) { ...@@ -69,15 +64,15 @@ func TestDelivery(t *testing.T) {
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
return nil return nil
}} }}
client := retrieval.New(swarm.MustParseHexAddress("9ee7add8"), clientMockStorer, recorder, ps, logger, clientMockAccounting, pricerMock, mockValidator, nil) client := retrieval.New(swarm.MustParseHexAddress("9ee7add8"), clientMockStorer, recorder, ps, logger, clientMockAccounting, pricerMock, nil)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel() defer cancel()
v, err := client.RetrieveChunk(ctx, reqAddr) v, err := client.RetrieveChunk(ctx, chunk.Address())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(v.Data(), reqData) { if !bytes.Equal(v.Data(), chunk.Data()) {
t.Fatalf("request and response data not equal. got %s want %s", v, reqData) t.Fatalf("request and response data not equal. got %s want %s", v, chunk.Data())
} }
records, err := recorder.Records(peerID, "retrieval", "1.0.0", "retrieval") records, err := recorder.Records(peerID, "retrieval", "1.0.0", "retrieval")
if err != nil { if err != nil {
...@@ -133,35 +128,32 @@ func TestDelivery(t *testing.T) { ...@@ -133,35 +128,32 @@ func TestDelivery(t *testing.T) {
} }
func TestRetrieveChunk(t *testing.T) { func TestRetrieveChunk(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(os.Stdout, 5)
mockValidator := mock.NewValidator(true)
pricer := accountingmock.NewPricer(1, 1) pricer := accountingmock.NewPricer(1, 1)
// requesting a chunk from downstream peer is expected // requesting a chunk from downstream peer is expected
t.Run("downstream", func(t *testing.T) { t.Run("downstream", func(t *testing.T) {
serverAddress := swarm.MustParseHexAddress("03") serverAddress := swarm.MustParseHexAddress("03")
chunkAddress := swarm.MustParseHexAddress("02")
clientAddress := swarm.MustParseHexAddress("01") clientAddress := swarm.MustParseHexAddress("01")
chunk := testingc.FixtureChunk("02c2")
serverStorer := storemock.NewStorer() serverStorer := storemock.NewStorer()
chunk := swarm.NewChunk(chunkAddress, []byte("some data"))
_, err := serverStorer.Put(context.Background(), storage.ModePutUpload, chunk) _, err := serverStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
server := retrieval.New(serverAddress, serverStorer, nil, nil, logger, accountingmock.NewAccounting(), pricer, mockValidator, nil) server := retrieval.New(serverAddress, serverStorer, nil, nil, logger, accountingmock.NewAccounting(), pricer, nil)
recorder := streamtest.New(streamtest.WithProtocols(server.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(server.Protocol()))
clientSuggester := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error { clientSuggester := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddress, 0) _, _, _ = f(serverAddress, 0)
return nil return nil
}} }}
client := retrieval.New(clientAddress, nil, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, mockValidator, nil) client := retrieval.New(clientAddress, nil, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, nil)
got, err := client.RetrieveChunk(context.Background(), chunkAddress) got, err := client.RetrieveChunk(context.Background(), chunk.Address())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -171,13 +163,13 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -171,13 +163,13 @@ func TestRetrieveChunk(t *testing.T) {
}) })
t.Run("forward", func(t *testing.T) { t.Run("forward", func(t *testing.T) {
chunkAddress := swarm.MustParseHexAddress("00") chunk := testingc.FixtureChunk("0025")
serverAddress := swarm.MustParseHexAddress("01")
forwarderAddress := swarm.MustParseHexAddress("02") serverAddress := swarm.MustParseHexAddress("0100000000000000000000000000000000000000000000000000000000000000")
clientAddress := swarm.MustParseHexAddress("03") forwarderAddress := swarm.MustParseHexAddress("0200000000000000000000000000000000000000000000000000000000000000")
clientAddress := swarm.MustParseHexAddress("030000000000000000000000000000000000000000000000000000000000000000")
serverStorer := storemock.NewStorer() serverStorer := storemock.NewStorer()
chunk := swarm.NewChunk(chunkAddress, []byte("some data"))
_, err := serverStorer.Put(context.Background(), storage.ModePutUpload, chunk) _, err := serverStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -185,13 +177,12 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -185,13 +177,12 @@ func TestRetrieveChunk(t *testing.T) {
server := retrieval.New( server := retrieval.New(
serverAddress, serverAddress,
serverStorer, // chunk is in sever's store serverStorer, // chunk is in server's store
nil, nil,
nil, nil,
logger, logger,
accountingmock.NewAccounting(), accountingmock.NewAccounting(),
pricer, pricer,
mockValidator,
nil, nil,
) )
...@@ -206,7 +197,6 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -206,7 +197,6 @@ func TestRetrieveChunk(t *testing.T) {
logger, logger,
accountingmock.NewAccounting(), accountingmock.NewAccounting(),
pricer, pricer,
mockValidator,
nil, nil,
) )
...@@ -221,11 +211,10 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -221,11 +211,10 @@ func TestRetrieveChunk(t *testing.T) {
logger, logger,
accountingmock.NewAccounting(), accountingmock.NewAccounting(),
pricer, pricer,
mockValidator,
nil, nil,
) )
got, err := client.RetrieveChunk(context.Background(), chunkAddress) got, err := client.RetrieveChunk(context.Background(), chunk.Address())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -7,4 +7,5 @@ package soc ...@@ -7,4 +7,5 @@ package soc
var ( var (
ToSignDigest = toSignDigest ToSignDigest = toSignDigest
RecoverAddress = recoverAddress RecoverAddress = recoverAddress
ContentAddressedChunk = contentAddressedChunk
) )
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/content" "github.com/ethersphere/bee/pkg/bmtpool"
"github.com/ethersphere/bee/pkg/crypto" "github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -130,7 +130,7 @@ func FromChunk(sch swarm.Chunk) (*Soc, error) { ...@@ -130,7 +130,7 @@ func FromChunk(sch swarm.Chunk) (*Soc, error) {
spanBytes := chunkData[cursor : cursor+swarm.SpanSize] spanBytes := chunkData[cursor : cursor+swarm.SpanSize]
cursor += swarm.SpanSize cursor += swarm.SpanSize
ch, err := content.NewChunkWithSpanBytes(chunkData[cursor:], spanBytes) ch, err := contentAddressedChunk(chunkData[cursor:], spanBytes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -231,3 +231,23 @@ func recoverAddress(signature, digest []byte) ([]byte, error) { ...@@ -231,3 +231,23 @@ func recoverAddress(signature, digest []byte) ([]byte, error) {
} }
return recoveredEthereumAddress, nil return recoveredEthereumAddress, nil
} }
func contentAddressedChunk(data, spanBytes []byte) (swarm.Chunk, error) {
hasher := bmtpool.Get()
defer bmtpool.Put(hasher)
// execute hash, compare and return result
err := hasher.SetSpanBytes(spanBytes)
if err != nil {
return nil, err
}
_, err = hasher.Write(data)
if err != nil {
return nil, err
}
s := hasher.Sum(nil)
payload := append(spanBytes, data...)
address := swarm.NewAddress(s)
return swarm.NewChunk(address, payload), nil
}
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"encoding/binary" "encoding/binary"
"testing" "testing"
"github.com/ethersphere/bee/pkg/content"
"github.com/ethersphere/bee/pkg/crypto" "github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -27,7 +26,7 @@ func TestToChunk(t *testing.T) { ...@@ -27,7 +26,7 @@ func TestToChunk(t *testing.T) {
id := make([]byte, 32) id := make([]byte, 32)
payload := []byte("foo") payload := []byte("foo")
ch, err := content.NewChunk(payload) ch, err := chunk(payload)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -94,7 +93,7 @@ func TestFromChunk(t *testing.T) { ...@@ -94,7 +93,7 @@ func TestFromChunk(t *testing.T) {
id := make([]byte, 32) id := make([]byte, 32)
payload := []byte("foo") payload := []byte("foo")
ch, err := content.NewChunk(payload) ch, err := chunk(payload)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -122,3 +121,9 @@ func TestFromChunk(t *testing.T) { ...@@ -122,3 +121,9 @@ func TestFromChunk(t *testing.T) {
t.Fatalf("owner address mismatch %x %x", ownerEthereumAddress, u2.OwnerAddress()) t.Fatalf("owner address mismatch %x %x", ownerEthereumAddress, u2.OwnerAddress())
} }
} }
func chunk(data []byte) (swarm.Chunk, error) {
span := make([]byte, swarm.SpanSize)
binary.LittleEndian.PutUint64(span, uint64(len(data)))
return soc.ContentAddressedChunk(data, span)
}
...@@ -8,20 +8,8 @@ import ( ...@@ -8,20 +8,8 @@ import (
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
var _ swarm.Validator = (*Validator)(nil) // Valid checks if the chunk is a valid single-owner chunk.
func Valid(ch swarm.Chunk) bool {
// Validator validates that the address of a given chunk
// is a single-owner chunk.
type Validator struct {
}
// NewValidator creates a new Validator.
func NewValidator() swarm.Validator {
return &Validator{}
}
// Validate performs the validation check.
func (v *Validator) Validate(ch swarm.Chunk) (valid bool) {
s, err := FromChunk(ch) s, err := FromChunk(ch)
if err != nil { if err != nil {
return false return false
......
...@@ -39,14 +39,13 @@ func TestValidator(t *testing.T) { ...@@ -39,14 +39,13 @@ func TestValidator(t *testing.T) {
} }
// check valid chunk // check valid chunk
v := soc.NewValidator() if !soc.Valid(sch) {
if !v.Validate(sch) {
t.Fatal("valid chunk evaluates to invalid") t.Fatal("valid chunk evaluates to invalid")
} }
// check invalid data // check invalid data
sch.Data()[0] = 0x01 sch.Data()[0] = 0x01
if v.Validate(sch) { if soc.Valid(sch) {
t.Fatal("chunk with invalid data evaluates to valid") t.Fatal("chunk with invalid data evaluates to valid")
} }
...@@ -56,7 +55,7 @@ func TestValidator(t *testing.T) { ...@@ -56,7 +55,7 @@ func TestValidator(t *testing.T) {
wrongAddressBytes[0] = 255 - wrongAddressBytes[0] wrongAddressBytes[0] = 255 - wrongAddressBytes[0]
wrongAddress := swarm.NewAddress(wrongAddressBytes) wrongAddress := swarm.NewAddress(wrongAddressBytes)
sch = swarm.NewChunk(wrongAddress, sch.Data()) sch = swarm.NewChunk(wrongAddress, sch.Data())
if v.Validate(sch) { if soc.Valid(sch) {
t.Fatal("chunk with invalid address evaluates to valid") t.Fatal("chunk with invalid address evaluates to valid")
} }
} }
...@@ -23,7 +23,6 @@ type MockStorer struct { ...@@ -23,7 +23,6 @@ type MockStorer struct {
pinnedCounter []uint64 // and its respective counter. These are stored as slices to preserve the order. pinnedCounter []uint64 // and its respective counter. These are stored as slices to preserve the order.
subpull []storage.Descriptor subpull []storage.Descriptor
partialInterval bool partialInterval bool
validator swarm.Validator
morePull chan struct{} morePull chan struct{}
mtx sync.Mutex mtx sync.Mutex
quit chan struct{} quit chan struct{}
...@@ -46,12 +45,6 @@ func WithBaseAddress(a swarm.Address) Option { ...@@ -46,12 +45,6 @@ func WithBaseAddress(a swarm.Address) Option {
}) })
} }
func WithValidator(v swarm.Validator) Option {
return optionFunc(func(m *MockStorer) {
m.validator = v
})
}
func WithPartialInterval(v bool) Option { func WithPartialInterval(v bool) Option {
return optionFunc(func(m *MockStorer) { return optionFunc(func(m *MockStorer) {
m.partialInterval = v m.partialInterval = v
...@@ -92,11 +85,6 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm ...@@ -92,11 +85,6 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm
exist = make([]bool, len(chs)) exist = make([]bool, len(chs))
for i, ch := range chs { for i, ch := range chs {
if m.validator != nil {
if !m.validator.Validate(ch) {
return nil, storage.ErrInvalidChunk
}
}
exist[i], err = m.has(ctx, ch.Address()) exist[i], err = m.has(ctx, ch.Address())
if err != nil { if err != nil {
return exist, err return exist, err
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/storage/mock/validator"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -52,41 +51,3 @@ func TestMockStorer(t *testing.T) { ...@@ -52,41 +51,3 @@ func TestMockStorer(t *testing.T) {
t.Fatal("expected mock store to have key") t.Fatal("expected mock store to have key")
} }
} }
func TestMockValidatingStorer(t *testing.T) {
validAddressHex := "aabbcc"
invalidAddressHex := "bbccdd"
validAddress := swarm.MustParseHexAddress(validAddressHex)
invalidAddress := swarm.MustParseHexAddress(invalidAddressHex)
validContent := []byte("bbaatt")
invalidContent := []byte("bbaattss")
s := mock.NewStorer(mock.WithValidator(validator.NewMockValidator(validAddress, validContent)))
ctx := context.Background()
if _, err := s.Put(ctx, storage.ModePutUpload, swarm.NewChunk(validAddress, validContent)); err != nil {
t.Fatalf("expected not error but got: %v", err)
}
if _, err := s.Put(ctx, storage.ModePutUpload, swarm.NewChunk(invalidAddress, validContent)); err == nil {
t.Fatalf("expected error but got none")
}
if _, err := s.Put(ctx, storage.ModePutUpload, swarm.NewChunk(invalidAddress, invalidContent)); err == nil {
t.Fatalf("expected error but got none")
}
if chunk, err := s.Get(ctx, storage.ModeGetRequest, validAddress); err != nil {
t.Fatalf("got error on get but expected none: %v", err)
} else if !bytes.Equal(chunk.Data(), validContent) {
t.Fatal("stored content not identical to input data")
}
if _, err := s.Get(ctx, storage.ModeGetRequest, invalidAddress); err == nil {
t.Fatal("got no error on get but expected one")
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package validator
import (
"bytes"
"github.com/ethersphere/bee/pkg/swarm"
)
// MockValidator returns true if the data and address passed in the Validate method
// are a byte-wise match to the data and address passed to the constructor
type MockValidator struct {
addressDataPair map[string][]byte // Make validator accept more than one address/data pair
}
// NewMockValidator constructs a new MockValidator
func NewMockValidator(address swarm.Address, data []byte) *MockValidator {
mp := &MockValidator{
addressDataPair: make(map[string][]byte),
}
mp.addressDataPair[address.String()] = data
return mp
}
// Add a new address/data pair which can be validated
func (v *MockValidator) AddPair(address swarm.Address, data []byte) {
v.addressDataPair[address.String()] = data
}
// Validate checks the passed chunk for validity
func (v *MockValidator) Validate(ch swarm.Chunk) (valid bool) {
if data, ok := v.addressDataPair[ch.Address().String()]; ok {
if bytes.Equal(data, ch.Data()) {
return true
} else if len(ch.Data()) > 8 && bytes.Equal(data, ch.Data()[8:]) {
return true
}
}
return false
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package validator_test
import (
"testing"
"github.com/ethersphere/bee/pkg/storage/mock/validator"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestMockValidator(t *testing.T) {
validAddr := swarm.NewAddress([]byte("foo"))
invalidAddr := swarm.NewAddress([]byte("bar"))
validContent := []byte("xyzzy")
invalidContent := []byte("yzzyx")
validator := validator.NewMockValidator(validAddr, validContent)
ch := swarm.NewChunk(validAddr, validContent)
if !validator.Validate(ch) {
t.Fatalf("chunk '%v' should be valid", ch)
}
ch = swarm.NewChunk(invalidAddr, validContent)
if validator.Validate(ch) {
t.Fatalf("chunk '%v' should be invalid", ch)
}
ch = swarm.NewChunk(validAddr, invalidContent)
if validator.Validate(ch) {
t.Fatalf("chunk '%v' should be invalid", ch)
}
ch = swarm.NewChunk(invalidAddr, invalidContent)
if validator.Validate(ch) {
t.Fatalf("chunk '%v' should be invalid", ch)
}
}
...@@ -17,25 +17,67 @@ ...@@ -17,25 +17,67 @@
package testing package testing
import ( import (
"encoding/binary"
"math/rand" "math/rand"
"time" "time"
"github.com/ethersphere/bee/pkg/bmtpool"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
// fixtreuChunks are pregenerated content-addressed chunks necessary for explicit
// test scenarios where random generated chunks are not good enough.
var fixtureChunks = map[string]swarm.Chunk{
"0025": swarm.NewChunk(
swarm.MustParseHexAddress("0025737be11979e91654dffd2be817ac1e52a2dadb08c97a7cef12f937e707bc"),
[]byte{72, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 149, 179, 31, 244, 146, 247, 129, 123, 132, 248, 215, 77, 44, 47, 91, 248, 229, 215, 89, 156, 210, 243, 3, 110, 204, 74, 101, 119, 53, 53, 145, 188, 193, 153, 130, 197, 83, 152, 36, 140, 150, 209, 191, 214, 193, 4, 144, 121, 32, 45, 205, 220, 59, 227, 28, 43, 161, 51, 108, 14, 106, 180, 135, 2},
),
"0033": swarm.NewChunk(
swarm.MustParseHexAddress("0033153ac8cfb0c343db1795f578c15ed8ef827f3e68ed3c58329900bf0d7276"),
[]byte{72, 0, 0, 0, 0, 0, 0, 0, 170, 117, 0, 0, 0, 0, 0, 0, 21, 157, 63, 86, 45, 17, 166, 184, 47, 126, 58, 172, 242, 77, 153, 249, 97, 5, 107, 244, 23, 153, 220, 255, 254, 47, 209, 24, 63, 58, 126, 142, 41, 79, 201, 182, 178, 227, 235, 223, 63, 11, 220, 155, 40, 181, 56, 204, 91, 44, 51, 185, 95, 155, 245, 235, 187, 250, 103, 49, 139, 184, 46, 199},
),
"02c2": swarm.NewChunk(
swarm.MustParseHexAddress("02c2bd0db71efb7d245eafcc1c126189c1f598feb80e8f14e7ecef913c6a2ef5"),
[]byte{72, 0, 0, 0, 0, 0, 0, 0, 226, 0, 0, 0, 0, 0, 0, 0, 67, 234, 252, 231, 229, 11, 121, 163, 131, 171, 41, 107, 57, 191, 221, 32, 62, 204, 159, 124, 116, 87, 30, 244, 99, 137, 121, 248, 119, 56, 74, 102, 140, 73, 178, 7, 151, 22, 47, 126, 173, 30, 43, 7, 61, 187, 13, 236, 59, 194, 245, 18, 25, 237, 106, 125, 78, 241, 35, 34, 116, 154, 105, 205},
),
"7000": swarm.NewChunk(
swarm.MustParseHexAddress("70002115a015d40a1f5ef68c29d072f06fae58854934c1cb399fcb63cf336127"),
[]byte{72, 0, 0, 0, 0, 0, 0, 0, 124, 59, 0, 0, 0, 0, 0, 0, 44, 67, 19, 101, 42, 213, 4, 209, 212, 189, 107, 244, 111, 22, 230, 24, 245, 103, 227, 165, 88, 74, 50, 11, 143, 197, 220, 118, 175, 24, 169, 193, 15, 40, 225, 196, 246, 151, 1, 45, 86, 7, 36, 99, 156, 86, 83, 29, 46, 207, 115, 112, 126, 88, 101, 128, 153, 113, 30, 27, 50, 232, 77, 215},
),
}
func init() { func init() {
// needed for GenerateTestRandomChunk // needed for GenerateTestRandomChunk
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }
// GenerateTestRandomChunk generates a Chunk that is not // GenerateTestRandomChunk generates a valid content addressed chunk.
// valid, but it contains a random key and a random value.
// This function is faster then storage.GenerateRandomChunk
// which generates a valid chunk.
// Some tests in do not need valid chunks, just
// random data, and their execution time can be decreased
// using this function.
func GenerateTestRandomChunk() swarm.Chunk { func GenerateTestRandomChunk() swarm.Chunk {
data := make([]byte, swarm.ChunkSize)
_, _ = rand.Read(data)
span := make([]byte, swarm.SpanSize)
binary.LittleEndian.PutUint64(span, uint64(len(data)))
data = append(span, data...)
hasher := bmtpool.Get()
defer bmtpool.Put(hasher)
err := hasher.SetSpanBytes(data[:swarm.SpanSize])
if err != nil {
panic(err)
}
_, err = hasher.Write(data[swarm.SpanSize:])
if err != nil {
panic(err)
}
ref := hasher.Sum(nil)
return swarm.NewChunk(swarm.NewAddress(ref), data)
}
// GenerateTestRandomInvalidChunk generates a random, however invalid, content
// addressed chunk.
func GenerateTestRandomInvalidChunk() swarm.Chunk {
data := make([]byte, swarm.ChunkSize) data := make([]byte, swarm.ChunkSize)
_, _ = rand.Read(data) _, _ = rand.Read(data)
key := make([]byte, swarm.SectionSize) key := make([]byte, swarm.SectionSize)
...@@ -48,7 +90,17 @@ func GenerateTestRandomChunk() swarm.Chunk { ...@@ -48,7 +90,17 @@ func GenerateTestRandomChunk() swarm.Chunk {
func GenerateTestRandomChunks(count int) []swarm.Chunk { func GenerateTestRandomChunks(count int) []swarm.Chunk {
chunks := make([]swarm.Chunk, count) chunks := make([]swarm.Chunk, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
chunks[i] = GenerateTestRandomChunk() chunks[i] = GenerateTestRandomInvalidChunk()
} }
return chunks return chunks
} }
// FixtureChunk gets a pregenerated content-addressed chunk and
// panics if one is not found.
func FixtureChunk(prefix string) swarm.Chunk {
c, ok := fixtureChunks[prefix]
if !ok {
panic("no fixture found")
}
return c
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package swarm
type Validator interface {
Validate(ch Chunk) (valid bool)
}
type ValidatorWithCallback interface {
ValidWithCallback(ch Chunk) (valid bool, callback func())
Validator
}
var _ Validator = (*validatorWithCallback)(nil)
type validatorWithCallback struct {
v Validator
callback func(Chunk)
}
func (v *validatorWithCallback) Validate(ch Chunk) bool {
valid := v.v.Validate(ch)
if valid {
go v.callback(ch)
}
return valid
}
var _ ValidatorWithCallback = (*multiValidator)(nil)
type multiValidator struct {
validators []Validator
callbacks []func(Chunk)
}
func NewMultiValidator(validators []Validator, callbacks ...func(Chunk)) ValidatorWithCallback {
return &multiValidator{validators, callbacks}
}
func (mv *multiValidator) Validate(ch Chunk) bool {
for _, v := range mv.validators {
if v.Validate(ch) {
return true
}
}
return false
}
func (mv *multiValidator) ValidWithCallback(ch Chunk) (bool, func()) {
for i, v := range mv.validators {
if v.Validate(ch) {
if i < len(mv.callbacks) {
return true, func() { mv.callbacks[i](ch) }
}
return true, nil
}
}
return false, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment