Commit 09507fc3 authored by Viktor Trón's avatar Viktor Trón Committed by GitHub

validators/callbacks reorg to enable pss mailboxing (#942)

parent 828d95cd
......@@ -65,7 +65,7 @@ func TestPssWebsocketSingleHandler(t *testing.T) {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), tc)
p.TryUnwrap(tc)
waitMessage(t, msgContent, payload, &mtx)
}
......@@ -103,7 +103,7 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), tc)
p.TryUnwrap(tc)
waitMessage(t, msgContent, nil, &mtx)
}
......@@ -146,7 +146,7 @@ func TestPssWebsocketMultiHandler(t *testing.T) {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), tc)
p.TryUnwrap(tc)
waitMessage(t, msgContent, nil, &mtx)
waitMessage(t, msgContent2, nil, &mtx)
......@@ -275,7 +275,7 @@ func TestPssPingPong(t *testing.T) {
time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
p.TryUnwrap(context.Background(), tc)
p.TryUnwrap(tc)
waitMessage(t, msgContent, nil, &mtx)
}
......@@ -399,7 +399,7 @@ func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() {
}
// TryUnwrap tries to unwrap a wrapped trojan message.
func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) {
func (m *mpss) TryUnwrap(_ swarm.Chunk) {
panic("not implemented") // TODO: Implement
}
......
......@@ -20,9 +20,8 @@ import (
type store struct {
storage.Storer
retrieval retrieval.Interface
validator swarm.Validator
logger logging.Logger
recoveryCallback recovery.RecoveryHook // this is the callback to be executed when a chunk fails to be retrieved
recoveryCallback recovery.Callback // this is the callback to be executed when a chunk fails to be retrieved
}
var (
......@@ -30,9 +29,8 @@ var (
)
// New returns a new NetStore that wraps a given Storer.
func New(s storage.Storer, rcb recovery.RecoveryHook, r retrieval.Interface, logger logging.Logger,
validator swarm.Validator) storage.Storer {
return &store{Storer: s, recoveryCallback: rcb, retrieval: r, logger: logger, validator: validator}
func New(s storage.Storer, rcb recovery.Callback, r retrieval.Interface, logger logging.Logger) storage.Storer {
return &store{Storer: s, recoveryCallback: rcb, retrieval: r, logger: logger}
}
// Get retrieves a given chunk address.
......@@ -44,20 +42,12 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
// request from network
ch, err = s.retrieval.RetrieveChunk(ctx, addr)
if err != nil {
if s.recoveryCallback == nil {
return nil, err
}
targets := sctx.GetTargets(ctx)
if targets != nil {
go func() {
err := s.recoveryCallback(addr, targets)
if err != nil {
s.logger.Debugf("netstore: error while recovering chunk: %v", err)
}
}()
return nil, ErrRecoveryAttempt
if targets == nil || s.recoveryCallback == nil {
return nil, err
}
return nil, fmt.Errorf("netstore retrieve chunk: %w", err)
go s.recoveryCallback(addr, targets)
return nil, ErrRecoveryAttempt
}
_, err = s.Storer.Put(ctx, storage.ModePutRequest, ch)
......@@ -70,15 +60,3 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
}
return ch, nil
}
// Put stores a given chunk in the local storage.
// returns a storage.ErrInvalidChunk error when
// encountering an invalid chunk.
func (s *store) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err error) {
for _, ch := range chs {
if !s.validator.Validate(ch) {
return nil, storage.ErrInvalidChunk
}
}
return s.Storer.Put(ctx, mode, chs...)
}
......@@ -14,10 +14,10 @@ import (
"testing"
"time"
validatormock "github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
......@@ -98,12 +98,12 @@ func TestNetstoreNoRetrieval(t *testing.T) {
}
func TestRecovery(t *testing.T) {
hookWasCalled := make(chan bool, 1)
callbackWasCalled := make(chan bool, 1)
rec := &mockRecovery{
hookC: hookWasCalled,
callbackC: callbackWasCalled,
}
retrieve, _, nstore := newRetrievingNetstore(rec)
retrieve, _, nstore := newRetrievingNetstore(rec.recovery)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
......@@ -115,10 +115,10 @@ func TestRecovery(t *testing.T) {
}
select {
case <-hookWasCalled:
case <-callbackWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
t.Fatal("recovery callback was not called")
}
}
......@@ -136,19 +136,11 @@ func TestInvalidRecoveryFunction(t *testing.T) {
}
// returns a mock retrieval protocol, a mock local storage and a netstore
func newRetrievingNetstore(rec *mockRecovery) (ret *retrievalMock, mockStore, ns storage.Storer) {
func newRetrievingNetstore(rec recovery.Callback) (ret *retrievalMock, mockStore, ns storage.Storer) {
retrieve := &retrievalMock{}
store := mock.NewStorer()
logger := logging.New(ioutil.Discard, 0)
validator := swarm.NewChunkValidator(validatormock.NewValidator(true))
var nstore storage.Storer
if rec != nil {
nstore = netstore.New(store, rec.recovery, retrieve, logger, validator)
} else {
nstore = netstore.New(store, nil, retrieve, logger, validator)
}
return retrieve, store, nstore
return retrieve, store, netstore.New(store, rec, retrieve, logger)
}
type retrievalMock struct {
......@@ -169,13 +161,12 @@ func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (
}
type mockRecovery struct {
hookC chan bool
callbackC chan bool
}
// Send mocks the pss Send function
func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets pss.Targets) error {
mr.hookC <- true
return nil
func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets pss.Targets) {
mr.callbackC <- true
}
func (r *mockRecovery) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
......
......@@ -334,7 +334,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
}
b.localstoreCloser = storer
chunkvalidator := swarm.NewChunkValidator(content.NewValidator(), soc.NewValidator())
chunkvalidator := swarm.NewMultiValidator([]swarm.Validator{content.NewValidator(), soc.NewValidator()})
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator, tracer)
tagService := tags.NewTags(stateStore, logger)
......@@ -358,13 +358,15 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
var ns storage.Storer
if o.GlobalPinningEnabled {
// create recovery callback for content repair
recoverFunc := recovery.NewRecoveryHook(pssService)
ns = netstore.New(storer, recoverFunc, retrieve, logger, chunkvalidator)
recoverFunc := recovery.NewCallback(pssService)
ns = netstore.New(storer, recoverFunc, retrieve, logger)
} else {
ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator)
ns = netstore.New(storer, nil, retrieve, logger)
}
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, pssService.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
chunkvalidatorWithCallback := swarm.NewMultiValidator([]swarm.Validator{content.NewValidator(), soc.NewValidator()}, pssService.TryUnwrap)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, chunkvalidatorWithCallback, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
// set the pushSyncer in the PSS
pssService.SetPushSyncer(pushSyncProtocol)
......@@ -376,7 +378,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
if o.GlobalPinningEnabled {
// register function for chunk repair upon receiving a trojan message
chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol)
b.recoveryHandleCleanup = pssService.Register(recovery.RecoveryTopic, chunkRepairHandler)
b.recoveryHandleCleanup = pssService.Register(recovery.Topic, chunkRepairHandler)
}
pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagService, logger, tracer)
......@@ -384,7 +386,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
pullStorage := pullstorage.New(storer)
pullSync := pullsync.New(p2ps, pullStorage, logger)
pullSync := pullsync.New(p2ps, pullStorage, chunkvalidator, logger)
b.pullSyncCloser = pullSync
if err = p2ps.AddProtocol(pullSync.Protocol()); err != nil {
......
......@@ -31,7 +31,7 @@ type Interface interface {
// Register a Handler for a given Topic.
Register(Topic, Handler) func()
// TryUnwrap tries to unwrap a wrapped trojan message.
TryUnwrap(context.Context, swarm.Chunk)
TryUnwrap(swarm.Chunk)
SetPushSyncer(pushSyncer pushsync.PushSyncer)
io.Closer
......@@ -128,10 +128,11 @@ func (p *pss) topics() []Topic {
}
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic.
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) {
func (p *pss) TryUnwrap(c swarm.Chunk) {
if len(c.Data()) < swarm.ChunkWithSpanSize {
return // chunk not full
}
ctx := context.Background()
topic, msg, err := Unwrap(ctx, p.key, c, p.topics())
if err != nil {
return // cannot unwrap
......
......@@ -75,8 +75,6 @@ type topicMessage struct {
// TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it
// results in the execution of the expected handler func
func TestDeliver(t *testing.T) {
ctx := context.Background()
privkey, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
......@@ -108,7 +106,7 @@ func TestDeliver(t *testing.T) {
p.Register(topic, handler)
// call pss TryUnwrap on chunk and verify test topic variable value changes
p.TryUnwrap(ctx, chunk)
p.TryUnwrap(chunk)
var message topicMessage
select {
......@@ -172,7 +170,7 @@ func TestRegister(t *testing.T) {
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), chunk1)
p.TryUnwrap(chunk1)
waitHandlerCallback(t, &msgChan, 1)
......@@ -181,7 +179,7 @@ func TestRegister(t *testing.T) {
// register another topic handler on the same topic
cleanup := p.Register(topic1, h3)
p.TryUnwrap(context.Background(), chunk1)
p.TryUnwrap(chunk1)
waitHandlerCallback(t, &msgChan, 2)
......@@ -191,7 +189,7 @@ func TestRegister(t *testing.T) {
cleanup() // remove the last handler
p.TryUnwrap(context.Background(), chunk1)
p.TryUnwrap(chunk1)
waitHandlerCallback(t, &msgChan, 1)
......@@ -203,7 +201,7 @@ func TestRegister(t *testing.T) {
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), chunk2)
p.TryUnwrap(chunk2)
waitHandlerCallback(t, &msgChan, 1)
......
......@@ -18,7 +18,6 @@ import (
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pullsync/pb"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/storage"
......@@ -57,12 +56,13 @@ type Interface interface {
}
type Syncer struct {
streamer p2p.Streamer
metrics metrics
logger logging.Logger
storage pullstorage.Storer
quit chan struct{}
wg sync.WaitGroup
streamer p2p.Streamer
metrics metrics
logger logging.Logger
storage pullstorage.Storer
quit chan struct{}
wg sync.WaitGroup
validator swarm.ValidatorWithCallback
ruidMtx sync.Mutex
ruidCtx map[uint32]func()
......@@ -71,15 +71,16 @@ type Syncer struct {
io.Closer
}
func New(streamer p2p.Streamer, storage pullstorage.Storer, logger logging.Logger) *Syncer {
func New(streamer p2p.Streamer, storage pullstorage.Storer, validator swarm.ValidatorWithCallback, logger logging.Logger) *Syncer {
return &Syncer{
streamer: streamer,
storage: storage,
metrics: newMetrics(),
logger: logger,
ruidCtx: make(map[uint32]func()),
wg: sync.WaitGroup{},
quit: make(chan struct{}),
streamer: streamer,
storage: storage,
metrics: newMetrics(),
validator: validator,
logger: logger,
ruidCtx: make(map[uint32]func()),
wg: sync.WaitGroup{},
quit: make(chan struct{}),
}
}
......@@ -212,9 +213,18 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8
delete(wantChunks, addr.String())
s.metrics.DbOpsCounter.Inc()
s.metrics.DeliveryCounter.Inc()
if err = s.storage.Put(ctx, storage.ModePutSync, swarm.NewChunk(addr, delivery.Data)); err != nil {
chunk := swarm.NewChunk(addr, delivery.Data)
valid, callback := s.validator.ValidWithCallback(chunk)
if !valid {
return 0, ru.Ruid, swarm.ErrInvalidChunk
}
if err = s.storage.Put(ctx, storage.ModePutSync, chunk); err != nil {
return 0, ru.Ruid, fmt.Errorf("delivery put: %w", err)
}
if callback != nil {
go callback()
}
}
return offer.Topmost, ru.Ruid, nil
}
......
......@@ -217,6 +217,20 @@ func haveChunks(t *testing.T, s *mock.PullStorage, addrs ...swarm.Address) {
func newPullSync(s p2p.Streamer, o ...mock.Option) (*pullsync.Syncer, *mock.PullStorage) {
storage := mock.NewPullStorage(o...)
c := make(chan swarm.Chunk)
validator := &mockValidator{c}
logger := logging.New(ioutil.Discard, 0)
return pullsync.New(s, storage, logger), storage
return pullsync.New(s, storage, validator, logger), storage
}
type mockValidator struct {
c chan swarm.Chunk
}
func (*mockValidator) Validate(swarm.Chunk) bool {
return true
}
func (mv *mockValidator) ValidWithCallback(c swarm.Chunk) (bool, func()) {
return true, func() { mv.c <- c }
}
......@@ -38,32 +38,32 @@ type Receipt struct {
}
type PushSync struct {
streamer p2p.Streamer
storer storage.Putter
peerSuggester topology.ClosestPeerer
tagg *tags.Tags
deliveryCallback func(context.Context, swarm.Chunk) // callback func to be invoked to deliver chunks to PSS
logger logging.Logger
accounting accounting.Interface
pricer accounting.Pricer
metrics metrics
tracer *tracing.Tracer
streamer p2p.Streamer
storer storage.Putter
peerSuggester topology.ClosestPeerer
tagger *tags.Tags
validator swarm.ValidatorWithCallback
logger logging.Logger
accounting accounting.Interface
pricer accounting.Pricer
metrics metrics
tracer *tracing.Tracer
}
var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk
func New(streamer p2p.Streamer, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, deliveryCallback func(context.Context, swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync {
func New(streamer p2p.Streamer, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, validator swarm.ValidatorWithCallback, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync {
ps := &PushSync{
streamer: streamer,
storer: storer,
peerSuggester: closestPeerer,
tagg: tagger,
deliveryCallback: deliveryCallback,
logger: logger,
accounting: accounting,
pricer: pricer,
metrics: newMetrics(),
tracer: tracer,
streamer: streamer,
storer: storer,
peerSuggester: closestPeerer,
tagger: tagger,
validator: validator,
logger: logger,
accounting: accounting,
pricer: pricer,
metrics: newMetrics(),
tracer: tracer,
}
return ps
}
......@@ -101,6 +101,13 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
ps.metrics.ChunksReceivedCounter.Inc()
chunk := swarm.NewChunk(swarm.NewAddress(ch.Address), ch.Data)
// validate the chunk and returns the delivery callback for the validator
valid, callback := ps.validator.ValidWithCallback(chunk)
if !valid {
return swarm.ErrInvalidChunk
}
span, _, ctx := ps.tracer.StartSpanFromContext(ctx, "pushsync-handler", ps.logger, opentracing.Tag{Key: "address", Value: chunk.Address().String()})
defer span.Finish()
......@@ -109,6 +116,9 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
if err != nil {
// If i am the closest peer then store the chunk and send receipt
if errors.Is(err, topology.ErrWantSelf) {
if callback != nil {
go callback()
}
return ps.handleDeliveryResponse(ctx, w, p, chunk)
}
return err
......@@ -223,7 +233,7 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
if err != nil {
if errors.Is(err, topology.ErrWantSelf) {
// this is to make sure that the sent number does not diverge from the synced counter
t, err := ps.tagg.Get(ch.TagID())
t, err := ps.tagger.Get(ch.TagID())
if err == nil && t != nil {
err = t.Inc(tags.StateSent)
if err != nil {
......@@ -260,7 +270,7 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
}
// if you manage to get a tag, just increment the respective counter
t, err := ps.tagg.Get(ch.TagID())
t, err := ps.tagger.Get(ch.TagID())
if err == nil && t != nil {
err = t.Inc(tags.StateSent)
if err != nil {
......@@ -315,9 +325,5 @@ func (ps *PushSync) handleDeliveryResponse(ctx context.Context, w protobuf.Write
return err
}
if ps.deliveryCallback != nil {
ps.deliveryCallback(ctx, chunk)
}
return nil
}
......@@ -20,6 +20,7 @@ import (
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/pushsync/pb"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
mockvalidator "github.com/ethersphere/bee/pkg/storage/mock/validator"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology"
......@@ -41,10 +42,11 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) {
// create a pivot node and a mocked closest node
pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000
closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 -> po 1
validator := testValidator(chunkAddress, chunkData, nil)
// peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, validator, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()))
......@@ -99,17 +101,19 @@ func TestPushChunkToClosest(t *testing.T) {
// create a pivot node and a mocked closest node
pivotNode := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000") // base is 0000
closestPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") // binary 0110 -> po 1
callbackC := make(chan swarm.Chunk, 1)
validator := testValidator(chunkAddress, chunkData, callbackC)
// peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
psPeer, storerPeer, _, peerAccounting := createPushSyncNode(t, closestPeer, nil, validator, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()))
validator = testValidator(chunkAddress, chunkData, nil)
// pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, pivotTags, pivotAccounting := createPushSyncNode(t, pivotNode, recorder, nil, mock.WithClosestPeer(closestPeer))
psPivot, storerPivot, pivotTags, pivotAccounting := createPushSyncNode(t, pivotNode, recorder, validator, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close()
ta, err := pivotTags.Create("test", 1)
......@@ -168,6 +172,13 @@ func TestPushChunkToClosest(t *testing.T) {
if balance != int64(fixedPrice) {
t.Fatalf("unexpected balance on peer. want %d got %d", int64(fixedPrice), balance)
}
// check if the pss delivery hook is called
select {
case <-callbackC:
case <-time.After(100 * time.Millisecond):
t.Fatalf("delivery hook was not called")
}
}
// TestHandler expect a chunk from a node on a stream. It then stores the chunk in the local store and
......@@ -186,27 +197,22 @@ func TestHandler(t *testing.T) {
pivotPeer := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000000")
triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000")
closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000")
// mock call back function to see if pss message is delivered when it is received in the destination (closestPeer in this testcase)
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
pssDeliver := func(ctx context.Context, ch swarm.Chunk) {
hookWasCalled <- true
}
validator := testValidator(chunkAddress, chunkData, nil)
// Create the closest peer
psClosestPeer, closestStorerPeerDB, _, closestAccounting := createPushSyncNode(t, closestPeer, nil, pssDeliver, mock.WithClosestPeerErr(topology.ErrWantSelf))
psClosestPeer, closestStorerPeerDB, _, closestAccounting := createPushSyncNode(t, closestPeer, nil, validator, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer closestStorerPeerDB.Close()
closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol()))
// creating the pivot peer
psPivot, storerPivotDB, _, pivotAccounting := createPushSyncNode(t, pivotPeer, closestRecorder, nil, mock.WithClosestPeer(closestPeer))
psPivot, storerPivotDB, _, pivotAccounting := createPushSyncNode(t, pivotPeer, closestRecorder, validator, mock.WithClosestPeer(closestPeer))
defer storerPivotDB.Close()
pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol()))
// Creating the trigger peer
psTriggerPeer, triggerStorerDB, _, triggerAccounting := createPushSyncNode(t, triggerPeer, pivotRecorder, nil, mock.WithClosestPeer(pivotPeer))
psTriggerPeer, triggerStorerDB, _, triggerAccounting := createPushSyncNode(t, triggerPeer, pivotRecorder, validator, mock.WithClosestPeer(pivotPeer))
defer triggerStorerDB.Close()
receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk)
......@@ -231,14 +237,6 @@ func TestHandler(t *testing.T) {
// In the received stream, check if a receipt is sent from pivot peer and check for its correctness.
waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunkAddress, nil)
// check if the pss delivery hook is called
select {
case <-hookWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
}
balance, err := triggerAccounting.Balance(pivotPeer)
if err != nil {
t.Fatal(err)
......@@ -277,7 +275,8 @@ func TestHandler(t *testing.T) {
}
}
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, pssDeliver func(context.Context, swarm.Chunk), mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) {
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, validator swarm.ValidatorWithCallback, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) {
t.Helper()
logger := logging.New(ioutil.Discard, 0)
storer, err := localstore.New("", addr.Bytes(), nil, logger)
......@@ -288,11 +287,10 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R
mockTopology := mock.NewTopologyDriver(mockOpts...)
mockStatestore := statestore.NewStateStore()
mtag := tags.NewTags(mockStatestore, logger)
mockAccounting := accountingmock.NewAccounting()
mockPricer := accountingmock.NewPricer(fixedPrice, fixedPrice)
return pushsync.New(recorder, storer, mockTopology, mtag, pssDeliver, logger, mockAccounting, mockPricer, nil), storer, mtag, mockAccounting
return pushsync.New(recorder, storer, mockTopology, mtag, validator, logger, mockAccounting, mockPricer, nil), storer, mtag, mockAccounting
}
func waitOnRecordAndTest(t *testing.T, peer swarm.Address, recorder *streamtest.Recorder, add swarm.Address, data []byte) {
......@@ -344,3 +342,11 @@ func waitOnRecordAndTest(t *testing.T, peer swarm.Address, recorder *streamtest.
}
}
}
func testValidator(chunkAddress swarm.Address, chunkData []byte, callbackC chan swarm.Chunk) swarm.ValidatorWithCallback {
validators := []swarm.Validator{mockvalidator.NewMockValidator(chunkAddress, chunkData)}
if callbackC != nil {
return swarm.NewMultiValidator(validators, func(c swarm.Chunk) { callbackC <- c })
}
return swarm.NewMultiValidator(validators)
}
......@@ -16,27 +16,26 @@ import (
)
const (
// RecoveryTopicText is the string used to construct the recovery topic.
RecoveryTopicText = "RECOVERY"
// TopicText is the string used to construct the recovery topic.
TopicText = "RECOVERY"
)
var (
// RecoveryTopic is the topic used for repairing globally pinned chunks.
RecoveryTopic = pss.NewTopic(RecoveryTopicText)
// Topic is the topic used for repairing globally pinned chunks.
Topic = pss.NewTopic(TopicText)
)
// RecoveryHook defines code to be executed upon failing to retrieve chunks.
type RecoveryHook func(chunkAddress swarm.Address, targets pss.Targets) error
// Callback defines code to be executed upon failing to retrieve chunks.
type Callback func(chunkAddress swarm.Address, targets pss.Targets)
// NewRecoveryHook returns a new RecoveryHook with the sender function defined.
func NewRecoveryHook(pssSender pss.Sender) RecoveryHook {
privk := crypto.Secp256k1PrivateKeyFromBytes([]byte(RecoveryTopicText))
// NewsCallback returns a new Callback with the sender function defined.
func NewCallback(pssSender pss.Sender) Callback {
privk := crypto.Secp256k1PrivateKeyFromBytes([]byte(TopicText))
recipient := privk.PublicKey
return func(chunkAddress swarm.Address, targets pss.Targets) error {
return func(chunkAddress swarm.Address, targets pss.Targets) {
payload := chunkAddress
ctx := context.Background()
err := pssSender.Send(ctx, RecoveryTopic, payload.Bytes(), &recipient, targets)
return err
_ = pssSender.Send(ctx, Topic, payload.Bytes(), &recipient, targets)
}
}
......
......@@ -30,40 +30,39 @@ import (
"github.com/ethersphere/bee/pkg/topology"
)
// TestRecoveryHook tests that a recovery hook can be created and called.
func TestRecoveryHook(t *testing.T) {
// test variables needed to be correctly set for any recovery hook to reach the sender func
// TestCallback tests that a callback can be created and called.
func TestCallback(t *testing.T) {
// test variables needed to be correctly set for any recovery callback to reach the sender func
chunkAddr := chunktesting.GenerateTestRandomChunk().Address()
targets := pss.Targets{[]byte{0xED}}
//setup the sender
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
callbackWasCalled := make(chan bool) // channel to check if callback is called
pssSender := &mockPssSender{
hookC: hookWasCalled,
callbackC: callbackWasCalled,
}
// create recovery hook and call it
recoveryHook := recovery.NewRecoveryHook(pssSender)
if err := recoveryHook(chunkAddr, targets); err != nil {
t.Fatal(err)
}
// create recovery callback and call it
recoveryCallback := recovery.NewCallback(pssSender)
go recoveryCallback(chunkAddr, targets)
select {
case <-hookWasCalled:
case <-callbackWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
t.Fatal("recovery callback was not called")
}
}
// RecoveryHookTestCase is a struct used as test cases for the TestRecoveryHookCalls func.
type recoveryHookTestCase struct {
// CallbackTestCase is a struct used as test cases for the TestCallbackCalls func.
type recoveryCallbackTestCase struct {
name string
ctx context.Context
expectsFailure bool
}
// TestRecoveryHookCalls verifies that recovery hooks are being called as expected when net store attempts to get a chunk.
func TestRecoveryHookCalls(t *testing.T) {
// TestCallbackCalls verifies that recovery callbacks are being called as expected when net store attempts to get a chunk.
func TestCallbackCalls(t *testing.T) {
// generate test chunk, store and publisher
c := chunktesting.GenerateTestRandomChunk()
ref := c.Address()
......@@ -72,7 +71,7 @@ func TestRecoveryHookCalls(t *testing.T) {
// test cases variables
targetContext := sctx.SetTargets(context.Background(), target)
for _, tc := range []recoveryHookTestCase{
for _, tc := range []recoveryCallbackTestCase{
{
name: "targets set in context",
ctx: targetContext,
......@@ -80,13 +79,13 @@ func TestRecoveryHookCalls(t *testing.T) {
},
} {
t.Run(tc.name, func(t *testing.T) {
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
callbackWasCalled := make(chan bool, 1) // channel to check if callback is called
// setup the sender
pssSender := &mockPssSender{
hookC: hookWasCalled,
callbackC: callbackWasCalled,
}
recoverFunc := recovery.NewRecoveryHook(pssSender)
recoverFunc := recovery.NewCallback(pssSender)
ns := newTestNetStore(t, recoverFunc)
// fetch test chunk
......@@ -97,16 +96,16 @@ func TestRecoveryHookCalls(t *testing.T) {
// checks whether the callback is invoked or the test case times out
select {
case <-hookWasCalled:
case <-callbackWasCalled:
if !tc.expectsFailure {
return
}
t.Fatal("recovery hook was unexpectedly called")
t.Fatal("recovery callback was unexpectedly called")
case <-time.After(1000 * time.Millisecond):
if tc.expectsFailure {
return
}
t.Fatal("recovery hook was not called when expected")
t.Fatal("recovery callback was not called when expected")
}
})
}
......@@ -212,7 +211,7 @@ func TestNewRepairHandler(t *testing.T) {
}
// newTestNetStore creates a test store with a set RemoteGet func.
func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.Storer {
func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Storer {
t.Helper()
storer := mock.NewStorer()
logger := logging.New(ioutil.Discard, 5)
......@@ -231,7 +230,7 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.S
streamtest.WithProtocols(server.Protocol()),
))
retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil, nil)
ns := netstore.New(storer, recoveryFunc, retrieve, logger, nil)
ns := netstore.New(storer, recoveryFunc, retrieve, logger)
return ns
}
......@@ -247,11 +246,11 @@ func (s mockPeerSuggester) EachPeerRev(f topology.EachPeerFunc) error {
}
type mockPssSender struct {
hookC chan bool
callbackC chan bool
}
// Send mocks the pss Send function
func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []byte, recipient *ecdsa.PublicKey, targets pss.Targets) error {
mp.hookC <- true
mp.callbackC <- true
return nil
}
......@@ -31,7 +31,7 @@ var testTimeout = 5 * time.Second
// TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
mockValidator := swarm.NewChunkValidator(mock.NewValidator(true))
mockValidator := mock.NewValidator(true)
mockStorer := storemock.NewStorer()
reqAddr, err := swarm.ParseHexAddress("00112233")
if err != nil {
......@@ -135,7 +135,7 @@ func TestDelivery(t *testing.T) {
func TestRetrieveChunk(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
mockValidator := swarm.NewChunkValidator(mock.NewValidator(true))
mockValidator := mock.NewValidator(true)
pricer := accountingmock.NewPricer(1, 1)
// requesting a chunk from downstream peer is expected
......
......@@ -9,6 +9,7 @@ import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"golang.org/x/crypto/sha3"
......@@ -30,6 +31,10 @@ var (
NewHasher = sha3.NewLegacyKeccak256
)
var (
ErrInvalidChunk = errors.New("invalid chunk")
)
// Address represents an address in Swarm metric space of
// Node and Chunk addresses.
type Address struct {
......@@ -164,28 +169,3 @@ func (c *chunk) String() string {
func (c *chunk) Equal(cp Chunk) bool {
return c.Address().Equal(cp.Address()) && bytes.Equal(c.Data(), cp.Data())
}
type Validator interface {
Validate(ch Chunk) (valid bool)
}
type chunkValidator struct {
set []Validator
Validator
}
func NewChunkValidator(v ...Validator) Validator {
return &chunkValidator{
set: v,
}
}
func (c *chunkValidator) Validate(ch Chunk) bool {
for _, v := range c.set {
if v.Validate(ch) {
return true
}
}
return false
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package swarm
type Validator interface {
Validate(ch Chunk) (valid bool)
}
type ValidatorWithCallback interface {
ValidWithCallback(ch Chunk) (valid bool, callback func())
Validator
}
var _ Validator = (*validatorWithCallback)(nil)
type validatorWithCallback struct {
v Validator
callback func(Chunk)
}
func (v *validatorWithCallback) Validate(ch Chunk) bool {
valid := v.v.Validate(ch)
if valid {
go v.callback(ch)
}
return valid
}
var _ ValidatorWithCallback = (*multiValidator)(nil)
type multiValidator struct {
validators []Validator
callbacks []func(Chunk)
}
func NewMultiValidator(validators []Validator, callbacks ...func(Chunk)) ValidatorWithCallback {
return &multiValidator{validators, callbacks}
}
func (mv *multiValidator) Validate(ch Chunk) bool {
for _, v := range mv.validators {
if v.Validate(ch) {
return true
}
}
return false
}
func (mv *multiValidator) ValidWithCallback(ch Chunk) (bool, func()) {
for i, v := range mv.validators {
if v.Validate(ch) {
if i < len(mv.callbacks) {
return true, func() { mv.callbacks[i](ch) }
}
return true, nil
}
}
return false, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment