Commit 09507fc3 authored by Viktor Trón's avatar Viktor Trón Committed by GitHub

validators/callbacks reorg to enable pss mailboxing (#942)

parent 828d95cd
...@@ -65,7 +65,7 @@ func TestPssWebsocketSingleHandler(t *testing.T) { ...@@ -65,7 +65,7 @@ func TestPssWebsocketSingleHandler(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
p.TryUnwrap(context.Background(), tc) p.TryUnwrap(tc)
waitMessage(t, msgContent, payload, &mtx) waitMessage(t, msgContent, payload, &mtx)
} }
...@@ -103,7 +103,7 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) { ...@@ -103,7 +103,7 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
p.TryUnwrap(context.Background(), tc) p.TryUnwrap(tc)
waitMessage(t, msgContent, nil, &mtx) waitMessage(t, msgContent, nil, &mtx)
} }
...@@ -146,7 +146,7 @@ func TestPssWebsocketMultiHandler(t *testing.T) { ...@@ -146,7 +146,7 @@ func TestPssWebsocketMultiHandler(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
p.TryUnwrap(context.Background(), tc) p.TryUnwrap(tc)
waitMessage(t, msgContent, nil, &mtx) waitMessage(t, msgContent, nil, &mtx)
waitMessage(t, msgContent2, nil, &mtx) waitMessage(t, msgContent2, nil, &mtx)
...@@ -275,7 +275,7 @@ func TestPssPingPong(t *testing.T) { ...@@ -275,7 +275,7 @@ func TestPssPingPong(t *testing.T) {
time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
p.TryUnwrap(context.Background(), tc) p.TryUnwrap(tc)
waitMessage(t, msgContent, nil, &mtx) waitMessage(t, msgContent, nil, &mtx)
} }
...@@ -399,7 +399,7 @@ func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() { ...@@ -399,7 +399,7 @@ func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() {
} }
// TryUnwrap tries to unwrap a wrapped trojan message. // TryUnwrap tries to unwrap a wrapped trojan message.
func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) { func (m *mpss) TryUnwrap(_ swarm.Chunk) {
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
......
...@@ -20,9 +20,8 @@ import ( ...@@ -20,9 +20,8 @@ import (
type store struct { type store struct {
storage.Storer storage.Storer
retrieval retrieval.Interface retrieval retrieval.Interface
validator swarm.Validator
logger logging.Logger logger logging.Logger
recoveryCallback recovery.RecoveryHook // this is the callback to be executed when a chunk fails to be retrieved recoveryCallback recovery.Callback // this is the callback to be executed when a chunk fails to be retrieved
} }
var ( var (
...@@ -30,9 +29,8 @@ var ( ...@@ -30,9 +29,8 @@ var (
) )
// New returns a new NetStore that wraps a given Storer. // New returns a new NetStore that wraps a given Storer.
func New(s storage.Storer, rcb recovery.RecoveryHook, r retrieval.Interface, logger logging.Logger, func New(s storage.Storer, rcb recovery.Callback, r retrieval.Interface, logger logging.Logger) storage.Storer {
validator swarm.Validator) storage.Storer { return &store{Storer: s, recoveryCallback: rcb, retrieval: r, logger: logger}
return &store{Storer: s, recoveryCallback: rcb, retrieval: r, logger: logger, validator: validator}
} }
// Get retrieves a given chunk address. // Get retrieves a given chunk address.
...@@ -44,20 +42,12 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres ...@@ -44,20 +42,12 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
// request from network // request from network
ch, err = s.retrieval.RetrieveChunk(ctx, addr) ch, err = s.retrieval.RetrieveChunk(ctx, addr)
if err != nil { if err != nil {
if s.recoveryCallback == nil {
return nil, err
}
targets := sctx.GetTargets(ctx) targets := sctx.GetTargets(ctx)
if targets != nil { if targets == nil || s.recoveryCallback == nil {
go func() { return nil, err
err := s.recoveryCallback(addr, targets)
if err != nil {
s.logger.Debugf("netstore: error while recovering chunk: %v", err)
}
}()
return nil, ErrRecoveryAttempt
} }
return nil, fmt.Errorf("netstore retrieve chunk: %w", err) go s.recoveryCallback(addr, targets)
return nil, ErrRecoveryAttempt
} }
_, err = s.Storer.Put(ctx, storage.ModePutRequest, ch) _, err = s.Storer.Put(ctx, storage.ModePutRequest, ch)
...@@ -70,15 +60,3 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres ...@@ -70,15 +60,3 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
} }
return ch, nil return ch, nil
} }
// Put stores a given chunk in the local storage.
// returns a storage.ErrInvalidChunk error when
// encountering an invalid chunk.
func (s *store) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err error) {
for _, ch := range chs {
if !s.validator.Validate(ch) {
return nil, storage.ErrInvalidChunk
}
}
return s.Storer.Put(ctx, mode, chs...)
}
...@@ -14,10 +14,10 @@ import ( ...@@ -14,10 +14,10 @@ import (
"testing" "testing"
"time" "time"
validatormock "github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/sctx" "github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
...@@ -98,12 +98,12 @@ func TestNetstoreNoRetrieval(t *testing.T) { ...@@ -98,12 +98,12 @@ func TestNetstoreNoRetrieval(t *testing.T) {
} }
func TestRecovery(t *testing.T) { func TestRecovery(t *testing.T) {
hookWasCalled := make(chan bool, 1) callbackWasCalled := make(chan bool, 1)
rec := &mockRecovery{ rec := &mockRecovery{
hookC: hookWasCalled, callbackC: callbackWasCalled,
} }
retrieve, _, nstore := newRetrievingNetstore(rec) retrieve, _, nstore := newRetrievingNetstore(rec.recovery)
addr := swarm.MustParseHexAddress("deadbeef") addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true retrieve.failure = true
ctx := context.Background() ctx := context.Background()
...@@ -115,10 +115,10 @@ func TestRecovery(t *testing.T) { ...@@ -115,10 +115,10 @@ func TestRecovery(t *testing.T) {
} }
select { select {
case <-hookWasCalled: case <-callbackWasCalled:
break break
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called") t.Fatal("recovery callback was not called")
} }
} }
...@@ -136,19 +136,11 @@ func TestInvalidRecoveryFunction(t *testing.T) { ...@@ -136,19 +136,11 @@ func TestInvalidRecoveryFunction(t *testing.T) {
} }
// returns a mock retrieval protocol, a mock local storage and a netstore // returns a mock retrieval protocol, a mock local storage and a netstore
func newRetrievingNetstore(rec *mockRecovery) (ret *retrievalMock, mockStore, ns storage.Storer) { func newRetrievingNetstore(rec recovery.Callback) (ret *retrievalMock, mockStore, ns storage.Storer) {
retrieve := &retrievalMock{} retrieve := &retrievalMock{}
store := mock.NewStorer() store := mock.NewStorer()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
validator := swarm.NewChunkValidator(validatormock.NewValidator(true)) return retrieve, store, netstore.New(store, rec, retrieve, logger)
var nstore storage.Storer
if rec != nil {
nstore = netstore.New(store, rec.recovery, retrieve, logger, validator)
} else {
nstore = netstore.New(store, nil, retrieve, logger, validator)
}
return retrieve, store, nstore
} }
type retrievalMock struct { type retrievalMock struct {
...@@ -169,13 +161,12 @@ func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) ( ...@@ -169,13 +161,12 @@ func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (
} }
type mockRecovery struct { type mockRecovery struct {
hookC chan bool callbackC chan bool
} }
// Send mocks the pss Send function // Send mocks the pss Send function
func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets pss.Targets) error { func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets pss.Targets) {
mr.hookC <- true mr.callbackC <- true
return nil
} }
func (r *mockRecovery) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) { func (r *mockRecovery) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
......
...@@ -334,7 +334,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -334,7 +334,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
} }
b.localstoreCloser = storer b.localstoreCloser = storer
chunkvalidator := swarm.NewChunkValidator(content.NewValidator(), soc.NewValidator()) chunkvalidator := swarm.NewMultiValidator([]swarm.Validator{content.NewValidator(), soc.NewValidator()})
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator, tracer) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator, tracer)
tagService := tags.NewTags(stateStore, logger) tagService := tags.NewTags(stateStore, logger)
...@@ -358,13 +358,15 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -358,13 +358,15 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
var ns storage.Storer var ns storage.Storer
if o.GlobalPinningEnabled { if o.GlobalPinningEnabled {
// create recovery callback for content repair // create recovery callback for content repair
recoverFunc := recovery.NewRecoveryHook(pssService) recoverFunc := recovery.NewCallback(pssService)
ns = netstore.New(storer, recoverFunc, retrieve, logger, chunkvalidator) ns = netstore.New(storer, recoverFunc, retrieve, logger)
} else { } else {
ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator) ns = netstore.New(storer, nil, retrieve, logger)
} }
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, pssService.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer) chunkvalidatorWithCallback := swarm.NewMultiValidator([]swarm.Validator{content.NewValidator(), soc.NewValidator()}, pssService.TryUnwrap)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, chunkvalidatorWithCallback, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
// set the pushSyncer in the PSS // set the pushSyncer in the PSS
pssService.SetPushSyncer(pushSyncProtocol) pssService.SetPushSyncer(pushSyncProtocol)
...@@ -376,7 +378,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -376,7 +378,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
if o.GlobalPinningEnabled { if o.GlobalPinningEnabled {
// register function for chunk repair upon receiving a trojan message // register function for chunk repair upon receiving a trojan message
chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol) chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol)
b.recoveryHandleCleanup = pssService.Register(recovery.RecoveryTopic, chunkRepairHandler) b.recoveryHandleCleanup = pssService.Register(recovery.Topic, chunkRepairHandler)
} }
pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagService, logger, tracer) pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagService, logger, tracer)
...@@ -384,7 +386,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -384,7 +386,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
pullStorage := pullstorage.New(storer) pullStorage := pullstorage.New(storer)
pullSync := pullsync.New(p2ps, pullStorage, logger) pullSync := pullsync.New(p2ps, pullStorage, chunkvalidator, logger)
b.pullSyncCloser = pullSync b.pullSyncCloser = pullSync
if err = p2ps.AddProtocol(pullSync.Protocol()); err != nil { if err = p2ps.AddProtocol(pullSync.Protocol()); err != nil {
......
...@@ -31,7 +31,7 @@ type Interface interface { ...@@ -31,7 +31,7 @@ type Interface interface {
// Register a Handler for a given Topic. // Register a Handler for a given Topic.
Register(Topic, Handler) func() Register(Topic, Handler) func()
// TryUnwrap tries to unwrap a wrapped trojan message. // TryUnwrap tries to unwrap a wrapped trojan message.
TryUnwrap(context.Context, swarm.Chunk) TryUnwrap(swarm.Chunk)
SetPushSyncer(pushSyncer pushsync.PushSyncer) SetPushSyncer(pushSyncer pushsync.PushSyncer)
io.Closer io.Closer
...@@ -128,10 +128,11 @@ func (p *pss) topics() []Topic { ...@@ -128,10 +128,11 @@ func (p *pss) topics() []Topic {
} }
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic. // TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic.
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) { func (p *pss) TryUnwrap(c swarm.Chunk) {
if len(c.Data()) < swarm.ChunkWithSpanSize { if len(c.Data()) < swarm.ChunkWithSpanSize {
return // chunk not full return // chunk not full
} }
ctx := context.Background()
topic, msg, err := Unwrap(ctx, p.key, c, p.topics()) topic, msg, err := Unwrap(ctx, p.key, c, p.topics())
if err != nil { if err != nil {
return // cannot unwrap return // cannot unwrap
......
...@@ -75,8 +75,6 @@ type topicMessage struct { ...@@ -75,8 +75,6 @@ type topicMessage struct {
// TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it // TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it
// results in the execution of the expected handler func // results in the execution of the expected handler func
func TestDeliver(t *testing.T) { func TestDeliver(t *testing.T) {
ctx := context.Background()
privkey, err := crypto.GenerateSecp256k1Key() privkey, err := crypto.GenerateSecp256k1Key()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -108,7 +106,7 @@ func TestDeliver(t *testing.T) { ...@@ -108,7 +106,7 @@ func TestDeliver(t *testing.T) {
p.Register(topic, handler) p.Register(topic, handler)
// call pss TryUnwrap on chunk and verify test topic variable value changes // call pss TryUnwrap on chunk and verify test topic variable value changes
p.TryUnwrap(ctx, chunk) p.TryUnwrap(chunk)
var message topicMessage var message topicMessage
select { select {
...@@ -172,7 +170,7 @@ func TestRegister(t *testing.T) { ...@@ -172,7 +170,7 @@ func TestRegister(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.TryUnwrap(context.Background(), chunk1) p.TryUnwrap(chunk1)
waitHandlerCallback(t, &msgChan, 1) waitHandlerCallback(t, &msgChan, 1)
...@@ -181,7 +179,7 @@ func TestRegister(t *testing.T) { ...@@ -181,7 +179,7 @@ func TestRegister(t *testing.T) {
// register another topic handler on the same topic // register another topic handler on the same topic
cleanup := p.Register(topic1, h3) cleanup := p.Register(topic1, h3)
p.TryUnwrap(context.Background(), chunk1) p.TryUnwrap(chunk1)
waitHandlerCallback(t, &msgChan, 2) waitHandlerCallback(t, &msgChan, 2)
...@@ -191,7 +189,7 @@ func TestRegister(t *testing.T) { ...@@ -191,7 +189,7 @@ func TestRegister(t *testing.T) {
cleanup() // remove the last handler cleanup() // remove the last handler
p.TryUnwrap(context.Background(), chunk1) p.TryUnwrap(chunk1)
waitHandlerCallback(t, &msgChan, 1) waitHandlerCallback(t, &msgChan, 1)
...@@ -203,7 +201,7 @@ func TestRegister(t *testing.T) { ...@@ -203,7 +201,7 @@ func TestRegister(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.TryUnwrap(context.Background(), chunk2) p.TryUnwrap(chunk2)
waitHandlerCallback(t, &msgChan, 1) waitHandlerCallback(t, &msgChan, 1)
......
...@@ -18,7 +18,6 @@ import ( ...@@ -18,7 +18,6 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pullsync/pb" "github.com/ethersphere/bee/pkg/pullsync/pb"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage" "github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -57,12 +56,13 @@ type Interface interface { ...@@ -57,12 +56,13 @@ type Interface interface {
} }
type Syncer struct { type Syncer struct {
streamer p2p.Streamer streamer p2p.Streamer
metrics metrics metrics metrics
logger logging.Logger logger logging.Logger
storage pullstorage.Storer storage pullstorage.Storer
quit chan struct{} quit chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
validator swarm.ValidatorWithCallback
ruidMtx sync.Mutex ruidMtx sync.Mutex
ruidCtx map[uint32]func() ruidCtx map[uint32]func()
...@@ -71,15 +71,16 @@ type Syncer struct { ...@@ -71,15 +71,16 @@ type Syncer struct {
io.Closer io.Closer
} }
func New(streamer p2p.Streamer, storage pullstorage.Storer, logger logging.Logger) *Syncer { func New(streamer p2p.Streamer, storage pullstorage.Storer, validator swarm.ValidatorWithCallback, logger logging.Logger) *Syncer {
return &Syncer{ return &Syncer{
streamer: streamer, streamer: streamer,
storage: storage, storage: storage,
metrics: newMetrics(), metrics: newMetrics(),
logger: logger, validator: validator,
ruidCtx: make(map[uint32]func()), logger: logger,
wg: sync.WaitGroup{}, ruidCtx: make(map[uint32]func()),
quit: make(chan struct{}), wg: sync.WaitGroup{},
quit: make(chan struct{}),
} }
} }
...@@ -212,9 +213,18 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8 ...@@ -212,9 +213,18 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8
delete(wantChunks, addr.String()) delete(wantChunks, addr.String())
s.metrics.DbOpsCounter.Inc() s.metrics.DbOpsCounter.Inc()
s.metrics.DeliveryCounter.Inc() s.metrics.DeliveryCounter.Inc()
if err = s.storage.Put(ctx, storage.ModePutSync, swarm.NewChunk(addr, delivery.Data)); err != nil {
chunk := swarm.NewChunk(addr, delivery.Data)
valid, callback := s.validator.ValidWithCallback(chunk)
if !valid {
return 0, ru.Ruid, swarm.ErrInvalidChunk
}
if err = s.storage.Put(ctx, storage.ModePutSync, chunk); err != nil {
return 0, ru.Ruid, fmt.Errorf("delivery put: %w", err) return 0, ru.Ruid, fmt.Errorf("delivery put: %w", err)
} }
if callback != nil {
go callback()
}
} }
return offer.Topmost, ru.Ruid, nil return offer.Topmost, ru.Ruid, nil
} }
......
...@@ -217,6 +217,20 @@ func haveChunks(t *testing.T, s *mock.PullStorage, addrs ...swarm.Address) { ...@@ -217,6 +217,20 @@ func haveChunks(t *testing.T, s *mock.PullStorage, addrs ...swarm.Address) {
func newPullSync(s p2p.Streamer, o ...mock.Option) (*pullsync.Syncer, *mock.PullStorage) { func newPullSync(s p2p.Streamer, o ...mock.Option) (*pullsync.Syncer, *mock.PullStorage) {
storage := mock.NewPullStorage(o...) storage := mock.NewPullStorage(o...)
c := make(chan swarm.Chunk)
validator := &mockValidator{c}
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
return pullsync.New(s, storage, logger), storage return pullsync.New(s, storage, validator, logger), storage
}
type mockValidator struct {
c chan swarm.Chunk
}
func (*mockValidator) Validate(swarm.Chunk) bool {
return true
}
func (mv *mockValidator) ValidWithCallback(c swarm.Chunk) (bool, func()) {
return true, func() { mv.c <- c }
} }
...@@ -38,32 +38,32 @@ type Receipt struct { ...@@ -38,32 +38,32 @@ type Receipt struct {
} }
type PushSync struct { type PushSync struct {
streamer p2p.Streamer streamer p2p.Streamer
storer storage.Putter storer storage.Putter
peerSuggester topology.ClosestPeerer peerSuggester topology.ClosestPeerer
tagg *tags.Tags tagger *tags.Tags
deliveryCallback func(context.Context, swarm.Chunk) // callback func to be invoked to deliver chunks to PSS validator swarm.ValidatorWithCallback
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer pricer accounting.Pricer
metrics metrics metrics metrics
tracer *tracing.Tracer tracer *tracing.Tracer
} }
var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk
func New(streamer p2p.Streamer, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, deliveryCallback func(context.Context, swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync { func New(streamer p2p.Streamer, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, validator swarm.ValidatorWithCallback, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync {
ps := &PushSync{ ps := &PushSync{
streamer: streamer, streamer: streamer,
storer: storer, storer: storer,
peerSuggester: closestPeerer, peerSuggester: closestPeerer,
tagg: tagger, tagger: tagger,
deliveryCallback: deliveryCallback, validator: validator,
logger: logger, logger: logger,
accounting: accounting, accounting: accounting,
pricer: pricer, pricer: pricer,
metrics: newMetrics(), metrics: newMetrics(),
tracer: tracer, tracer: tracer,
} }
return ps return ps
} }
...@@ -101,6 +101,13 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -101,6 +101,13 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
ps.metrics.ChunksReceivedCounter.Inc() ps.metrics.ChunksReceivedCounter.Inc()
chunk := swarm.NewChunk(swarm.NewAddress(ch.Address), ch.Data) chunk := swarm.NewChunk(swarm.NewAddress(ch.Address), ch.Data)
// validate the chunk and returns the delivery callback for the validator
valid, callback := ps.validator.ValidWithCallback(chunk)
if !valid {
return swarm.ErrInvalidChunk
}
span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()}) span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()})
defer span.Finish() defer span.Finish()
...@@ -109,6 +116,9 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -109,6 +116,9 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
if err != nil { if err != nil {
// If i am the closest peer then store the chunk and send receipt // If i am the closest peer then store the chunk and send receipt
if errors.Is(err, topology.ErrWantSelf) { if errors.Is(err, topology.ErrWantSelf) {
if callback != nil {
go callback()
}
return ps.handleDeliveryResponse(ctx, w, p, chunk) return ps.handleDeliveryResponse(ctx, w, p, chunk)
} }
return err return err
...@@ -223,7 +233,7 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re ...@@ -223,7 +233,7 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
if err != nil { if err != nil {
if errors.Is(err, topology.ErrWantSelf) { if errors.Is(err, topology.ErrWantSelf) {
// this is to make sure that the sent number does not diverge from the synced counter // this is to make sure that the sent number does not diverge from the synced counter
t, err := ps.tagg.Get(ch.TagID()) t, err := ps.tagger.Get(ch.TagID())
if err == nil && t != nil { if err == nil && t != nil {
err = t.Inc(tags.StateSent) err = t.Inc(tags.StateSent)
if err != nil { if err != nil {
...@@ -260,7 +270,7 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re ...@@ -260,7 +270,7 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
} }
// if you manage to get a tag, just increment the respective counter // if you manage to get a tag, just increment the respective counter
t, err := ps.tagg.Get(ch.TagID()) t, err := ps.tagger.Get(ch.TagID())
if err == nil && t != nil { if err == nil && t != nil {
err = t.Inc(tags.StateSent) err = t.Inc(tags.StateSent)
if err != nil { if err != nil {
...@@ -315,9 +325,5 @@ func (ps *PushSync) handleDeliveryResponse(ctx context.Context, w protobuf.Write ...@@ -315,9 +325,5 @@ func (ps *PushSync) handleDeliveryResponse(ctx context.Context, w protobuf.Write
return err return err
} }
if ps.deliveryCallback != nil {
ps.deliveryCallback(ctx, chunk)
}
return nil return nil
} }
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/pushsync/pb" "github.com/ethersphere/bee/pkg/pushsync/pb"
statestore "github.com/ethersphere/bee/pkg/statestore/mock" statestore "github.com/ethersphere/bee/pkg/statestore/mock"
mockvalidator "github.com/ethersphere/bee/pkg/storage/mock/validator"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
...@@ -41,10 +42,11 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) { ...@@ -41,10 +42,11 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) {
// create a pivot node and a mocked closest node // create a pivot node and a mocked closest node
pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000 pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000
closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 -> po 1 closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 -> po 1
validator := testValidator(chunkAddress, chunkData, nil)
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, validator, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close() defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()))
...@@ -99,17 +101,19 @@ func TestPushChunkToClosest(t *testing.T) { ...@@ -99,17 +101,19 @@ func TestPushChunkToClosest(t *testing.T) {
// create a pivot node and a mocked closest node // create a pivot node and a mocked closest node
pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000 pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000
closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 -> po 1 closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 -> po 1
callbackC := make(chan swarm.Chunk, 1)
validator := testValidator(chunkAddress, chunkData, callbackC)
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, validator, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close() defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()))
validator = testValidator(chunkAddress, chunkData, nil)
// pivot node needs the streamer since the chunk is intercepted by // pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream // the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, pivotTags, pivotAccounting := createPushSyncNode(t, pivotNode, recorder, nil, mock.WithClosestPeer(closestPeer)) psPivot, storerPivot, pivotTags, pivotAccounting := createPushSyncNode(t, pivotNode, recorder, validator, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close() defer storerPivot.Close()
ta, err := pivotTags.Create("test", 1) ta, err := pivotTags.Create("test", 1)
...@@ -168,6 +172,13 @@ func TestPushChunkToClosest(t *testing.T) { ...@@ -168,6 +172,13 @@ func TestPushChunkToClosest(t *testing.T) {
if balance != int64(fixedPrice) { if balance != int64(fixedPrice) {
t.Fatalf("unexpected balance on peer. want %d got %d", int64(fixedPrice), balance) t.Fatalf("unexpected balance on peer. want %d got %d", int64(fixedPrice), balance)
} }
// check if the pss delivery hook is called
select {
case <-callbackC:
case <-time.After(100 * time.Millisecond):
t.Fatalf("delivery hook was not called")
}
} }
// TestHandler expect a chunk from a node on a stream. It then stores the chunk in the local store and // TestHandler expect a chunk from a node on a stream. It then stores the chunk in the local store and
...@@ -186,27 +197,22 @@ func TestHandler(t *testing.T) { ...@@ -186,27 +197,22 @@ func TestHandler(t *testing.T) {
pivotPeer := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") pivotPeer := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000")
triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000")
closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000") closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000")
validator := testValidator(chunkAddress, chunkData, nil)
// mock call back function to see if pss message is delivered when it is received in the destination (closestPeer in this testcase)
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
pssDeliver := func(ctx context.Context, ch swarm.Chunk) {
hookWasCalled <- true
}
// Create the closest peer // Create the closest peer
psClosestPeer, closestStorerPeerDB, _, closestAccounting := createPushSyncNode(t, closestPeer, nil, pssDeliver, mock.WithClosestPeerErr(topology.ErrWantSelf)) psClosestPeer, closestStorerPeerDB, _, closestAccounting := createPushSyncNode(t, closestPeer, nil, validator, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer closestStorerPeerDB.Close() defer closestStorerPeerDB.Close()
closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol())) closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol()))
// creating the pivot peer // creating the pivot peer
psPivot, storerPivotDB, _, pivotAccounting := createPushSyncNode(t, pivotPeer, closestRecorder, nil, mock.WithClosestPeer(closestPeer)) psPivot, storerPivotDB, _, pivotAccounting := createPushSyncNode(t, pivotPeer, closestRecorder, validator, mock.WithClosestPeer(closestPeer))
defer storerPivotDB.Close() defer storerPivotDB.Close()
pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol())) pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol()))
// Creating the trigger peer // Creating the trigger peer
psTriggerPeer, triggerStorerDB, _, triggerAccounting := createPushSyncNode(t, triggerPeer, pivotRecorder, nil, mock.WithClosestPeer(pivotPeer)) psTriggerPeer, triggerStorerDB, _, triggerAccounting := createPushSyncNode(t, triggerPeer, pivotRecorder, validator, mock.WithClosestPeer(pivotPeer))
defer triggerStorerDB.Close() defer triggerStorerDB.Close()
receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk) receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk)
...@@ -231,14 +237,6 @@ func TestHandler(t *testing.T) { ...@@ -231,14 +237,6 @@ func TestHandler(t *testing.T) {
// In the received stream, check if a receipt is sent from pivot peer and check for its correctness. // In the received stream, check if a receipt is sent from pivot peer and check for its correctness.
waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunkAddress, nil) waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunkAddress, nil)
// check if the pss delivery hook is called
select {
case <-hookWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
}
balance, err := triggerAccounting.Balance(pivotPeer) balance, err := triggerAccounting.Balance(pivotPeer)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -277,7 +275,8 @@ func TestHandler(t *testing.T) { ...@@ -277,7 +275,8 @@ func TestHandler(t *testing.T) {
} }
} }
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, pssDeliver func(context.Context, swarm.Chunk), mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) { func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, validator swarm.ValidatorWithCallback, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) {
t.Helper()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
storer, err := localstore.New("", addr.Bytes(), nil, logger) storer, err := localstore.New("", addr.Bytes(), nil, logger)
...@@ -288,11 +287,10 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R ...@@ -288,11 +287,10 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R
mockTopology := mock.NewTopologyDriver(mockOpts...) mockTopology := mock.NewTopologyDriver(mockOpts...)
mockStatestore := statestore.NewStateStore() mockStatestore := statestore.NewStateStore()
mtag := tags.NewTags(mockStatestore, logger) mtag := tags.NewTags(mockStatestore, logger)
mockAccounting := accountingmock.NewAccounting() mockAccounting := accountingmock.NewAccounting()
mockPricer := accountingmock.NewPricer(fixedPrice, fixedPrice) mockPricer := accountingmock.NewPricer(fixedPrice, fixedPrice)
return pushsync.New(recorder, storer, mockTopology, mtag, pssDeliver, logger, mockAccounting, mockPricer, nil), storer, mtag, mockAccounting return pushsync.New(recorder, storer, mockTopology, mtag, validator, logger, mockAccounting, mockPricer, nil), storer, mtag, mockAccounting
} }
func waitOnRecordAndTest(t *testing.T, peer swarm.Address, recorder *streamtest.Recorder, add swarm.Address, data []byte) { func waitOnRecordAndTest(t *testing.T, peer swarm.Address, recorder *streamtest.Recorder, add swarm.Address, data []byte) {
...@@ -344,3 +342,11 @@ func waitOnRecordAndTest(t *testing.T, peer swarm.Address, recorder *streamtest. ...@@ -344,3 +342,11 @@ func waitOnRecordAndTest(t *testing.T, peer swarm.Address, recorder *streamtest.
} }
} }
} }
func testValidator(chunkAddress swarm.Address, chunkData []byte, callbackC chan swarm.Chunk) swarm.ValidatorWithCallback {
validators := []swarm.Validator{mockvalidator.NewMockValidator(chunkAddress, chunkData)}
if callbackC != nil {
return swarm.NewMultiValidator(validators, func(c swarm.Chunk) { callbackC <- c })
}
return swarm.NewMultiValidator(validators)
}
...@@ -16,27 +16,26 @@ import ( ...@@ -16,27 +16,26 @@ import (
) )
const ( const (
// RecoveryTopicText is the string used to construct the recovery topic. // TopicText is the string used to construct the recovery topic.
RecoveryTopicText = "RECOVERY" TopicText = "RECOVERY"
) )
var ( var (
// RecoveryTopic is the topic used for repairing globally pinned chunks. // Topic is the topic used for repairing globally pinned chunks.
RecoveryTopic = pss.NewTopic(RecoveryTopicText) Topic = pss.NewTopic(TopicText)
) )
// RecoveryHook defines code to be executed upon failing to retrieve chunks. // Callback defines code to be executed upon failing to retrieve chunks.
type RecoveryHook func(chunkAddress swarm.Address, targets pss.Targets) error type Callback func(chunkAddress swarm.Address, targets pss.Targets)
// NewRecoveryHook returns a new RecoveryHook with the sender function defined. // NewsCallback returns a new Callback with the sender function defined.
func NewRecoveryHook(pssSender pss.Sender) RecoveryHook { func NewCallback(pssSender pss.Sender) Callback {
privk := crypto.Secp256k1PrivateKeyFromBytes([]byte(RecoveryTopicText)) privk := crypto.Secp256k1PrivateKeyFromBytes([]byte(TopicText))
recipient := privk.PublicKey recipient := privk.PublicKey
return func(chunkAddress swarm.Address, targets pss.Targets) error { return func(chunkAddress swarm.Address, targets pss.Targets) {
payload := chunkAddress payload := chunkAddress
ctx := context.Background() ctx := context.Background()
err := pssSender.Send(ctx, RecoveryTopic, payload.Bytes(), &recipient, targets) _ = pssSender.Send(ctx, Topic, payload.Bytes(), &recipient, targets)
return err
} }
} }
......
...@@ -30,40 +30,39 @@ import ( ...@@ -30,40 +30,39 @@ import (
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
) )
// TestRecoveryHook tests that a recovery hook can be created and called. // TestCallback tests that a callback can be created and called.
func TestRecoveryHook(t *testing.T) { func TestCallback(t *testing.T) {
// test variables needed to be correctly set for any recovery hook to reach the sender func // test variables needed to be correctly set for any recovery callback to reach the sender func
chunkAddr := chunktesting.GenerateTestRandomChunk().Address() chunkAddr := chunktesting.GenerateTestRandomChunk().Address()
targets := pss.Targets{[]byte{0xED}} targets := pss.Targets{[]byte{0xED}}
//setup the sender //setup the sender
hookWasCalled := make(chan bool, 1) // channel to check if hook is called callbackWasCalled := make(chan bool) // channel to check if callback is called
pssSender := &mockPssSender{ pssSender := &mockPssSender{
hookC: hookWasCalled, callbackC: callbackWasCalled,
} }
// create recovery hook and call it // create recovery callback and call it
recoveryHook := recovery.NewRecoveryHook(pssSender) recoveryCallback := recovery.NewCallback(pssSender)
if err := recoveryHook(chunkAddr, targets); err != nil { go recoveryCallback(chunkAddr, targets)
t.Fatal(err)
}
select { select {
case <-hookWasCalled: case <-callbackWasCalled:
break break
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called") t.Fatal("recovery callback was not called")
} }
} }
// RecoveryHookTestCase is a struct used as test cases for the TestRecoveryHookCalls func. // CallbackTestCase is a struct used as test cases for the TestCallbackCalls func.
type recoveryHookTestCase struct { type recoveryCallbackTestCase struct {
name string name string
ctx context.Context ctx context.Context
expectsFailure bool expectsFailure bool
} }
// TestRecoveryHookCalls verifies that recovery hooks are being called as expected when net store attempts to get a chunk. // TestCallbackCalls verifies that recovery callbacks are being called as expected when net store attempts to get a chunk.
func TestRecoveryHookCalls(t *testing.T) { func TestCallbackCalls(t *testing.T) {
// generate test chunk, store and publisher // generate test chunk, store and publisher
c := chunktesting.GenerateTestRandomChunk() c := chunktesting.GenerateTestRandomChunk()
ref := c.Address() ref := c.Address()
...@@ -72,7 +71,7 @@ func TestRecoveryHookCalls(t *testing.T) { ...@@ -72,7 +71,7 @@ func TestRecoveryHookCalls(t *testing.T) {
// test cases variables // test cases variables
targetContext := sctx.SetTargets(context.Background(), target) targetContext := sctx.SetTargets(context.Background(), target)
for _, tc := range []recoveryHookTestCase{ for _, tc := range []recoveryCallbackTestCase{
{ {
name: "targets set in context", name: "targets set in context",
ctx: targetContext, ctx: targetContext,
...@@ -80,13 +79,13 @@ func TestRecoveryHookCalls(t *testing.T) { ...@@ -80,13 +79,13 @@ func TestRecoveryHookCalls(t *testing.T) {
}, },
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
hookWasCalled := make(chan bool, 1) // channel to check if hook is called callbackWasCalled := make(chan bool, 1) // channel to check if callback is called
// setup the sender // setup the sender
pssSender := &mockPssSender{ pssSender := &mockPssSender{
hookC: hookWasCalled, callbackC: callbackWasCalled,
} }
recoverFunc := recovery.NewRecoveryHook(pssSender) recoverFunc := recovery.NewCallback(pssSender)
ns := newTestNetStore(t, recoverFunc) ns := newTestNetStore(t, recoverFunc)
// fetch test chunk // fetch test chunk
...@@ -97,16 +96,16 @@ func TestRecoveryHookCalls(t *testing.T) { ...@@ -97,16 +96,16 @@ func TestRecoveryHookCalls(t *testing.T) {
// checks whether the callback is invoked or the test case times out // checks whether the callback is invoked or the test case times out
select { select {
case <-hookWasCalled: case <-callbackWasCalled:
if !tc.expectsFailure { if !tc.expectsFailure {
return return
} }
t.Fatal("recovery hook was unexpectedly called") t.Fatal("recovery callback was unexpectedly called")
case <-time.After(1000 * time.Millisecond): case <-time.After(1000 * time.Millisecond):
if tc.expectsFailure { if tc.expectsFailure {
return return
} }
t.Fatal("recovery hook was not called when expected") t.Fatal("recovery callback was not called when expected")
} }
}) })
} }
...@@ -212,7 +211,7 @@ func TestNewRepairHandler(t *testing.T) { ...@@ -212,7 +211,7 @@ func TestNewRepairHandler(t *testing.T) {
} }
// newTestNetStore creates a test store with a set RemoteGet func. // newTestNetStore creates a test store with a set RemoteGet func.
func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.Storer { func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Storer {
t.Helper() t.Helper()
storer := mock.NewStorer() storer := mock.NewStorer()
logger := logging.New(ioutil.Discard, 5) logger := logging.New(ioutil.Discard, 5)
...@@ -231,7 +230,7 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.S ...@@ -231,7 +230,7 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.S
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
)) ))
retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil, nil) retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil, nil)
ns := netstore.New(storer, recoveryFunc, retrieve, logger, nil) ns := netstore.New(storer, recoveryFunc, retrieve, logger)
return ns return ns
} }
...@@ -247,11 +246,11 @@ func (s mockPeerSuggester) EachPeerRev(f topology.EachPeerFunc) error { ...@@ -247,11 +246,11 @@ func (s mockPeerSuggester) EachPeerRev(f topology.EachPeerFunc) error {
} }
type mockPssSender struct { type mockPssSender struct {
hookC chan bool callbackC chan bool
} }
// Send mocks the pss Send function // Send mocks the pss Send function
func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []byte, recipient *ecdsa.PublicKey, targets pss.Targets) error { func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []byte, recipient *ecdsa.PublicKey, targets pss.Targets) error {
mp.hookC <- true mp.callbackC <- true
return nil return nil
} }
...@@ -31,7 +31,7 @@ var testTimeout = 5 * time.Second ...@@ -31,7 +31,7 @@ var testTimeout = 5 * time.Second
// TestDelivery tests that a naive request -> delivery flow works. // TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) { func TestDelivery(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
mockValidator := swarm.NewChunkValidator(mock.NewValidator(true)) mockValidator := mock.NewValidator(true)
mockStorer := storemock.NewStorer() mockStorer := storemock.NewStorer()
reqAddr, err := swarm.ParseHexAddress("00112233") reqAddr, err := swarm.ParseHexAddress("00112233")
if err != nil { if err != nil {
...@@ -135,7 +135,7 @@ func TestDelivery(t *testing.T) { ...@@ -135,7 +135,7 @@ func TestDelivery(t *testing.T) {
func TestRetrieveChunk(t *testing.T) { func TestRetrieveChunk(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
mockValidator := swarm.NewChunkValidator(mock.NewValidator(true)) mockValidator := mock.NewValidator(true)
pricer := accountingmock.NewPricer(1, 1) pricer := accountingmock.NewPricer(1, 1)
// requesting a chunk from downstream peer is expected // requesting a chunk from downstream peer is expected
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
...@@ -30,6 +31,10 @@ var ( ...@@ -30,6 +31,10 @@ var (
NewHasher = sha3.NewLegacyKeccak256 NewHasher = sha3.NewLegacyKeccak256
) )
var (
ErrInvalidChunk = errors.New("invalid chunk")
)
// Address represents an address in Swarm metric space of // Address represents an address in Swarm metric space of
// Node and Chunk addresses. // Node and Chunk addresses.
type Address struct { type Address struct {
...@@ -164,28 +169,3 @@ func (c *chunk) String() string { ...@@ -164,28 +169,3 @@ func (c *chunk) String() string {
func (c *chunk) Equal(cp Chunk) bool { func (c *chunk) Equal(cp Chunk) bool {
return c.Address().Equal(cp.Address()) && bytes.Equal(c.Data(), cp.Data()) return c.Address().Equal(cp.Address()) && bytes.Equal(c.Data(), cp.Data())
} }
type Validator interface {
Validate(ch Chunk) (valid bool)
}
type chunkValidator struct {
set []Validator
Validator
}
func NewChunkValidator(v ...Validator) Validator {
return &chunkValidator{
set: v,
}
}
func (c *chunkValidator) Validate(ch Chunk) bool {
for _, v := range c.set {
if v.Validate(ch) {
return true
}
}
return false
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package swarm
type Validator interface {
Validate(ch Chunk) (valid bool)
}
type ValidatorWithCallback interface {
ValidWithCallback(ch Chunk) (valid bool, callback func())
Validator
}
var _ Validator = (*validatorWithCallback)(nil)
type validatorWithCallback struct {
v Validator
callback func(Chunk)
}
func (v *validatorWithCallback) Validate(ch Chunk) bool {
valid := v.v.Validate(ch)
if valid {
go v.callback(ch)
}
return valid
}
var _ ValidatorWithCallback = (*multiValidator)(nil)
type multiValidator struct {
validators []Validator
callbacks []func(Chunk)
}
func NewMultiValidator(validators []Validator, callbacks ...func(Chunk)) ValidatorWithCallback {
return &multiValidator{validators, callbacks}
}
func (mv *multiValidator) Validate(ch Chunk) bool {
for _, v := range mv.validators {
if v.Validate(ch) {
return true
}
}
return false
}
func (mv *multiValidator) ValidWithCallback(ch Chunk) (bool, func()) {
for i, v := range mv.validators {
if v.Validate(ch) {
if i < len(mv.callbacks) {
return true, func() { mv.callbacks[i](ch) }
}
return true, nil
}
}
return false, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment