Commit d45fb176 authored by Viktor Trón's avatar Viktor Trón Committed by GitHub

Pss trojan length check and no error return in delivery callback (#754)

* pss,api,pushsync: TryUnwrap/deliveryCallBack returns no error

* pss: add length  check to TryUnwrap

* node: TryUnwrap/deliveryCallBack returns no error

* pushsync: TryUnwrap/deliveryCallBack returns no error

* pss: use constant and fix typo
parent f220749a
...@@ -62,10 +62,7 @@ func TestPssWebsocketSingleHandler(t *testing.T) { ...@@ -62,10 +62,7 @@ func TestPssWebsocketSingleHandler(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = p.TryUnwrap(context.Background(), tc) p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, payload, &mtx) waitMessage(t, msgContent, payload, &mtx)
} }
...@@ -103,10 +100,7 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) { ...@@ -103,10 +100,7 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = p.TryUnwrap(context.Background(), tc) p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx) waitMessage(t, msgContent, nil, &mtx)
} }
...@@ -149,10 +143,7 @@ func TestPssWebsocketMultiHandler(t *testing.T) { ...@@ -149,10 +143,7 @@ func TestPssWebsocketMultiHandler(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = p.TryUnwrap(context.Background(), tc) p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx) waitMessage(t, msgContent, nil, &mtx)
waitMessage(t, msgContent2, nil, &mtx) waitMessage(t, msgContent2, nil, &mtx)
...@@ -281,10 +272,7 @@ func TestPssPingPong(t *testing.T) { ...@@ -281,10 +272,7 @@ func TestPssPingPong(t *testing.T) {
time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
err = p.TryUnwrap(context.Background(), tc) p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx) waitMessage(t, msgContent, nil, &mtx)
} }
...@@ -408,7 +396,7 @@ func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() { ...@@ -408,7 +396,7 @@ func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() {
} }
// TryUnwrap tries to unwrap a wrapped trojan message. // TryUnwrap tries to unwrap a wrapped trojan message.
func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) error { func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) {
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
......
...@@ -360,15 +360,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -360,15 +360,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
} }
retrieve.SetStorer(ns) retrieve.SetStorer(ns)
silenceNoHandlerFunc := func(ctx context.Context, ch swarm.Chunk) error { pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
err := psss.TryUnwrap(ctx, ch)
if errors.Is(err, pss.ErrNoHandler) {
return nil
}
return err
}
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, silenceNoHandlerFunc, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
// set the pushSyncer in the PSS // set the pushSyncer in the PSS
psss.SetPushSyncer(pushSyncProtocol) psss.SetPushSyncer(pushSyncProtocol)
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"context" "context"
"crypto/ecdsa" "crypto/ecdsa"
"errors" "errors"
"fmt"
"io" "io"
"sync" "sync"
...@@ -32,7 +31,7 @@ type Interface interface { ...@@ -32,7 +31,7 @@ type Interface interface {
// Register a Handler for a given Topic. // Register a Handler for a given Topic.
Register(Topic, Handler) func() Register(Topic, Handler) func()
// TryUnwrap tries to unwrap a wrapped trojan message. // TryUnwrap tries to unwrap a wrapped trojan message.
TryUnwrap(context.Context, swarm.Chunk) error TryUnwrap(context.Context, swarm.Chunk)
SetPushSyncer(pushSyncer pushsync.PushSyncer) SetPushSyncer(pushSyncer pushsync.PushSyncer)
io.Closer io.Closer
...@@ -129,14 +128,17 @@ func (p *pss) topics() []Topic { ...@@ -129,14 +128,17 @@ func (p *pss) topics() []Topic {
} }
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic. // TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic.
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error { func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) {
if len(c.Data()) < swarm.ChunkWithSpanSize {
return // chunk not full
}
topic, msg, err := Unwrap(ctx, p.key, c, p.topics()) topic, msg, err := Unwrap(ctx, p.key, c, p.topics())
if err != nil { if err != nil {
return err return // cannot unwrap
} }
h := p.getHandlers(topic) h := p.getHandlers(topic)
if h == nil { if h == nil {
return fmt.Errorf("topic %v, %w", topic, ErrNoHandler) return // no handler
} }
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
...@@ -160,8 +162,6 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error { ...@@ -160,8 +162,6 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
wg.Wait() wg.Wait()
close(done) close(done)
}() }()
return nil
} }
func (p *pss) getHandlers(topic Topic) []*Handler { func (p *pss) getHandlers(topic Topic) []*Handler {
......
...@@ -108,13 +108,9 @@ func TestDeliver(t *testing.T) { ...@@ -108,13 +108,9 @@ func TestDeliver(t *testing.T) {
p.Register(topic, handler) p.Register(topic, handler)
// call pss TryUnwrap on chunk and verify test topic variable value changes // call pss TryUnwrap on chunk and verify test topic variable value changes
err = p.TryUnwrap(ctx, chunk) p.TryUnwrap(ctx, chunk)
if err != nil {
t.Fatal(err)
}
var message topicMessage var message topicMessage
select { select {
case message = <-msgChan: case message = <-msgChan:
break break
...@@ -176,10 +172,7 @@ func TestRegister(t *testing.T) { ...@@ -176,10 +172,7 @@ func TestRegister(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = p.TryUnwrap(context.Background(), chunk1) p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
waitHandlerCallback(t, &msgChan, 1) waitHandlerCallback(t, &msgChan, 1)
...@@ -188,10 +181,7 @@ func TestRegister(t *testing.T) { ...@@ -188,10 +181,7 @@ func TestRegister(t *testing.T) {
// register another topic handler on the same topic // register another topic handler on the same topic
cleanup := p.Register(topic1, h3) cleanup := p.Register(topic1, h3)
err = p.TryUnwrap(context.Background(), chunk1) p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
waitHandlerCallback(t, &msgChan, 2) waitHandlerCallback(t, &msgChan, 2)
...@@ -201,10 +191,7 @@ func TestRegister(t *testing.T) { ...@@ -201,10 +191,7 @@ func TestRegister(t *testing.T) {
cleanup() // remove the last handler cleanup() // remove the last handler
err = p.TryUnwrap(context.Background(), chunk1) p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
waitHandlerCallback(t, &msgChan, 1) waitHandlerCallback(t, &msgChan, 1)
...@@ -216,10 +203,7 @@ func TestRegister(t *testing.T) { ...@@ -216,10 +203,7 @@ func TestRegister(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = p.TryUnwrap(context.Background(), chunk2) p.TryUnwrap(context.Background(), chunk2)
if err != nil {
t.Fatal(err)
}
waitHandlerCallback(t, &msgChan, 1) waitHandlerCallback(t, &msgChan, 1)
......
...@@ -20,7 +20,7 @@ import ( ...@@ -20,7 +20,7 @@ import (
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/tracing" "github.com/ethersphere/bee/pkg/tracing"
"github.com/opentracing/opentracing-go" opentracing "github.com/opentracing/opentracing-go"
) )
const ( const (
...@@ -42,7 +42,7 @@ type PushSync struct { ...@@ -42,7 +42,7 @@ type PushSync struct {
storer storage.Putter storer storage.Putter
peerSuggester topology.ClosestPeerer peerSuggester topology.ClosestPeerer
tagg *tags.Tags tagg *tags.Tags
deliveryCallback func(context.Context, swarm.Chunk) error // callback func to be invoked to deliver chunks to PSS deliveryCallback func(context.Context, swarm.Chunk) // callback func to be invoked to deliver chunks to PSS
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer pricer accounting.Pricer
...@@ -52,7 +52,7 @@ type PushSync struct { ...@@ -52,7 +52,7 @@ type PushSync struct {
var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk
func New(streamer p2p.Streamer, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, deliveryCallback func(context.Context, swarm.Chunk) error, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync { func New(streamer p2p.Streamer, storer storage.Putter, closestPeerer topology.ClosestPeerer, tagger *tags.Tags, deliveryCallback func(context.Context, swarm.Chunk), logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, tracer *tracing.Tracer) *PushSync {
ps := &PushSync{ ps := &PushSync{
streamer: streamer, streamer: streamer,
storer: storer, storer: storer,
...@@ -310,10 +310,7 @@ func (ps *PushSync) handleDeliveryResponse(ctx context.Context, w protobuf.Write ...@@ -310,10 +310,7 @@ func (ps *PushSync) handleDeliveryResponse(ctx context.Context, w protobuf.Write
} }
if ps.deliveryCallback != nil { if ps.deliveryCallback != nil {
err = ps.deliveryCallback(ctx, chunk) ps.deliveryCallback(ctx, chunk)
if err != nil {
ps.logger.Debugf("pushsync delivery callback: %v", err)
}
} }
return nil return nil
......
...@@ -11,8 +11,6 @@ import ( ...@@ -11,8 +11,6 @@ import (
"testing" "testing"
"time" "time"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/accounting" "github.com/ethersphere/bee/pkg/accounting"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock" accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/localstore" "github.com/ethersphere/bee/pkg/localstore"
...@@ -21,6 +19,7 @@ import ( ...@@ -21,6 +19,7 @@ import (
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/pushsync/pb" "github.com/ethersphere/bee/pkg/pushsync/pb"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
...@@ -190,9 +189,8 @@ func TestHandler(t *testing.T) { ...@@ -190,9 +189,8 @@ func TestHandler(t *testing.T) {
// mock call back function to see if pss message is delivered when it is received in the destination (closestPeer in this testcase) // mock call back function to see if pss message is delivered when it is received in the destination (closestPeer in this testcase)
hookWasCalled := make(chan bool, 1) // channel to check if hook is called hookWasCalled := make(chan bool, 1) // channel to check if hook is called
pssDeliver := func(ctx context.Context, ch swarm.Chunk) error { pssDeliver := func(ctx context.Context, ch swarm.Chunk) {
hookWasCalled <- true hookWasCalled <- true
return nil
} }
// Create the closest peer // Create the closest peer
...@@ -279,7 +277,7 @@ func TestHandler(t *testing.T) { ...@@ -279,7 +277,7 @@ func TestHandler(t *testing.T) {
} }
} }
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, pssDeliver func(context.Context, swarm.Chunk) error, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) { func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, pssDeliver func(context.Context, swarm.Chunk), mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags, accounting.Interface) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
storer, err := localstore.New("", addr.Bytes(), nil, logger) storer, err := localstore.New("", addr.Bytes(), nil, logger)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment