Commit aed7f865 authored by Viktor Trón's avatar Viktor Trón Committed by GitHub

Pss encryption and topic matching, refactor, cleaner test (#737)

parent 0722fbce
......@@ -163,6 +163,7 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7
github.com/ethereum/go-ethereum v1.9.14/go.mod h1:oP8FC5+TbICUyftkTWs+8JryntjIJLJvWvApK3z2AYw=
github.com/ethereum/go-ethereum v1.9.20 h1:kk/J5OIoaoz3DRrCXznz3RGi212mHHXwzXlY/ZQxcj0=
github.com/ethereum/go-ethereum v1.9.20/go.mod h1:JSSTypSMTkGZtAdAChH2wP5dZEvPGh3nUTuDpH+hNrg=
github.com/ethereum/go-ethereum v1.9.21 h1:8qRlhzrItnmUGdVlBzZLI2Tb46S0RdSNjFwICo781ws=
github.com/ethersphere/bmt v0.1.2 h1:FEuvQY9xuK+rDp3VwDVyde8T396Matv/u9PdtKa2r9Q=
github.com/ethersphere/bmt v0.1.2/go.mod h1:fqRBDmYwn3lX2MH4lkImXQgFWeNP8ikLkS/hgi/HRws=
github.com/ethersphere/langos v1.0.0 h1:NBtNKzXTTRSue95uOlzPN4py7Aofs0xWPzyj4AI1Vcc=
......
......@@ -6,15 +6,17 @@ package api
import (
"context"
"crypto/ecdsa"
"encoding/hex"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
......@@ -28,18 +30,13 @@ var (
targetMaxLength = 2 // max target length in bytes, in order to prevent grieving by excess computation
)
type PssMessage struct {
Topic string
Message string
}
func (s *server) pssPostHandler(w http.ResponseWriter, r *http.Request) {
t := mux.Vars(r)["topic"]
topic := trojan.NewTopic(t)
topicVar := mux.Vars(r)["topic"]
topic := pss.NewTopic(topicVar)
tg := mux.Vars(r)["targets"]
var targets trojan.Targets
tgts := strings.Split(tg, ",")
targetsVar := mux.Vars(r)["targets"]
var targets pss.Targets
tgts := strings.Split(targetsVar, ",")
for _, v := range tgts {
target, err := hex.DecodeString(v)
......@@ -52,6 +49,23 @@ func (s *server) pssPostHandler(w http.ResponseWriter, r *http.Request) {
targets = append(targets, target)
}
recipientQueryString := r.URL.Query().Get("recipient")
var recipient *ecdsa.PublicKey
if recipientQueryString == "" {
// use topic-based encryption
privkey := crypto.Secp256k1PrivateKeyFromBytes(topic[:])
recipient = &privkey.PublicKey
} else {
var err error
recipient, err = pss.ParseRecipient(recipientQueryString)
if err != nil {
s.Logger.Debugf("pss recipient: %v", err)
s.Logger.Error("pss recipient")
jsonhttp.BadRequest(w, nil)
return
}
}
payload, err := ioutil.ReadAll(r.Body)
if err != nil {
s.Logger.Debugf("pss read payload: %v", err)
......@@ -60,9 +74,9 @@ func (s *server) pssPostHandler(w http.ResponseWriter, r *http.Request) {
return
}
err = s.Pss.Send(r.Context(), targets, topic, payload)
err = s.Pss.Send(r.Context(), topic, payload, recipient, targets)
if err != nil {
s.Logger.Debugf("pss send payload: %v. topic: %s", err, t)
s.Logger.Debugf("pss send payload: %v. topic: %s", err, topicVar)
s.Logger.Error("pss send payload")
jsonhttp.InternalServerError(w, nil)
return
......@@ -91,7 +105,7 @@ func (s *server) pumpWs(conn *websocket.Conn, t string) {
var (
dataC = make(chan []byte)
gone = make(chan struct{})
topic = trojan.NewTopic(t)
topic = pss.NewTopic(t)
ticker = time.NewTicker(s.WsPingPeriod)
err error
)
......@@ -99,8 +113,8 @@ func (s *server) pumpWs(conn *websocket.Conn, t string) {
ticker.Stop()
_ = conn.Close()
}()
cleanup := s.Pss.Register(topic, func(_ context.Context, m *trojan.Message) {
dataC <- m.Payload
cleanup := s.Pss.Register(topic, func(_ context.Context, m []byte) {
dataC <- m
})
defer cleanup()
......
......@@ -7,6 +7,8 @@ package api_test
import (
"bytes"
"context"
"crypto/ecdsa"
"encoding/hex"
"fmt"
"io/ioutil"
"net/http"
......@@ -15,6 +17,8 @@ import (
"testing"
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
"github.com/ethersphere/bee/pkg/logging"
......@@ -22,22 +26,22 @@ import (
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/gorilla/websocket"
)
var (
target = trojan.Target([]byte{1})
targets = trojan.Targets([]trojan.Target{target})
target = pss.Target([]byte{1})
targets = pss.Targets([]pss.Target{target})
payload = []byte("testdata")
topic = trojan.NewTopic("testtopic")
topic = pss.NewTopic("testtopic")
timeout = 10 * time.Second
)
// creates a single websocket handler for an arbitrary topic, and receives a message
func TestPssWebsocketSingleHandler(t *testing.T) {
var (
pss, cl, _ = newPssTest(t, opts{})
p, publicKey, cl, _ = newPssTest(t, opts{})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
......@@ -52,20 +56,17 @@ func TestPssWebsocketSingleHandler(t *testing.T) {
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
err = p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, payload, &mtx)
}
......@@ -74,7 +75,8 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
// pss.TryUnwrap with a chunk designated for this handler and expect
// the handler to be notified
var (
pss, cl, _ = newPssTest(t, opts{})
p, publicKey, cl, _ = newPssTest(t, opts{})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
......@@ -89,12 +91,8 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets)
if err != nil {
t.Fatal(err)
}
......@@ -105,7 +103,7 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
err = p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
......@@ -115,7 +113,8 @@ func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
func TestPssWebsocketMultiHandler(t *testing.T) {
var (
pss, cl, listener = newPssTest(t, opts{})
p, publicKey, cl, listener = newPssTest(t, opts{})
u = url.URL{Scheme: "ws", Host: listener, Path: "/pss/subscribe/testtopic"}
cl2, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
......@@ -138,12 +137,8 @@ func TestPssWebsocketMultiHandler(t *testing.T) {
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
go waitReadMessage(t, &mtx, cl2, msgContent2, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets)
if err != nil {
t.Fatal(err)
}
......@@ -154,7 +149,7 @@ func TestPssWebsocketMultiHandler(t *testing.T) {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
err = p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
......@@ -169,28 +164,33 @@ func TestPssSend(t *testing.T) {
logger = logging.New(ioutil.Discard, 0)
mtx sync.Mutex
recievedTargets trojan.Targets
recievedTopic trojan.Topic
recievedBytes []byte
receivedTopic pss.Topic
receivedBytes []byte
receivedTargets pss.Targets
done bool
sendFn = func(_ context.Context, targets trojan.Targets, topic trojan.Topic, bytes []byte) error {
privk, _ = crypto.GenerateSecp256k1Key()
publicKeyBytes = (*btcec.PublicKey)(&privk.PublicKey).SerializeCompressed()
sendFn = func(ctx context.Context, targets pss.Targets, chunk swarm.Chunk) error {
mtx.Lock()
recievedTargets = targets
recievedTopic = topic
recievedBytes = bytes
topic, msg, err := pss.Unwrap(ctx, privk, chunk, []pss.Topic{topic})
receivedTopic = topic
receivedBytes = msg
receivedTargets = targets
done = true
mtx.Unlock()
return nil
return err
}
pss = newMockPss(sendFn)
p = newMockPss(sendFn)
client, _, _ = newTestServer(t, testServerOptions{
Pss: pss,
Pss: p,
Storer: mock.NewStorer(),
Logger: logger,
})
recipient = hex.EncodeToString(publicKeyBytes)
targets = fmt.Sprintf("[[%d]]", 0x12)
topic = "testtopic"
hasher = swarm.NewHasher()
......@@ -202,7 +202,7 @@ func TestPssSend(t *testing.T) {
}
t.Run("err - bad targets", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/badtarget", http.StatusBadRequest,
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/badtarget?recipient="+recipient, http.StatusBadRequest,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "Bad Request",
......@@ -212,6 +212,26 @@ func TestPssSend(t *testing.T) {
})
t.Run("ok", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12?recipient="+recipient, http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "OK",
Code: http.StatusOK,
}),
)
waitDone(t, &mtx, &done)
if !bytes.Equal(receivedBytes, payload) {
t.Fatalf("payload mismatch. want %v got %v", payload, receivedBytes)
}
if targets != fmt.Sprint(receivedTargets) {
t.Fatalf("targets mismatch. want %v got %v", targets, receivedTargets)
}
if string(topicHash) != string(receivedTopic[:]) {
t.Fatalf("topic mismatch. want %v got %v", topic, string(receivedTopic[:]))
}
})
t.Run("without recipient", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12", http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
......@@ -220,14 +240,14 @@ func TestPssSend(t *testing.T) {
}),
)
waitDone(t, &mtx, &done)
if !bytes.Equal(recievedBytes, payload) {
t.Fatalf("payload mismatch. want %v got %v", payload, recievedBytes)
if !bytes.Equal(receivedBytes, payload) {
t.Fatalf("payload mismatch. want %v got %v", payload, receivedBytes)
}
if targets != fmt.Sprint(recievedTargets) {
t.Fatalf("targets mismatch. want %v got %v", targets, recievedTargets)
if targets != fmt.Sprint(receivedTargets) {
t.Fatalf("targets mismatch. want %v got %v", targets, receivedTargets)
}
if string(topicHash) != string(recievedTopic[:]) {
t.Fatalf("topic mismatch. want %v got %v", topic, string(recievedTopic[:]))
if string(topicHash) != string(receivedTopic[:]) {
t.Fatalf("topic mismatch. want %v got %v", topic, string(receivedTopic[:]))
}
})
}
......@@ -237,7 +257,7 @@ func TestPssSend(t *testing.T) {
// The test opens a websocket, keeps it alive for 500ms, then receives a pss message.
func TestPssPingPong(t *testing.T) {
var (
pss, cl, _ = newPssTest(t, opts{pingPeriod: 90 * time.Millisecond})
p, publicKey, cl, _ = newPssTest(t, opts{pingPeriod: 90 * time.Millisecond})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
......@@ -254,18 +274,14 @@ func TestPssPingPong(t *testing.T) {
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
err = pss.TryUnwrap(context.Background(), tc)
err = p.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
......@@ -318,6 +334,8 @@ func waitDone(t *testing.T, mtx *sync.Mutex, done *bool) {
}
func waitMessage(t *testing.T, data, expData []byte, mtx *sync.Mutex) {
t.Helper()
ttl := time.After(timeout)
for {
select {
......@@ -342,10 +360,16 @@ type opts struct {
pingPeriod time.Duration
}
func newPssTest(t *testing.T, o opts) (pss.Interface, *websocket.Conn, string) {
func newPssTest(t *testing.T, o opts) (pss.Interface, *ecdsa.PublicKey, *websocket.Conn, string) {
t.Helper()
privkey, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
var (
logger = logging.New(ioutil.Discard, 0)
pss = pss.New(logger)
pss = pss.New(privkey, logger)
)
if o.pingPeriod == 0 {
o.pingPeriod = 10 * time.Second
......@@ -357,10 +381,10 @@ func newPssTest(t *testing.T, o opts) (pss.Interface, *websocket.Conn, string) {
Logger: logger,
WsPingPeriod: o.pingPeriod,
})
return pss, cl, listener
return pss, &privkey.PublicKey, cl, listener
}
type pssSendFn func(context.Context, trojan.Targets, trojan.Topic, []byte) error
type pssSendFn func(context.Context, pss.Targets, swarm.Chunk) error
type mpss struct {
f pssSendFn
}
......@@ -370,12 +394,16 @@ func newMockPss(f pssSendFn) *mpss {
}
// Send arbitrary byte slice with the given topic to Targets.
func (m *mpss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, bytes []byte) error {
return m.f(ctx, targets, topic, bytes)
func (m *mpss) Send(ctx context.Context, topic pss.Topic, payload []byte, recipient *ecdsa.PublicKey, targets pss.Targets) error {
chunk, err := pss.Wrap(ctx, topic, payload, recipient, targets)
if err != nil {
return err
}
return m.f(ctx, targets, chunk)
}
// Register a Handler for a given Topic.
func (m *mpss) Register(_ trojan.Topic, _ pss.Handler) func() {
func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() {
panic("not implemented") // TODO: Implement
}
......
......@@ -61,6 +61,13 @@ func DecodeSecp256k1PrivateKey(data []byte) (*ecdsa.PrivateKey, error) {
return (*ecdsa.PrivateKey)(privk), nil
}
// Secp256k1PrivateKeyFromBytes returns an ECDSA private key based on
// the byte slice.
func Secp256k1PrivateKeyFromBytes(data []byte) *ecdsa.PrivateKey {
privk, _ := btcec.PrivKeyFromBytes(btcec.S256(), data)
return (*ecdsa.PrivateKey)(privk)
}
// NewEthereumAddress returns a binary representation of ethereum blockchain address.
// This function is based on github.com/ethereum/go-ethereum/crypto.PubkeyToAddress.
func NewEthereumAddress(p ecdsa.PublicKey) ([]byte, error) {
......
......@@ -62,6 +62,24 @@ func TestEncodeSecp256k1PrivateKey(t *testing.T) {
}
}
func TestSecp256k1PrivateKeyFromBytes(t *testing.T) {
data := []byte("data")
k1 := crypto.Secp256k1PrivateKeyFromBytes(data)
if k1 == nil {
t.Fatal("nil key")
}
k2 := crypto.Secp256k1PrivateKeyFromBytes(data)
if k2 == nil {
t.Fatal("nil key")
}
if !bytes.Equal(k1.D.Bytes(), k2.D.Bytes()) {
t.Fatal("two generated keys are not equal")
}
}
func TestNewEthereumAddress(t *testing.T) {
privKeyHex := "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae"
privKeyBytes, err := hex.DecodeString(privKeyHex)
......
......@@ -23,10 +23,10 @@ func New(key *ecdsa.PrivateKey, pub *ecdsa.PublicKey, salt []byte, padding int,
return encryption.New(sk, padding, 0, hashfunc), nil
}
// NewEncrypter constructs an El-Gamal encryptor
// NewEncryptor constructs an El-Gamal encryptor
// this involves generating an ephemeral key pair the public part of which is returned
// as it is needed for the counterparty to decrypt
func NewEncrypter(pub *ecdsa.PublicKey, salt []byte, padding int, hashfunc func() hash.Hash) (encryption.Encrypter, *ecdsa.PublicKey, error) {
func NewEncryptor(pub *ecdsa.PublicKey, salt []byte, padding int, hashfunc func() hash.Hash) (encryption.Encrypter, *ecdsa.PublicKey, error) {
privKey, err := crypto.GenerateSecp256k1Key()
if err != nil {
return nil, nil, err
......
......@@ -27,7 +27,7 @@ func TestElgamalCorrect(t *testing.T) {
t.Fatal(err)
}
padding := 4032
enc, ephpub, err := elgamal.NewEncrypter(pub, salt, padding, swarm.NewHasher)
enc, ephpub, err := elgamal.NewEncryptor(pub, salt, padding, swarm.NewHasher)
if err != nil {
t.Fatal(err)
}
......
......@@ -17,11 +17,11 @@ import (
validatormock "github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
)
var chunkData = []byte("mockdata")
......@@ -173,7 +173,7 @@ type mockRecovery struct {
}
// Send mocks the pss Send function
func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets trojan.Targets) error {
func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets pss.Targets) error {
mr.hookC <- true
return nil
}
......
......@@ -341,7 +341,12 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
}
// instantiate the pss object
psss := pss.New(logger)
swarmPrivateKey, _, err := keystore.Key("swarm", o.Password)
if err != nil {
return nil, fmt.Errorf("swarm key: %w", err)
}
psss := pss.New(swarmPrivateKey, logger)
b.pssCloser = psss
var ns storage.Storer
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pss
var (
Contains = contains
)
......@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trojan_test
package pss_test
import (
"context"
......@@ -10,25 +10,21 @@ import (
"fmt"
"testing"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/pss"
)
func newTargets(length, depth int) trojan.Targets {
targets := make([]trojan.Target, length)
func newTargets(length, depth int) pss.Targets {
targets := make([]pss.Target, length)
for i := 0; i < length; i++ {
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, uint64(i))
targets[i] = trojan.Target(buf[:depth])
targets[i] = pss.Target(buf[:depth])
}
return trojan.Targets(targets)
return pss.Targets(targets)
}
func BenchmarkWrap(b *testing.B) {
payload := []byte("foopayload")
m, err := trojan.NewMessage(testTopic, payload)
if err != nil {
b.Fatal(err)
}
cases := []struct {
length int
depth int
......@@ -43,12 +39,19 @@ func BenchmarkWrap(b *testing.B) {
{4096, 3},
{16384, 3},
}
topic := pss.NewTopic("topic")
msg := []byte("this is my scariest")
key, err := crypto.GenerateSecp256k1Key()
if err != nil {
b.Fatal(err)
}
pubkey := &key.PublicKey
for _, c := range cases {
name := fmt.Sprintf("length:%d,depth:%d", c.length, c.depth)
b.Run(name, func(b *testing.B) {
targets := newTargets(c.length, c.depth)
for i := 0; i < b.N; i++ {
if _, err := m.Wrap(context.Background(), targets); err != nil {
if _, err := pss.Wrap(context.Background(), topic, msg, pubkey, targets); err != nil {
b.Fatal(err)
}
}
......
......@@ -6,6 +6,7 @@ package pss
import (
"context"
"crypto/ecdsa"
"errors"
"fmt"
"io"
......@@ -14,7 +15,6 @@ import (
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
)
var (
......@@ -22,11 +22,15 @@ var (
ErrNoHandler = errors.New("no handler found")
)
type Interface interface {
type Sender interface {
// Send arbitrary byte slice with the given topic to Targets.
Send(context.Context, trojan.Targets, trojan.Topic, []byte) error
Send(context.Context, Topic, []byte, *ecdsa.PublicKey, Targets) error
}
type Interface interface {
Sender
// Register a Handler for a given Topic.
Register(trojan.Topic, Handler) func()
Register(Topic, Handler) func()
// TryUnwrap tries to unwrap a wrapped trojan message.
TryUnwrap(context.Context, swarm.Chunk) error
......@@ -35,8 +39,9 @@ type Interface interface {
}
type pss struct {
key *ecdsa.PrivateKey
pusher pushsync.PushSyncer
handlers map[trojan.Topic][]*Handler
handlers map[Topic][]*Handler
handlersMu sync.Mutex
metrics metrics
logger logging.Logger
......@@ -44,10 +49,11 @@ type pss struct {
}
// New returns a new pss service.
func New(logger logging.Logger) Interface {
func New(key *ecdsa.PrivateKey, logger logging.Logger) Interface {
return &pss{
key: key,
logger: logger,
handlers: make(map[trojan.Topic][]*Handler),
handlers: make(map[Topic][]*Handler),
metrics: newMetrics(),
quit: make(chan struct{}),
}
......@@ -58,7 +64,7 @@ func (ps *pss) Close() error {
ps.handlersMu.Lock()
defer ps.handlersMu.Unlock()
ps.handlers = make(map[trojan.Topic][]*Handler) //unset handlers on shutdown
ps.handlers = make(map[Topic][]*Handler) //unset handlers on shutdown
return nil
}
......@@ -68,21 +74,15 @@ func (ps *pss) SetPushSyncer(pushSyncer pushsync.PushSyncer) {
}
// Handler defines code to be executed upon reception of a trojan message.
type Handler func(context.Context, *trojan.Message)
type Handler func(context.Context, []byte)
// Send constructs a padded message with topic and payload,
// wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address.
// Uses push-sync to deliver message.
func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error {
func (p *pss) Send(ctx context.Context, topic Topic, payload []byte, recipient *ecdsa.PublicKey, targets Targets) error {
p.metrics.TotalMessagesSentCounter.Inc()
m, err := trojan.NewMessage(topic, payload)
if err != nil {
return err
}
var tc swarm.Chunk
tc, err = m.Wrap(ctx, targets)
tc, err := Wrap(ctx, topic, payload, recipient, targets)
if err != nil {
return err
}
......@@ -96,7 +96,7 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top
}
// Register allows the definition of a Handler func for a specific topic on the pss struct.
func (p *pss) Register(topic trojan.Topic, handler Handler) (cleanup func()) {
func (p *pss) Register(topic Topic, handler Handler) (cleanup func()) {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
......@@ -116,18 +116,27 @@ func (p *pss) Register(topic trojan.Topic, handler Handler) (cleanup func()) {
}
}
func (p *pss) topics() []Topic {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
ts := make([]Topic, 0, len(p.handlers))
for t := range p.handlers {
ts = append(ts, t)
}
return ts
}
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic.
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
if !trojan.IsPotential(c) {
return nil
}
m, err := trojan.Unwrap(c)
topic, msg, err := Unwrap(ctx, p.key, c, p.topics())
if err != nil {
return err
}
h := p.getHandlers(m.Topic)
h := p.getHandlers(topic)
if h == nil {
return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler)
return fmt.Errorf("topic %v, %w", topic, ErrNoHandler)
}
ctx, cancel := context.WithCancel(ctx)
......@@ -144,7 +153,7 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
wg.Add(1)
go func(hh Handler) {
defer wg.Done()
hh(ctx, m)
hh(ctx, msg)
}(*hh)
}
go func() {
......@@ -155,7 +164,7 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
return nil
}
func (p *pss) getHandlers(topic trojan.Topic) []*Handler {
func (p *pss) getHandlers(topic Topic) []*Handler {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
......
......@@ -8,17 +8,15 @@ import (
"bytes"
"context"
"io/ioutil"
"runtime"
"sync"
"testing"
"time"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync"
pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
)
// TestSend creates a trojan chunk and sends it using push sync
......@@ -27,202 +25,225 @@ func TestSend(t *testing.T) {
ctx := context.Background()
// create a mock pushsync service to push the chunk to its destination
var receipt *pushsync.Receipt
var storedChunk swarm.Chunk
pushSyncService := pushsyncmock.New(func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
rcpt := &pushsync.Receipt{
Address: swarm.NewAddress(chunk.Address().Bytes()),
}
storedChunk = chunk
receipt = rcpt
return rcpt, nil
return nil, nil
})
pss := pss.New(logging.New(ioutil.Discard, 0))
pss.SetPushSyncer(pushSyncService)
target := trojan.Target([]byte{1}) // arbitrary test target
targets := trojan.Targets([]trojan.Target{target})
payload := []byte("RECOVERY CHUNK")
topic := trojan.NewTopic("RECOVERY TOPIC")
p := pss.New(nil, logging.New(ioutil.Discard, 0))
p.SetPushSyncer(pushSyncService)
target := pss.Target([]byte{1}) // arbitrary test target
targets := pss.Targets([]pss.Target{target})
payload := []byte("some payload")
topic := pss.NewTopic("topic")
privkey, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
recipient := &privkey.PublicKey
// call Send to store trojan chunk in localstore
if err = pss.Send(ctx, targets, topic, payload); err != nil {
if err = p.Send(ctx, topic, payload, recipient, targets); err != nil {
t.Fatal(err)
}
if receipt == nil {
t.Fatal("no receipt")
}
m, err := trojan.Unwrap(storedChunk)
topic1 := pss.NewTopic("topic-1")
topic2 := pss.NewTopic("topic-2")
topic3 := pss.NewTopic("topic-3")
topics := []pss.Topic{topic, topic1, topic2, topic3}
unwrapTopic, msg, err := pss.Unwrap(ctx, privkey, storedChunk, topics)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(m.Payload, payload) {
t.Fatalf("payload mismatch expected %v but is %v instead", m.Payload, payload)
if !bytes.Equal(msg, payload) {
t.Fatalf("message mismatch: expected %x, got %x", payload, msg)
}
if !bytes.Equal(m.Topic[:], topic[:]) {
t.Fatalf("topic mismatch expected %v but is %v instead", m.Topic, topic)
if !bytes.Equal(unwrapTopic[:], topic[:]) {
t.Fatalf("topic mismatch: expected %x, got %x", topic[:], unwrapTopic[:])
}
}
type topicMessage struct {
topic pss.Topic
msg []byte
}
// TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it
// results in the execution of the expected handler func
func TestDeliver(t *testing.T) {
pss := pss.New(logging.New(ioutil.Discard, 0))
ctx := context.Background()
var mtx sync.Mutex
// test message
topic := trojan.NewTopic("footopic")
payload := []byte("foopayload")
msg, err := trojan.NewMessage(topic, payload)
privkey, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
p := pss.New(privkey, logging.New(ioutil.Discard, 0))
target := pss.Target([]byte{1}) // arbitrary test target
targets := pss.Targets([]pss.Target{target})
payload := []byte("some payload")
topic := pss.NewTopic("topic")
recipient := &privkey.PublicKey
// test chunk
target := trojan.Target([]byte{1}) // arbitrary test target
targets := trojan.Targets([]trojan.Target{target})
c, err := msg.Wrap(ctx, targets)
chunk, err := pss.Wrap(context.Background(), topic, payload, recipient, targets)
if err != nil {
t.Fatal(err)
}
msgChan := make(chan topicMessage)
// create and register handler
var tt trojan.Topic // test variable to check handler func was correctly called
hndlr := func(ctx context.Context, m *trojan.Message) {
mtx.Lock()
copy(tt[:], m.Topic[:]) // copy the message topic to the test variable
mtx.Unlock()
handler := func(ctx context.Context, m []byte) {
msgChan <- topicMessage{
topic: topic,
msg: m,
}
}
pss.Register(topic, hndlr)
p.Register(topic, handler)
// call pss TryUnwrap on chunk and verify test topic variable value changes
err = pss.TryUnwrap(ctx, c)
err = p.TryUnwrap(ctx, chunk)
if err != nil {
t.Fatal(err)
}
runtime.Gosched() // schedule the handler goroutine
for i := 0; i < 10; i++ {
mtx.Lock()
eq := bytes.Equal(tt[:], msg.Topic[:])
mtx.Unlock()
if eq {
return
var message topicMessage
select {
case message = <-msgChan:
break
case <-time.After(1 * time.Second):
t.Fatal("reached timeout while waiting for message")
}
if !bytes.Equal(payload, message.msg) {
t.Fatalf("message mismatch: expected %x, got %x", payload, message.msg)
}
<-time.After(50 * time.Millisecond)
if !bytes.Equal(topic[:], message.topic[:]) {
t.Fatalf("topic mismatch: expected %x, got %x", topic[:], message.topic[:])
}
t.Fatalf("unexpected result for pss Deliver func, expected test variable to have a value of %v but is %v instead", msg.Topic, tt)
}
// TestRegister verifies that handler funcs are able to be registered correctly in pss
func TestRegister(t *testing.T) {
privkey, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
recipient := &privkey.PublicKey
var (
pss = pss.New(logging.New(ioutil.Discard, 0))
p = pss.New(privkey, logging.New(ioutil.Discard, 0))
h1Calls = 0
h2Calls = 0
h3Calls = 0
mtx sync.Mutex
topic1 = trojan.NewTopic("one")
topic2 = trojan.NewTopic("two")
msgChan = make(chan struct{})
topic1 = pss.NewTopic("one")
topic2 = pss.NewTopic("two")
payload = []byte("payload")
target = trojan.Target([]byte{1})
targets = trojan.Targets([]trojan.Target{target})
target = pss.Target([]byte{1})
targets = pss.Targets([]pss.Target{target})
h1 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h1 = func(_ context.Context, m []byte) {
h1Calls++
msgChan <- struct{}{}
}
h2 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h2 = func(_ context.Context, m []byte) {
h2Calls++
msgChan <- struct{}{}
}
h3 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h3 = func(_ context.Context, m []byte) {
h3Calls++
msgChan <- struct{}{}
}
)
_ = pss.Register(topic1, h1)
_ = pss.Register(topic2, h2)
_ = p.Register(topic1, h1)
_ = p.Register(topic2, h2)
// send a message on topic1, check that only h1 is called
msg, err := trojan.NewMessage(topic1, payload)
chunk1, err := pss.Wrap(context.Background(), topic1, payload, recipient, targets)
if err != nil {
t.Fatal(err)
}
c, err := msg.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), c)
err = p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 1)
ensureCalls(t, &mtx, &h2Calls, 0)
waitHandlerCallback(t, &msgChan, 1)
ensureCalls(t, &h1Calls, 1)
ensureCalls(t, &h2Calls, 0)
// register another topic handler on the same topic
cleanup := pss.Register(topic1, h3)
err = pss.TryUnwrap(context.Background(), c)
cleanup := p.Register(topic1, h3)
err = p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 2)
ensureCalls(t, &mtx, &h2Calls, 0)
ensureCalls(t, &mtx, &h3Calls, 1)
waitHandlerCallback(t, &msgChan, 2)
ensureCalls(t, &h1Calls, 2)
ensureCalls(t, &h2Calls, 0)
ensureCalls(t, &h3Calls, 1)
cleanup() // remove the last handler
err = pss.TryUnwrap(context.Background(), c)
err = p.TryUnwrap(context.Background(), chunk1)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 3)
ensureCalls(t, &mtx, &h2Calls, 0)
ensureCalls(t, &mtx, &h3Calls, 1)
waitHandlerCallback(t, &msgChan, 1)
msg, err = trojan.NewMessage(topic2, payload)
ensureCalls(t, &h1Calls, 3)
ensureCalls(t, &h2Calls, 0)
ensureCalls(t, &h3Calls, 1)
chunk2, err := pss.Wrap(context.Background(), topic2, payload, recipient, targets)
if err != nil {
t.Fatal(err)
}
c, err = msg.Wrap(context.Background(), targets)
err = p.TryUnwrap(context.Background(), chunk2)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
waitHandlerCallback(t, &msgChan, 1)
ensureCalls(t, &mtx, &h1Calls, 3)
ensureCalls(t, &mtx, &h2Calls, 1)
ensureCalls(t, &mtx, &h3Calls, 1)
ensureCalls(t, &h1Calls, 3)
ensureCalls(t, &h2Calls, 1)
ensureCalls(t, &h3Calls, 1)
}
func ensureCalls(t *testing.T, mtx *sync.Mutex, calls *int, exp int) {
func waitHandlerCallback(t *testing.T, msgChan *chan struct{}, count int) {
t.Helper()
for i := 0; i < 10; i++ {
mtx.Lock()
if *calls == exp {
mtx.Unlock()
return
for received := 0; received < count; received++ {
select {
case <-*msgChan:
case <-time.After(1 * time.Second):
t.Fatal("reached timeout while waiting for handler message")
}
}
mtx.Unlock()
<-time.After(100 * time.Millisecond)
}
func ensureCalls(t *testing.T, calls *int, exp int) {
t.Helper()
if exp != *calls {
t.Fatalf("expected %d calls, found %d", exp, *calls)
}
t.Fatal("timed out waiting for value")
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pss
import (
"bytes"
"context"
"crypto/ecdsa"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
random "math/rand"
"github.com/btcsuite/btcd/btcec"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/encryption/elgamal"
"github.com/ethersphere/bee/pkg/swarm"
bmtlegacy "github.com/ethersphere/bmt/legacy"
)
var (
// ErrPayloadTooBig is returned when a given payload for a Message type is longer than the maximum amount allowed
ErrPayloadTooBig = fmt.Errorf("message payload size cannot be greater than %d bytes", MaxPayloadSize)
// ErrEmptyTargets is returned when the given target list for a trojan chunk is empty
ErrEmptyTargets = errors.New("target list cannot be empty")
// ErrVarLenTargets is returned when the given target list for a trojan chunk has addresses of different lengths
ErrVarLenTargets = errors.New("target list cannot have targets of different length")
)
// Topic is the type that classifies messages, allows client applications to subscribe to
type Topic [32]byte
// NewTopic creates a new Topic from an input string by taking its hash
func NewTopic(text string) Topic {
bytes, _ := crypto.LegacyKeccak256([]byte(text))
var topic Topic
copy(topic[:], bytes[:32])
return topic
}
// Target is an alias for a partial address (overlay prefix) serving as potential destination
type Target []byte
// Targets is an alias for a collection of targets
type Targets []Target
const (
// MaxPayloadSize is the maximum allowed payload size for the Message type, in bytes
MaxPayloadSize = swarm.ChunkSize - 3*swarm.HashSize
)
// Wrap creates a new serialised message with the given topic, payload and recipient public key used
// for encryption
// - span as topic hint (H(key|topic)[0:8]) to match topic
// chunk payload:
// - nonce is chosen so that the chunk address will have one of the targets as its prefix and thus will be forwarded to the neighbourhood of the recipient overlay address the target is derived from
// trojan payload:
// - ephemeral public key for el-Gamal encryption
// ciphertext - plaintext:
// - plaintext length encoding
// - integrity protection
// message:
func Wrap(ctx context.Context, topic Topic, msg []byte, recipient *ecdsa.PublicKey, targets Targets) (swarm.Chunk, error) {
if len(msg) > MaxPayloadSize {
return nil, ErrPayloadTooBig
}
// integrity protection and plaintext msg length encoding
integrity, err := crypto.LegacyKeccak256(msg)
if err != nil {
return nil, err
}
binary.BigEndian.PutUint16(integrity[:2], uint16(len(msg)))
// integrity segment prepended to msg
plaintext := append(integrity, msg...)
// use el-Gamal with ECDH on an ephemeral key, recipient public key and topic as salt
enc, ephpub, err := elgamal.NewEncryptor(recipient, topic[:], 4032, swarm.NewHasher)
if err != nil {
return nil, err
}
ciphertext, err := enc.Encrypt(plaintext)
if err != nil {
return nil, err
}
// prepend serialised ephemeral public key to the ciphertext
// NOTE: only the random bytes of the compressed public key are used
// in order not to leak anything, the one bit parity info of the magic byte
// is encoded in the parity of the 28th byte of the mined nonce
ephpubBytes := (*btcec.PublicKey)(ephpub).SerializeCompressed()
payload := append(ephpubBytes[1:], ciphertext...)
odd := ephpubBytes[0]&0x1 != 0
if err := checkTargets(targets); err != nil {
return nil, err
}
targetsLen := len(targets[0])
// topic hash, the first 8 bytes is used as the span of the chunk
hash, err := crypto.LegacyKeccak256(append(enc.Key(), topic[:]...))
if err != nil {
return nil, err
}
hint := hash[:8]
h := hasher(hint, payload)
// f is evaluating the mined nonce
// it accepts the nonce if it has the parity required by the ephemeral public key AND
// the chunk hashes to an address matching one of the targets
f := func(nonce []byte) (swarm.Chunk, error) {
hash, err := h(nonce)
if err != nil {
return nil, err
}
if !contains(targets, hash[:targetsLen]) {
return nil, nil
}
chunk := swarm.NewChunk(swarm.NewAddress(hash), append(hint, append(nonce, payload...)...))
return chunk, nil
}
return mine(ctx, odd, f)
}
// Unwrap takes a chunk, a topic and a private key, and tries to decrypt the payload
// using the private key, the prepended ephemeral public key for el-Gamal using the topic as salt
func Unwrap(ctx context.Context, key *ecdsa.PrivateKey, chunk swarm.Chunk, topics []Topic) (topic Topic, msg []byte, err error) {
chunkData := chunk.Data()
pubkey, err := extractPublicKey(chunkData)
if err != nil {
return Topic{}, nil, err
}
hint := chunkData[:8]
for _, topic = range topics {
select {
case <-ctx.Done():
return Topic{}, nil, ctx.Err()
default:
}
dec, err := matchTopic(key, pubkey, hint, topic[:])
if err != nil {
privk := crypto.Secp256k1PrivateKeyFromBytes(topic[:])
dec, err = matchTopic(privk, pubkey, hint, topic[:])
if err != nil {
continue
}
}
ciphertext := chunkData[72:]
msg, err = decryptAndCheck(dec, ciphertext)
if err != nil {
continue
}
break
}
return topic, msg, nil
}
// checkTargets verifies that the list of given targets is non empty and with elements of matching size
func checkTargets(targets Targets) error {
if len(targets) == 0 {
return ErrEmptyTargets
}
validLen := len(targets[0]) // take first element as allowed length
for i := 1; i < len(targets); i++ {
if len(targets[i]) != validLen {
return ErrVarLenTargets
}
}
return nil
}
func hasher(span, b []byte) func([]byte) ([]byte, error) {
hashPool := bmtlegacy.NewTreePool(swarm.NewHasher, swarm.Branches, bmtlegacy.PoolSize)
return func(nonce []byte) ([]byte, error) {
s := append(nonce, b...)
hasher := bmtlegacy.New(hashPool)
if err := hasher.SetSpanBytes(span); err != nil {
return nil, err
}
if _, err := hasher.Write(s); err != nil {
return nil, err
}
return hasher.Sum(nil), nil
}
}
// contains returns whether the given collection contains the given element
func contains(col Targets, elem []byte) bool {
for i := range col {
if bytes.Equal(elem, col[i]) {
return true
}
}
return false
}
// mine iteratively enumerates different nonces until the address (BMT hash) of the chunkhas one of the targets as its prefix
func mine(ctx context.Context, odd bool, f func(nonce []byte) (swarm.Chunk, error)) (swarm.Chunk, error) {
seeds := make([]uint32, 8)
for i := range seeds {
seeds[i] = random.Uint32()
}
initnonce := make([]byte, 32)
for i := 0; i < 8; i++ {
binary.LittleEndian.PutUint32(initnonce[i*4:i*4+4], seeds[i])
}
if odd {
initnonce[28] |= 0x01
} else {
initnonce[28] &= 0xfe
}
seeds[7] = binary.LittleEndian.Uint32(initnonce[28:32])
quit := make(chan struct{})
// make both errs and result channels buffered so they never block
result := make(chan swarm.Chunk, 8)
errs := make(chan error, 8)
for i := 0; i < 8; i++ {
go func(j int) {
nonce := make([]byte, 32)
copy(nonce, initnonce)
for seed := seeds[j]; ; seed++ {
binary.LittleEndian.PutUint32(nonce[j*4:j*4+4], seed)
res, err := f(nonce)
if err != nil {
errs <- err
return
}
if res != nil {
result <- res
return
}
select {
case <-quit:
return
default:
}
}
}(i)
}
defer close(quit)
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errs:
return nil, err
case res := <-result:
return res, nil
}
}
// extracts ephemeral public key from the chunk data to use with el-Gamal
func extractPublicKey(chunkData []byte) (*ecdsa.PublicKey, error) {
pubkeyBytes := make([]byte, 33)
pubkeyBytes[0] |= 0x2
copy(pubkeyBytes[1:], chunkData[40:72])
if chunkData[36]|0x1 != 0 {
pubkeyBytes[0] |= 0x1
}
pubkey, err := btcec.ParsePubKey(pubkeyBytes, btcec.S256())
return (*ecdsa.PublicKey)(pubkey), err
}
// topic is needed to decrypt the trojan payload, but no need to perform decryption with each
// instead the hash of the secret key and the topic is matched against a hint (64 bit meta info)q
// proper integrity check will disambiguate any potential collisions (false positives)
// if the topic matches the hint, it returns the el-Gamal decryptor, otherwise an error
func matchTopic(key *ecdsa.PrivateKey, pubkey *ecdsa.PublicKey, hint []byte, topic []byte) (encryption.Decrypter, error) {
dec, err := elgamal.NewDecrypter(key, pubkey, topic, swarm.NewHasher)
if err != nil {
return nil, err
}
match, err := crypto.LegacyKeccak256(append(dec.Key(), topic...))
if err != nil {
return nil, err
}
if !bytes.Equal(hint, match[:8]) {
return nil, errors.New("topic does not match hint")
}
return dec, nil
}
// decrypts the ciphertext with an el-Gamal decryptor using a topic that matched the hint
// the msg is extracted from the plaintext and its integrity is checked
func decryptAndCheck(dec encryption.Decrypter, ciphertext []byte) ([]byte, error) {
plaintext, err := dec.Decrypt(ciphertext)
if err != nil {
return nil, err
}
length := int(binary.BigEndian.Uint16(plaintext[:2]))
if length > MaxPayloadSize {
return nil, errors.New("invalid length")
}
msg := plaintext[32 : 32+length]
integrity := plaintext[2:32]
hash, err := crypto.LegacyKeccak256(msg)
if err != nil {
return nil, err
}
if !bytes.Equal(integrity, hash[2:]) {
return nil, errors.New("invalid message")
}
// bingo
return msg, nil
}
// ParseRecipient extract ephemeral public key from the hexadecimal string to use with el-Gamal.
func ParseRecipient(recipientHexString string) (*ecdsa.PublicKey, error) {
publicKeyBytes, err := hex.DecodeString(recipientHexString)
if err != nil {
return nil, err
}
pubkey, err := btcec.ParsePubKey(publicKeyBytes, btcec.S256())
if err != nil {
return nil, err
}
return (*ecdsa.PublicKey)(pubkey), err
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pss_test
import (
"bytes"
"context"
"testing"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestWrap(t *testing.T) {
topic := pss.NewTopic("topic")
msg := []byte("some payload")
key, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
pubkey := &key.PublicKey
depth := 1
targets := newTargets(4, depth)
chunk, err := pss.Wrap(context.Background(), topic, msg, pubkey, targets)
if err != nil {
t.Fatal(err)
}
contains := pss.Contains(targets, chunk.Address().Bytes()[0:depth])
if !contains {
t.Fatal("trojan address was expected to match one of the targets with prefix")
}
if len(chunk.Data()) != swarm.ChunkWithSpanSize {
t.Fatalf("expected trojan data size to be %d, was %d", swarm.ChunkWithSpanSize, len(chunk.Data()))
}
}
func TestUnwrap(t *testing.T) {
topic := pss.NewTopic("topic")
msg := []byte("some payload")
key, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
pubkey := &key.PublicKey
depth := 1
targets := newTargets(4, depth)
chunk, err := pss.Wrap(context.Background(), topic, msg, pubkey, targets)
if err != nil {
t.Fatal(err)
}
topic1 := pss.NewTopic("topic-1")
topic2 := pss.NewTopic("topic-2")
unwrapTopic, unwrapMsg, err := pss.Unwrap(context.Background(), key, chunk, []pss.Topic{topic1, topic2, topic})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(msg, unwrapMsg) {
t.Fatalf("message mismatch: expected %x, got %x", msg, unwrapMsg)
}
if !bytes.Equal(topic[:], unwrapTopic[:]) {
t.Fatalf("topic mismatch: expected %x, got %x", topic[:], unwrapTopic[:])
}
}
func TestUnwrapTopicEncrypted(t *testing.T) {
topic := pss.NewTopic("topic")
msg := []byte("some payload")
privk := crypto.Secp256k1PrivateKeyFromBytes(topic[:])
pubkey := privk.PublicKey
depth := 1
targets := newTargets(4, depth)
chunk, err := pss.Wrap(context.Background(), topic, msg, &pubkey, targets)
if err != nil {
t.Fatal(err)
}
key, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
topic1 := pss.NewTopic("topic-1")
topic2 := pss.NewTopic("topic-2")
unwrapTopic, unwrapMsg, err := pss.Unwrap(context.Background(), key, chunk, []pss.Topic{topic1, topic2, topic})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(msg, unwrapMsg) {
t.Fatalf("message mismatch: expected %x, got %x", msg, unwrapMsg)
}
if !bytes.Equal(topic[:], unwrapTopic[:]) {
t.Fatalf("topic mismatch: expected %x, got %x", topic[:], unwrapTopic[:])
}
}
......@@ -7,12 +7,12 @@ package recovery
import (
"context"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
)
const (
......@@ -22,31 +22,28 @@ const (
var (
// RecoveryTopic is the topic used for repairing globally pinned chunks.
RecoveryTopic = trojan.NewTopic(RecoveryTopicText)
RecoveryTopic = pss.NewTopic(RecoveryTopicText)
)
// RecoveryHook defines code to be executed upon failing to retrieve chunks.
type RecoveryHook func(chunkAddress swarm.Address, targets trojan.Targets) error
// sender is the function call for sending trojan chunks.
type PssSender interface {
Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error
}
type RecoveryHook func(chunkAddress swarm.Address, targets pss.Targets) error
// NewRecoveryHook returns a new RecoveryHook with the sender function defined.
func NewRecoveryHook(pss PssSender) RecoveryHook {
return func(chunkAddress swarm.Address, targets trojan.Targets) error {
func NewRecoveryHook(pssSender pss.Sender) RecoveryHook {
privk := crypto.Secp256k1PrivateKeyFromBytes([]byte(RecoveryTopicText))
recipient := privk.PublicKey
return func(chunkAddress swarm.Address, targets pss.Targets) error {
payload := chunkAddress
ctx := context.Background()
err := pss.Send(ctx, targets, RecoveryTopic, payload.Bytes())
err := pssSender.Send(ctx, RecoveryTopic, payload.Bytes(), &recipient, targets)
return err
}
}
// NewRepairHandler creates a repair function to re-upload globally pinned chunks to the network with the given store.
func NewRepairHandler(s storage.Storer, logger logging.Logger, pushSyncer pushsync.PushSyncer) pss.Handler {
return func(ctx context.Context, m *trojan.Message) {
chAddr := m.Payload
return func(ctx context.Context, m []byte) {
chAddr := m
// check if the chunk exists in the local store and proceed.
// otherwise the Get will trigger a unnecessary network retrieve
......
......@@ -6,6 +6,7 @@ package recovery_test
import (
"context"
"crypto/ecdsa"
"errors"
"io/ioutil"
"testing"
......@@ -15,6 +16,7 @@ import (
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync"
pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock"
"github.com/ethersphere/bee/pkg/recovery"
......@@ -26,14 +28,13 @@ import (
chunktesting "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/trojan"
)
// TestRecoveryHook tests that a recovery hook can be created and called.
func TestRecoveryHook(t *testing.T) {
// test variables needed to be correctly set for any recovery hook to reach the sender func
chunkAddr := chunktesting.GenerateTestRandomChunk().Address()
targets := trojan.Targets{[]byte{0xED}}
targets := pss.Targets{[]byte{0xED}}
//setup the sender
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
......@@ -139,17 +140,8 @@ func TestNewRepairHandler(t *testing.T) {
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c1.Address().Bytes())
if msg, err = trojan.NewMessage(testTopic, maxPayload); err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
repairHandler(context.Background(), &msg)
repairHandler(context.Background(), c1.Address().Bytes())
// check if receipt is received
if receipt == nil {
......@@ -180,18 +172,8 @@ func TestNewRepairHandler(t *testing.T) {
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c2.Address().Bytes())
msg, err := trojan.NewMessage(testTopic, maxPayload)
if err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
repairHandler(context.Background(), &msg)
repairHandler(context.Background(), c2.Address().Bytes())
if pushServiceCalled {
t.Fatal("push service called even if the chunk is not present")
......@@ -220,18 +202,8 @@ func TestNewRepairHandler(t *testing.T) {
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c3.Address().Bytes())
msg, err = trojan.NewMessage(testTopic, maxPayload)
if err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
repairHandler(context.Background(), &msg)
repairHandler(context.Background(), c3.Address().Bytes())
if receiptError == nil {
t.Fatal("pushsync did not generate a receipt error")
......@@ -281,7 +253,7 @@ type mockPssSender struct {
}
// Send mocks the pss Send function
func (mp *mockPssSender) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error {
func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []byte, recipient *ecdsa.PublicKey, targets pss.Targets) error {
mp.hookC <- true
return nil
}
......@@ -6,14 +6,12 @@ package sctx
import (
"context"
"encoding/hex"
"errors"
"strings"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/trojan"
)
var (
......@@ -63,16 +61,16 @@ func SetTargets(ctx context.Context, targets string) context.Context {
// GetTargets returns the specific target pinners for a corresponding chunk by
// reading the prefix targets sent in the download API.
func GetTargets(ctx context.Context) trojan.Targets {
func GetTargets(ctx context.Context) pss.Targets {
targetString, ok := ctx.Value(targetsContextKey{}).(string)
if !ok {
return nil
}
prefixes := strings.Split(targetString, ",")
var targets trojan.Targets
var targets pss.Targets
for _, prefix := range prefixes {
var target trojan.Target
var target pss.Target
target, err := hex.DecodeString(prefix)
if err != nil {
continue
......
package trojan
import (
"errors"
"fmt"
)
var (
// ErrPayloadTooBig is returned when a given payload for a Message type is longer than the maximum amount allowed
ErrPayloadTooBig = fmt.Errorf("message payload size cannot be greater than %d bytes", MaxPayloadSize)
// ErrEmptyTargets is returned when the given target list for a trojan chunk is empty
ErrEmptyTargets = errors.New("target list cannot be empty")
// ErrVarLenTargets is returned when the given target list for a trojan chunk has addresses of different lengths
ErrVarLenTargets = errors.New("target list cannot have targets of different length")
// ErrUnmarshal is returned when a trojan message could not be de-serialized
ErrUnmarshal = errors.New("trojan message unmarshall error")
)
package trojan
var (
Contains = contains
)
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trojan
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
random "math/rand"
"github.com/ethersphere/bee/pkg/swarm"
bmtlegacy "github.com/ethersphere/bmt/legacy"
)
// Topic is an alias for a 32 byte fixed-size array which contains an encoding of a message topic
type Topic [32]byte
// Target is an alias for an address which can be mined to construct a trojan message.
// Target is like partial address which helps to send message to a particular PO.
type Target []byte
// Targets is an alias for a collection of targets
type Targets []Target
// Message represents a trojan message, which is a message that will be hidden within a chunk payload as part of its data
type Message struct {
length [2]byte // big-endian encoding of Message payload length
Topic Topic
Payload []byte // contains the chunk address to be repaired
padding []byte
}
const (
// MaxPayloadSize is the maximum allowed payload size for the Message type, in bytes
// MaxPayloadSize + Topic + Length + Nonce = Default ChunkSize
// (4030) + (32) + (2) + (32) = 4096 Bytes
MaxPayloadSize = swarm.ChunkSize - NonceSize - LengthSize - TopicSize
// NonceSize is a hash bit sequence
NonceSize = 32
// LengthSize is the byte length to represent message
LengthSize = 2
// TopicSize is a hash bit sequence
TopicSize = 32
)
// NewTopic creates a new Topic variable with the given input string
// the input string is taken as a byte slice and hashed
func NewTopic(topic string) Topic {
var tpc Topic
hasher := swarm.NewHasher()
_, err := hasher.Write([]byte(topic))
if err != nil {
return tpc
}
sum := hasher.Sum(nil)
copy(tpc[:], sum)
return tpc
}
// NewMessage creates a new Message variable with the given topic and payload
// it finds a length and nonce for the message according to the given input and maximum payload size
func NewMessage(topic Topic, payload []byte) (Message, error) {
if len(payload) > MaxPayloadSize {
return Message{}, ErrPayloadTooBig
}
// get length as array of 2 bytes
payloadSize := uint16(len(payload))
// set random bytes as padding
paddingLen := MaxPayloadSize - payloadSize
padding := make([]byte, paddingLen)
if _, err := rand.Read(padding); err != nil {
return Message{}, err
}
// create new Message var and set fields
m := new(Message)
binary.BigEndian.PutUint16(m.length[:], payloadSize)
m.Topic = topic
m.Payload = payload
m.padding = padding
return *m, nil
}
// Wrap creates a new trojan chunk for the given targets and Message
// a trojan chunk is a content-addressed chunk made up of span, a nonce, and a payload which contains the Message
// the chunk address will have one of the targets as its prefix and thus will be forwarded to the neighbourhood of the recipient overlay address the target is derived from
// this is done by iteratively enumerating different nonces until the BMT hash of the serialization of the trojan chunk fields results in a chunk address that has one of the targets as its prefix
func (m *Message) Wrap(ctx context.Context, targets Targets) (swarm.Chunk, error) {
if err := checkTargets(targets); err != nil {
return nil, err
}
targetsLen := len(targets[0])
// serialize message
b, err := m.MarshalBinary() // TODO: this should be encrypted
if err != nil {
return nil, err
}
span := make([]byte, 8)
binary.LittleEndian.PutUint64(span, uint64(len(b)+NonceSize))
h := hasher(span, b)
f := func(nonce []byte) (swarm.Chunk, error) {
hash, err := h(nonce)
if err != nil {
return nil, err
}
if !contains(targets, hash[:targetsLen]) {
return nil, nil
}
chunk := swarm.NewChunk(swarm.NewAddress(hash), append(span, append(nonce, b...)...))
return chunk, nil
}
return mine(ctx, f)
}
// Unwrap creates a new trojan message from the given chunk payload
// this function assumes the chunk has been validated as a content-addressed chunk
// it will return the resulting message if the unwrapping is successful, and an error otherwise
func Unwrap(c swarm.Chunk) (*Message, error) {
d := c.Data()
// unmarshal chunk payload into message
m := new(Message)
// first 40 bytes are span + nonce
err := m.UnmarshalBinary(d[40:])
return m, err
}
// IsPotential returns true if the given chunk is a potential trojan
func IsPotential(c swarm.Chunk) bool {
data := c.Data()
// check for minimum chunk data length
trojanChunkMinDataLen := swarm.SpanSize + NonceSize + TopicSize + LengthSize
if len(data) < trojanChunkMinDataLen {
return false
}
// check for valid trojan message length in bytes #41 and #42
messageLen := int(binary.BigEndian.Uint16(data[40:42]))
return trojanChunkMinDataLen+messageLen <= len(data)
}
// checkTargets verifies that the list of given targets is non empty and with elements of matching size
func checkTargets(targets Targets) error {
if len(targets) == 0 {
return ErrEmptyTargets
}
validLen := len(targets[0]) // take first element as allowed length
for i := 1; i < len(targets); i++ {
if len(targets[i]) != validLen {
return ErrVarLenTargets
}
}
return nil
}
func hasher(span, b []byte) func([]byte) ([]byte, error) {
hashPool := bmtlegacy.NewTreePool(swarm.NewHasher, swarm.Branches, bmtlegacy.PoolSize)
return func(nonce []byte) ([]byte, error) {
s := append(nonce, b...) // serialize chunk fields
hasher := bmtlegacy.New(hashPool)
if err := hasher.SetSpanBytes(span); err != nil {
return nil, err
}
if _, err := hasher.Write(s); err != nil {
return nil, err
}
return hasher.Sum(nil), nil
}
}
// contains returns whether the given collection contains the given element
func contains(col Targets, elem []byte) bool {
for i := range col {
if bytes.Equal(elem, col[i]) {
return true
}
}
return false
}
// MarshalBinary serializes a message struct
func (m *Message) MarshalBinary() (data []byte, err error) {
data = append(m.length[:], m.Topic[:]...)
data = append(data, m.Payload...)
data = append(data, m.padding...)
return data, nil
}
// UnmarshalBinary deserializes a message struct
func (m *Message) UnmarshalBinary(data []byte) (err error) {
if len(data) < LengthSize+TopicSize {
return ErrUnmarshal
}
copy(m.length[:], data[:LengthSize]) // first 2 bytes are length
copy(m.Topic[:], data[LengthSize:LengthSize+TopicSize]) // following 32 bytes are topic
length := binary.BigEndian.Uint16(m.length[:])
if (len(data) - LengthSize - TopicSize) < int(length) {
return ErrUnmarshal
}
// rest of the bytes are payload and padding
payloadEnd := LengthSize + TopicSize + length
m.Payload = data[LengthSize+TopicSize : payloadEnd]
m.padding = data[payloadEnd:]
return nil
}
func mine(ctx context.Context, f func(nonce []byte) (swarm.Chunk, error)) (swarm.Chunk, error) {
seeds := make([]uint32, 8)
for i := range seeds {
seeds[i] = random.Uint32()
}
initnonce := make([]byte, 32)
for i := 0; i < 8; i++ {
binary.LittleEndian.PutUint32(initnonce[i*4:i*4+4], seeds[i])
}
quit := make(chan struct{})
// make both errs and result channels buffered so they never block
result := make(chan swarm.Chunk, 8)
errs := make(chan error, 8)
for i := 0; i < 8; i++ {
go func(j int) {
nonce := make([]byte, 32)
copy(nonce, initnonce)
for seed := seeds[j]; ; seed++ {
binary.LittleEndian.PutUint32(nonce[j*4:j*4+4], seed)
res, err := f(nonce)
if err != nil {
errs <- err
return
}
if res != nil {
result <- res
return
}
select {
case <-quit:
return
default:
}
}
}(i)
}
defer close(quit)
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errs:
return nil, err
case res := <-result:
return res, nil
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trojan_test
import (
"context"
"encoding/binary"
"errors"
"reflect"
"testing"
"time"
chunktesting "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
)
// arbitrary targets for tests
var t1 = trojan.Target([]byte{57})
var t2 = trojan.Target([]byte{209})
var t3 = trojan.Target([]byte{156})
var t4 = trojan.Target([]byte{89})
var t5 = trojan.Target([]byte{22})
var testTargets = trojan.Targets([]trojan.Target{t1, t2, t3, t4, t5})
// arbitrary topic for tests
var testTopic = trojan.NewTopic("foo")
// newTestMessage creates an arbitrary Message for tests
func newTestMessage(t *testing.T) trojan.Message {
payload := []byte("foopayload")
m, err := trojan.NewMessage(testTopic, payload)
if err != nil {
t.Fatal(err)
}
return m
}
// TestNewMessage tests the correct and incorrect creation of a Message struct
func TestNewMessage(t *testing.T) {
smallPayload := make([]byte, 32)
m, err := trojan.NewMessage(testTopic, smallPayload)
if err != nil {
t.Fatal(err)
}
// verify topic
if m.Topic != testTopic {
t.Fatalf("expected message topic to be %v but is %v instead", testTopic, m.Topic)
}
maxPayload := make([]byte, trojan.MaxPayloadSize)
if _, err := trojan.NewMessage(testTopic, maxPayload); err != nil {
t.Fatal(err)
}
// the creation should fail if the payload is too big
invalidPayload := make([]byte, trojan.MaxPayloadSize+1)
if _, err := trojan.NewMessage(testTopic, invalidPayload); err != trojan.ErrPayloadTooBig {
t.Fatalf("expected error when creating message of invalid payload size to be %q, but got %v", trojan.ErrPayloadTooBig, err)
}
}
// TestWrap tests the creation of a chunk from a list of targets
// its address length and span should be correct
// its resulting address should have a prefix which matches one of the given targets
// its resulting data should have a hash that matches its address exactly
func TestWrap(t *testing.T) {
m := newTestMessage(t)
c, err := m.Wrap(context.Background(), testTargets)
if err != nil {
t.Fatal(err)
}
addr := c.Address()
addrLen := len(addr.Bytes())
if addrLen != swarm.HashSize {
t.Fatalf("chunk has an unexpected address length of %d rather than %d", addrLen, swarm.HashSize)
}
addrPrefix := addr.Bytes()[:len(testTargets[0])]
if !trojan.Contains(testTargets, addrPrefix) {
t.Fatal("chunk address prefix does not match any of the targets")
}
data := c.Data()
dataSize := len(data)
expectedSize := swarm.ChunkWithSpanSize // span + payload
if dataSize != expectedSize {
t.Fatalf("chunk data has an unexpected size of %d rather than %d", dataSize, expectedSize)
}
span := binary.LittleEndian.Uint64(data[:8])
remainingDataLen := len(data[8:])
if int(span) != remainingDataLen {
t.Fatalf("chunk span set to %d, but rest of chunk data is of size %d", span, remainingDataLen)
}
}
// TestWrapError tests that the creation of a chunk fails when given targets are invalid
func TestWrapError(t *testing.T) {
m := newTestMessage(t)
ctx := context.Background()
emptyTargets := trojan.Targets([]trojan.Target{})
if _, err := m.Wrap(ctx, emptyTargets); err != trojan.ErrEmptyTargets {
t.Fatalf("expected error when creating chunk for empty targets to be %q, but got %v", trojan.ErrEmptyTargets, err)
}
t1 := trojan.Target([]byte{34})
t2 := trojan.Target([]byte{25, 120})
t3 := trojan.Target([]byte{180, 18, 255})
varLenTargets := trojan.Targets([]trojan.Target{t1, t2, t3})
if _, err := m.Wrap(ctx, varLenTargets); err != trojan.ErrVarLenTargets {
t.Fatalf("expected error when creating chunk for variable-length targets to be %q, but got %v", trojan.ErrVarLenTargets, err)
}
}
// TestWrapTimeout tests for mining timeout and avoid forever loop
func TestWrapTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
m := newTestMessage(t)
// a large target will take more than MinerTimeout seconds, so timeout error will be triggered
buf := make([]byte, 16)
target := trojan.Target(buf)
targets := trojan.Targets([]trojan.Target{target})
if _, err := m.Wrap(ctx, targets); err == nil || !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected context timeout, got %v", err)
}
}
// TestUnwrap tests the correct unwrapping of a trojan chunk to obtain a message
func TestUnwrap(t *testing.T) {
m := newTestMessage(t)
c, err := m.Wrap(context.Background(), testTargets)
if err != nil {
t.Fatal(err)
}
um, err := trojan.Unwrap(c)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(m, *um) {
t.Fatalf("original message does not match unwrapped one")
}
}
// TestIsPotential tests if chunks are correctly interpreted as potentially trojan
func TestIsPotential(t *testing.T) {
c := chunktesting.GenerateTestRandomChunk()
// valid type, but invalid trojan message length
length := len(c.Data()) - 73 // go 1 byte over the maximum allowed
lengthBuf := make([]byte, 2)
binary.BigEndian.PutUint16(lengthBuf, uint16(length))
// put invalid length into bytes #41 and #42
copy(c.Data()[40:42], lengthBuf)
if trojan.IsPotential(c) {
t.Fatal("chunk with invalid trojan message length marked as potential trojan")
}
// valid type, but invalid chunk data length
data := make([]byte, 10)
c = swarm.NewChunk(swarm.ZeroAddress, data)
if trojan.IsPotential(c) {
t.Fatal("chunk with invalid data length marked as potential trojan")
}
// valid potential trojan
m := newTestMessage(t)
c, err := m.Wrap(context.Background(), testTargets)
if err != nil {
t.Fatal(err)
}
if !trojan.IsPotential(c) {
t.Fatal("valid test trojan chunk not marked as potential trojan")
}
}
// TestMessageSerialization tests that the Message type can be correctly serialized and deserialized
func TestMessageSerialization(t *testing.T) {
m := newTestMessage(t)
sm, err := m.MarshalBinary()
if err != nil {
t.Fatal(err)
}
dsm := new(trojan.Message)
err = dsm.UnmarshalBinary(sm)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(m, *dsm) {
t.Fatalf("original message does not match deserialized one")
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment