Commit d45fb176 authored by Viktor Trón's avatar Viktor Trón Committed by GitHub

Pss trojan length check and no error return in delivery callback (#754)

* pss,api,pushsync: TryUnwrap/deliveryCallBack returns no error

* pss: add length  check to TryUnwrap

* node: TryUnwrap/deliveryCallBack returns no error

* pushsync: TryUnwrap/deliveryCallBack returns no error

* pss: use constant and fix typo
parent f220749a
......@@ -62,10 +62,7 @@ func TestPssWebsocketSingleHandler(t *testing.T) {
t.Fatal(err)
}
err = p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), tc)
waitMessage(t, msgContent, payload, &mtx)
}
......@@ -103,10 +100,7 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
t.Fatal(err)
}
err = p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), tc)
waitMessage(t, msgContent, nil, &mtx)
}
......@@ -149,10 +143,7 @@ func TestPssWebsocketMultiHandler(t *testing.T) {
t.Fatal(err)
}
err = p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), tc)
waitMessage(t, msgContent, nil, &mtx)
waitMessage(t, msgContent2, nil, &mtx)
......@@ -281,10 +272,7 @@ func TestPssPingPong(t *testing.T) {
time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
err = p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), tc)
waitMessage(t, msgContent, nil, &mtx)
}
......@@ -408,7 +396,7 @@ func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() {
}
// TryUnwrap tries to unwrap a wrapped trojan message.
func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) error {
func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) {
panic("not implemented") // TODO: Implement
}
......
......@@ -360,15 +360,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
}
retrieve.SetStorer(ns)
silenceNoHandlerFunc := func(ctx context.Context, ch swarm.Chunk) error {
err := psss.TryUnwrap(ctx, ch)
if errors.Is(err, pss.ErrNoHandler) {
return nil
}
return err
}
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, silenceNoHandlerFunc, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
// set the pushSyncer in the PSS
psss.SetPushSyncer(pushSyncProtocol)
......
......@@ -8,7 +8,6 @@ import (
"context"
"crypto/ecdsa"
"errors"
"fmt"
"io"
"sync"
......@@ -32,7 +31,7 @@ type Interface interface {
// Register a Handler for a given Topic.
Register(Topic, Handler) func()
// TryUnwrap tries to unwrap a wrapped trojan message.
TryUnwrap(context.Context, swarm.Chunk) error
TryUnwrap(context.Context, swarm.Chunk)
SetPushSyncer(pushSyncer pushsync.PushSyncer)
io.Closer
......@@ -129,14 +128,17 @@ func (p *pss) topics() []Topic {
}
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic.
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) {
if len(c.Data()) < swarm.ChunkWithSpanSize {
return // chunk not full
}
topic, msg, err := Unwrap(ctx, p.key, c, p.topics())
if err != nil {
return err
return // cannot unwrap
}
h := p.getHandlers(topic)
if h == nil {
return fmt.Errorf("topic %v, %w", topic, ErrNoHandler)
return // no handler
}
ctx, cancel := context.WithCancel(ctx)
......@@ -160,8 +162,6 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
wg.Wait()
close(done)
}()
return nil
}
func (p *pss) getHandlers(topic Topic) []*Handler {
......
......@@ -108,13 +108,9 @@ func TestDeliver(t *testing.T) {
p.Register(topic, handler)
// call pss TryUnwrap on chunk and verify test topic variable value changes
err = p.TryUnwrap(ctx, chunk)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(ctx, chunk)
var message topicMessage
select {
case message = <-msgChan:
break
......@@ -176,10 +172,7 @@ func TestRegister(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), chunk1)
waitHandlerCallback(t, &msgChan, 1)
......@@ -188,10 +181,7 @@ func TestRegister(t *testing.T) {
// register another topic handler on the same topic
cleanup := p.Register(topic1, h3)
err = p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), chunk1)
waitHandlerCallback(t, &msgChan, 2)
......@@ -201,10 +191,7 @@ func TestRegister(t *testing.T) {
cleanup() // remove the last handler
err = p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), chunk1)
waitHandlerCallback(t, &msgChan, 1)
......@@ -216,10 +203,7 @@ func TestRegister(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = p.TryUnwrap(context.Background(), chunk2)
if err != nil {
t.Fatal(err)
}
p.TryUnwrap(context.Background(), chunk2)
waitHandlerCallback(t, &msgChan, 1)
......
......@@ -20,7 +20,7 @@ import (
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/tracing"
"github.com/opentracing/opentracing-go"
opentracing "github.com/opentracing/opentracing-go"
)
const (
......@@ -42,7 +42,7 @@ type PushSync struct {
storer storage.Putter
peerSuggester topology.ClosestPeerer
tagg *tags.Tags
deliveryCallback func(context.Context, swarm.Chunk) error // callback func to be invoked to deliver chunks to PSS
deliveryCallback func(context.Context, swarm.Chunk) // callback func to be invoked to deliver chunks to PSS
logger logging.Logger
accounting accounting.Interface
pricer accounting.Pricer
......@@ -52,7 +52,7 @@ type PushSync struct {
var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk
func New(streamer p2p.Streamer, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, deliveryCallback func(context.Context, swarm.Chunk) error, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync {
func New(streamer p2p.Streamer, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, deliveryCallback func(context.Context, swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync {
ps := &PushSync{
streamer: streamer,
storer: storer,
......@@ -310,10 +310,7 @@ func (ps *PushSync) handleDeliveryResponse(ctx context.Context, w protobuf.Write
}
if ps.deliveryCallback != nil {
err = ps.deliveryCallback(ctx, chunk)
if err != nil {
ps.logger.Debugf("pushsync delivery callback: %v", err)
}
ps.deliveryCallback(ctx, chunk)
}
return nil
......
......@@ -11,8 +11,6 @@ import (
"testing"
"time"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/accounting"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/localstore"
......@@ -21,6 +19,7 @@ import (
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/pushsync/pb"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology"
......@@ -190,9 +189,8 @@ func TestHandler(t *testing.T) {
// mock call back function to see if pss message is delivered when it is received in the destination (closestPeer in this testcase)
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
pssDeliver := func(ctx context.Context, ch swarm.Chunk) error {
pssDeliver := func(ctx context.Context, ch swarm.Chunk) {
hookWasCalled <- true
return nil
}
// Create the closest peer
......@@ -279,7 +277,7 @@ func TestHandler(t *testing.T) {
}
}
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, pssDeliver func(context.Context, swarm.Chunk) error, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) {
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, pssDeliver func(context.Context, swarm.Chunk), mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) {
logger := logging.New(ioutil.Discard, 0)
storer, err := localstore.New("", addr.Bytes(), nil, logger)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment