Commit 2db9d81f authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

Signing in handshake (#196)

* signing in handshake & Signer interface
parent 70ff4fa9
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package crypto
import (
"crypto/ecdsa"
"github.com/btcsuite/btcd/btcec"
)
type Signer interface {
Sign(data []byte) ([]byte, error)
PublicKey() (*ecdsa.PublicKey, error)
}
// Recover verifies signature with the data base provided.
// It is using `btcec.RecoverCompact` function
func Recover(signature, data []byte) (*ecdsa.PublicKey, error) {
p, _, err := btcec.RecoverCompact(btcec.S256(), signature, data)
return (*ecdsa.PublicKey)(p), err
}
type defaultSigner struct {
key *ecdsa.PrivateKey
}
func NewDefaultSigner(key *ecdsa.PrivateKey) Signer {
return &defaultSigner{
key: key,
}
}
func (d *defaultSigner) PublicKey() (*ecdsa.PublicKey, error) {
return &d.key.PublicKey, nil
}
func (d *defaultSigner) Sign(data []byte) (signature []byte, err error) {
return btcec.SignCompact(btcec.S256(), (*btcec.PrivateKey)(d.key), data, true)
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package crypto_test
import (
"testing"
"github.com/ethersphere/bee/pkg/crypto"
)
func TestDefaultSigner(t *testing.T) {
testBytes := []byte("test string")
privKey, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
signer := crypto.NewDefaultSigner(privKey)
signature, err := signer.Sign(testBytes)
if err != nil {
t.Fatal(err)
}
t.Run("OK - sign & recover", func(t *testing.T) {
pubKey, err := crypto.Recover(signature, testBytes)
if err != nil {
t.Fatal(err)
}
if pubKey.X.Cmp(privKey.PublicKey.X) != 0 || pubKey.Y.Cmp(privKey.PublicKey.Y) != 0 {
t.Fatalf("wanted %v but got %v", pubKey, &privKey.PublicKey)
}
})
t.Run("OK - recover with invalid data", func(t *testing.T) {
pubKey, err := crypto.Recover(signature, []byte("invalid"))
if err != nil {
t.Fatal(err)
}
if pubKey.X.Cmp(privKey.PublicKey.X) == 0 && pubKey.Y.Cmp(privKey.PublicKey.Y) == 0 {
t.Fatal("expected different public key")
}
})
}
...@@ -134,13 +134,10 @@ func NewBee(o Options) (*Bee, error) { ...@@ -134,13 +134,10 @@ func NewBee(o Options) (*Bee, error) {
b.stateStoreCloser = stateStore b.stateStoreCloser = stateStore
addressbook := addressbook.New(stateStore) addressbook := addressbook.New(stateStore)
p2ps, err := libp2p.New(p2pCtx, libp2p.Options{ p2ps, err := libp2p.New(p2pCtx, crypto.NewDefaultSigner(swarmPrivateKey), o.NetworkID, address, o.Addr, libp2p.Options{
PrivateKey: libp2pPrivateKey, PrivateKey: libp2pPrivateKey,
Overlay: address,
Addr: o.Addr,
DisableWS: o.DisableWS, DisableWS: o.DisableWS,
DisableQUIC: o.DisableQUIC, DisableQUIC: o.DisableQUIC,
NetworkID: o.NetworkID,
Addressbook: addressbook, Addressbook: addressbook,
Logger: logger, Logger: logger,
Tracer: tracer, Tracer: tracer,
......
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
) )
func TestAddresses(t *testing.T) { func TestAddresses(t *testing.T) {
s, _, cleanup := newService(t, libp2p.Options{NetworkID: 1}) s, _, cleanup := newService(t, 1, libp2p.Options{})
defer cleanup() defer cleanup()
addrs, err := s.Addresses() addrs, err := s.Addresses()
...@@ -32,10 +32,10 @@ func TestConnectDisconnect(t *testing.T) { ...@@ -32,10 +32,10 @@ func TestConnectDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -60,10 +60,10 @@ func TestDoubleConnect(t *testing.T) { ...@@ -60,10 +60,10 @@ func TestDoubleConnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -87,10 +87,10 @@ func TestDoubleDisconnect(t *testing.T) { ...@@ -87,10 +87,10 @@ func TestDoubleDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -122,10 +122,10 @@ func TestMultipleConnectDisconnect(t *testing.T) { ...@@ -122,10 +122,10 @@ func TestMultipleConnectDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -165,10 +165,10 @@ func TestConnectDisconnectOnAllAddresses(t *testing.T) { ...@@ -165,10 +165,10 @@ func TestConnectDisconnectOnAllAddresses(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addrs, err := s1.Addresses() addrs, err := s1.Addresses()
...@@ -197,10 +197,10 @@ func TestDoubleConnectOnAllAddresses(t *testing.T) { ...@@ -197,10 +197,10 @@ func TestDoubleConnectOnAllAddresses(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addrs, err := s1.Addresses() addrs, err := s1.Addresses()
...@@ -235,10 +235,10 @@ func TestDifferentNetworkIDs(t *testing.T) { ...@@ -235,10 +235,10 @@ func TestDifferentNetworkIDs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, _, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, _, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{NetworkID: 2}) s2, _, cleanup2 := newService(t, 2, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -255,15 +255,13 @@ func TestConnectWithDisabledQUICAndWSTransports(t *testing.T) { ...@@ -255,15 +255,13 @@ func TestConnectWithDisabledQUICAndWSTransports(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{ s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{
NetworkID: 1,
DisableQUIC: true, DisableQUIC: true,
DisableWS: true, DisableWS: true,
}) })
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{ s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{
NetworkID: 1,
DisableQUIC: true, DisableQUIC: true,
DisableWS: true, DisableWS: true,
}) })
...@@ -284,10 +282,10 @@ func TestConnectRepeatHandshake(t *testing.T) { ...@@ -284,10 +282,10 @@ func TestConnectRepeatHandshake(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
......
...@@ -23,10 +23,10 @@ func TestHeaders(t *testing.T) { ...@@ -23,10 +23,10 @@ func TestHeaders(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
var gotHeaders p2p.Headers var gotHeaders p2p.Headers
...@@ -72,10 +72,10 @@ func TestHeaders_empty(t *testing.T) { ...@@ -72,10 +72,10 @@ func TestHeaders_empty(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
var gotHeaders p2p.Headers var gotHeaders p2p.Headers
...@@ -130,10 +130,10 @@ func TestHeadler(t *testing.T) { ...@@ -130,10 +130,10 @@ func TestHeadler(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
var gotReceivedHeaders p2p.Headers var gotReceivedHeaders p2p.Headers
......
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
package handshake package handshake
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"strconv"
"sync" "sync"
"time" "time"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb"
...@@ -27,13 +30,19 @@ const ( ...@@ -27,13 +30,19 @@ const (
messageTimeout = 5 * time.Second // maximum allowed time for a message to be read or written. messageTimeout = 5 * time.Second // maximum allowed time for a message to be read or written.
) )
// ErrNetworkIDIncompatible should be returned by handshake handlers if var (
// response from the other peer does not have valid networkID. // ErrNetworkIDIncompatible is returned if response from the other peer does not have valid networkID.
var ErrNetworkIDIncompatible = errors.New("incompatible network ID") ErrNetworkIDIncompatible = errors.New("incompatible network ID")
// ErrHandshakeDuplicate should be returned by handshake handlers if // ErrHandshakeDuplicate is returned if the handshake response has been received by an already processed peer.
// the handshake response has been received by an already processed peer. ErrHandshakeDuplicate = errors.New("duplicate handshake")
var ErrHandshakeDuplicate = errors.New("duplicate handshake")
// ErrInvalidSignature is returned if peer info was received with invalid signature
ErrInvalidSignature = errors.New("invalid signature")
// ErrInvalidAck is returned if ack does not match the syn provided
ErrInvalidAck = errors.New("invalid ack")
)
// PeerFinder has the information if the peer already exists in swarm. // PeerFinder has the information if the peer already exists in swarm.
type PeerFinder interface { type PeerFinder interface {
...@@ -42,6 +51,9 @@ type PeerFinder interface { ...@@ -42,6 +51,9 @@ type PeerFinder interface {
type Service struct { type Service struct {
overlay swarm.Address overlay swarm.Address
underlay []byte
signature []byte
signer crypto.Signer
networkID uint64 networkID uint64
receivedHandshakes map[libp2ppeer.ID]struct{} receivedHandshakes map[libp2ppeer.ID]struct{}
receivedHandshakesMu sync.Mutex receivedHandshakesMu sync.Mutex
...@@ -50,20 +62,32 @@ type Service struct { ...@@ -50,20 +62,32 @@ type Service struct {
network.Notifiee // handhsake service can be the receiver for network.Notify network.Notifiee // handhsake service can be the receiver for network.Notify
} }
func New(overlay swarm.Address, networkID uint64, logger logging.Logger) *Service { func New(overlay swarm.Address, underlay string, signer crypto.Signer, networkID uint64, logger logging.Logger) (*Service, error) {
signature, err := signer.Sign([]byte(underlay + strconv.FormatUint(networkID, 10)))
if err != nil {
return nil, err
}
return &Service{ return &Service{
overlay: overlay, overlay: overlay,
underlay: []byte(underlay),
signature: signature,
signer: signer,
networkID: networkID, networkID: networkID,
receivedHandshakes: make(map[libp2ppeer.ID]struct{}), receivedHandshakes: make(map[libp2ppeer.ID]struct{}),
logger: logger, logger: logger,
Notifiee: new(network.NoopNotifiee), Notifiee: new(network.NoopNotifiee),
} }, nil
} }
func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{
Address: s.overlay.Bytes(), BzzAddress: &pb.BzzAddress{
Underlay: s.underlay,
Signature: s.signature,
Overlay: s.overlay.Bytes(),
},
NetworkID: s.networkID, NetworkID: s.networkID,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write syn message: %w", err) return nil, fmt.Errorf("write syn message: %w", err)
...@@ -74,21 +98,25 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { ...@@ -74,21 +98,25 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("read synack message: %w", err) return nil, fmt.Errorf("read synack message: %w", err)
} }
address := swarm.NewAddress(resp.Syn.Address) if err := s.checkAck(resp.Ack); err != nil {
if resp.Syn.NetworkID != s.networkID { return nil, err
return nil, ErrNetworkIDIncompatible }
if err := s.checkSyn(resp.Syn); err != nil {
return nil, err
} }
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{
Address: resp.Syn.Address, BzzAddress: resp.Syn.BzzAddress,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write ack message: %w", err) return nil, fmt.Errorf("write ack message: %w", err)
} }
s.logger.Tracef("handshake finished for peer %s", address) s.logger.Tracef("handshake finished for peer %s", swarm.NewAddress(resp.Syn.BzzAddress.Overlay).String())
return &Info{ return &Info{
Address: address, Overlay: swarm.NewAddress(resp.Syn.BzzAddress.Overlay),
Underlay: resp.Syn.BzzAddress.Underlay,
NetworkID: resp.Syn.NetworkID, NetworkID: resp.Syn.NetworkID,
Light: resp.Syn.Light, Light: resp.Syn.Light,
}, nil }, nil
...@@ -110,17 +138,20 @@ func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err ...@@ -110,17 +138,20 @@ func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err
return nil, fmt.Errorf("read syn message: %w", err) return nil, fmt.Errorf("read syn message: %w", err)
} }
address := swarm.NewAddress(req.Address) if err := s.checkSyn(&req); err != nil {
if req.NetworkID != s.networkID { return nil, err
return nil, ErrNetworkIDIncompatible
} }
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{
Syn: &pb.Syn{ Syn: &pb.Syn{
Address: s.overlay.Bytes(), BzzAddress: &pb.BzzAddress{
Underlay: s.underlay,
Signature: s.signature,
Overlay: s.overlay.Bytes(),
},
NetworkID: s.networkID, NetworkID: s.networkID,
}, },
Ack: &pb.Ack{Address: req.Address}, Ack: &pb.Ack{BzzAddress: req.BzzAddress},
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write synack message: %w", err) return nil, fmt.Errorf("write synack message: %w", err)
} }
...@@ -130,9 +161,14 @@ func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err ...@@ -130,9 +161,14 @@ func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err
return nil, fmt.Errorf("read ack message: %w", err) return nil, fmt.Errorf("read ack message: %w", err)
} }
s.logger.Tracef("handshake finished for peer %s", address) if err := s.checkAck(&ack); err != nil {
return nil, err
}
s.logger.Tracef("handshake finished for peer %s", req.BzzAddress.Overlay)
return &Info{ return &Info{
Address: address, Overlay: swarm.NewAddress(req.BzzAddress.Overlay),
Underlay: req.BzzAddress.Underlay,
NetworkID: req.NetworkID, NetworkID: req.NetworkID,
Light: req.Light, Light: req.Light,
}, nil }, nil
...@@ -144,8 +180,37 @@ func (s *Service) Disconnected(_ network.Network, c network.Conn) { ...@@ -144,8 +180,37 @@ func (s *Service) Disconnected(_ network.Network, c network.Conn) {
delete(s.receivedHandshakes, c.RemotePeer()) delete(s.receivedHandshakes, c.RemotePeer())
} }
func (s *Service) checkSyn(syn *pb.Syn) error {
if syn.NetworkID != s.networkID {
return ErrNetworkIDIncompatible
}
recoveredPK, err := crypto.Recover(syn.BzzAddress.Signature, append(syn.BzzAddress.Underlay, strconv.FormatUint(syn.NetworkID, 10)...))
if err != nil {
return ErrInvalidSignature
}
recoveredOverlay := crypto.NewOverlayAddress(*recoveredPK, syn.NetworkID)
if !bytes.Equal(recoveredOverlay.Bytes(), syn.BzzAddress.Overlay) {
return ErrInvalidSignature
}
return nil
}
func (s *Service) checkAck(ack *pb.Ack) error {
if !bytes.Equal(ack.BzzAddress.Overlay, s.overlay.Bytes()) ||
!bytes.Equal(ack.BzzAddress.Underlay, s.underlay) ||
!bytes.Equal(ack.BzzAddress.Signature, s.signature) {
return ErrInvalidAck
}
return nil
}
type Info struct { type Info struct {
Address swarm.Address Overlay swarm.Address
Underlay []byte
NetworkID uint64 NetworkID uint64
Light bool Light bool
} }
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package handshake package handshake_test
import ( import (
"bytes" "bytes"
...@@ -11,7 +11,9 @@ import ( ...@@ -11,7 +11,9 @@ import (
"io/ioutil" "io/ioutil"
"testing" "testing"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/mock" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/mock"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
...@@ -22,24 +24,67 @@ import ( ...@@ -22,24 +24,67 @@ import (
) )
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
node1Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c") node1Underlay := []byte("underlay1")
node2Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59b") node2Underlay := []byte("underlay2")
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
info := Info{
Address: node1Addr, privateKey1, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
privateKey2, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
node1Overlay := crypto.NewOverlayAddress(privateKey1.PublicKey, 0)
node2Overlay := crypto.NewOverlayAddress(privateKey2.PublicKey, 0)
signer1 := crypto.NewDefaultSigner(privateKey1)
signer2 := crypto.NewDefaultSigner(privateKey2)
signature1, err := signer1.Sign([]byte("underlay10"))
if err != nil {
t.Fatal(err)
}
signature2, err := signer2.Sign([]byte("underlay20"))
if err != nil {
t.Fatal(err)
}
node1Info := handshake.Info{
Overlay: node1Overlay,
Underlay: node1Underlay,
NetworkID: 0, NetworkID: 0,
Light: false, Light: false,
} }
handshakeService := New(info.Address, info.NetworkID, logger) node1BzzAddress := &pb.BzzAddress{
Overlay: node1Info.Overlay.Bytes(),
Underlay: node1Info.Underlay,
Signature: signature1,
}
t.Run("OK", func(t *testing.T) { node2Info := handshake.Info{
expectedInfo := Info{ Overlay: node2Overlay,
Address: node2Addr, Underlay: node2Underlay,
NetworkID: 0, NetworkID: 0,
Light: false, Light: false,
} }
node2BzzAddress := &pb.BzzAddress{
Overlay: node2Info.Overlay.Bytes(),
Underlay: node2Info.Underlay,
Signature: signature2,
}
handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
if err != nil {
t.Fatal(err)
}
t.Run("OK", func(t *testing.T) {
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2) stream1 := mock.NewStream(&buffer1, &buffer2)
...@@ -48,11 +93,11 @@ func TestHandshake(t *testing.T) { ...@@ -48,11 +93,11 @@ func TestHandshake(t *testing.T) {
w, r := protobuf.NewWriterAndReader(stream2) w, r := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.SynAck{ if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{ Syn: &pb.Syn{
Address: expectedInfo.Address.Bytes(), BzzAddress: node2BzzAddress,
NetworkID: expectedInfo.NetworkID, NetworkID: node2Info.NetworkID,
Light: expectedInfo.Light, Light: node2Info.Light,
}, },
Ack: &pb.Ack{Address: info.Address.Bytes()}, Ack: &pb.Ack{BzzAddress: node1BzzAddress},
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -62,7 +107,7 @@ func TestHandshake(t *testing.T) { ...@@ -62,7 +107,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
testInfo(t, *res, expectedInfo) testInfo(t, *res, node2Info)
if err := r.ReadMsg(&pb.Ack{}); err != nil { if err := r.ReadMsg(&pb.Ack{}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -101,26 +146,20 @@ func TestHandshake(t *testing.T) { ...@@ -101,26 +146,20 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - ack write error", func(t *testing.T) { t.Run("ERROR - ack write error", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write ack message: %w", testErr) expectedErr := fmt.Errorf("write ack message: %w", testErr)
expectedInfo := Info{
Address: node2Addr,
NetworkID: 0,
Light: false,
}
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2) stream1 := mock.NewStream(&buffer1, &buffer2)
stream1.SetWriteErr(testErr, 1) stream1.SetWriteErr(testErr, 1)
stream2 := mock.NewStream(&buffer2, &buffer1) stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2) w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.SynAck{ if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{ Syn: &pb.Syn{
Address: expectedInfo.Address.Bytes(), BzzAddress: node2BzzAddress,
NetworkID: expectedInfo.NetworkID, NetworkID: node2Info.NetworkID,
Light: expectedInfo.Light, Light: node2Info.Light,
}, },
Ack: &pb.Ack{Address: info.Address.Bytes()}, Ack: &pb.Ack{BzzAddress: node1BzzAddress},
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -136,25 +175,79 @@ func TestHandshake(t *testing.T) { ...@@ -136,25 +175,79 @@ func TestHandshake(t *testing.T) {
}) })
t.Run("ERROR - networkID mismatch", func(t *testing.T) { t.Run("ERROR - networkID mismatch", func(t *testing.T) {
node2Info := Info{ var buffer1 bytes.Buffer
Address: node2Addr, var buffer2 bytes.Buffer
NetworkID: 2, stream1 := mock.NewStream(&buffer1, &buffer2)
Light: false, stream2 := mock.NewStream(&buffer2, &buffer1)
w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{
BzzAddress: node2BzzAddress,
NetworkID: 5,
Light: node2Info.Light,
},
Ack: &pb.Ack{BzzAddress: node1BzzAddress},
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handshake(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != handshake.ErrNetworkIDIncompatible {
t.Fatalf("expected %s, got %s", handshake.ErrNetworkIDIncompatible, err)
}
})
t.Run("ERROR - invalid ack", func(t *testing.T) {
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{
BzzAddress: node2BzzAddress,
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
},
Ack: &pb.Ack{BzzAddress: node2BzzAddress},
}); err != nil {
t.Fatal(err)
} }
res, err := handshakeService.Handshake(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != handshake.ErrInvalidAck {
t.Fatalf("expected %s, got %s", handshake.ErrInvalidAck, err)
}
})
t.Run("ERROR - invalid signature", func(t *testing.T) {
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2) stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1) stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2) w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.SynAck{ if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{ Syn: &pb.Syn{
Address: node2Info.Address.Bytes(), BzzAddress: &pb.BzzAddress{
Underlay: node2BzzAddress.Underlay,
Signature: []byte("wrong signature"),
Overlay: node2BzzAddress.Overlay,
},
NetworkID: node2Info.NetworkID, NetworkID: node2Info.NetworkID,
Light: node2Info.Light, Light: node2Info.Light,
}, },
Ack: &pb.Ack{Address: info.Address.Bytes()}, Ack: &pb.Ack{BzzAddress: node1BzzAddress},
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -164,37 +257,80 @@ func TestHandshake(t *testing.T) { ...@@ -164,37 +257,80 @@ func TestHandshake(t *testing.T) {
t.Fatal("res should be nil") t.Fatal("res should be nil")
} }
if err != ErrNetworkIDIncompatible { if err != handshake.ErrInvalidSignature {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err) t.Fatalf("expected %s, got %s", handshake.ErrInvalidSignature, err)
} }
}) })
} }
func TestHandle(t *testing.T) { func TestHandle(t *testing.T) {
node1Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c") privateKey1, err := crypto.GenerateSecp256k1Key()
node2Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59b")
multiaddress, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/7070/p2p/16Uiu2HAkx8ULY8cTXhdVAcMmLcH9AsTKz6uBQ7DPLKRjMLgBVYkS")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
info, err := peer.AddrInfoFromP2pAddr(multiaddress) privateKey2, err := crypto.GenerateSecp256k1Key()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
nodeInfo := Info{
Address: node1Addr, node1Overlay := crypto.NewOverlayAddress(privateKey1.PublicKey, 0)
node2Overlay := crypto.NewOverlayAddress(privateKey2.PublicKey, 0)
node2ma, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/7070/p2p/16Uiu2HAkx8ULY8cTXhdVAcMmLcH9AsTKz6uBQ7DPLKRjMLgBVYkS")
if err != nil {
t.Fatal(err)
}
node2AddrInfo, err := peer.AddrInfoFromP2pAddr(node2ma)
if err != nil {
t.Fatal(err)
}
node1Underlay := []byte("underlay1")
node2Underlay := []byte("16Uiu2HAkx8ULY8cTXhdVAcMmLcH9AsTKz6uBQ7DPLKRjMLgBVYkS")
node1Info := handshake.Info{
Overlay: node1Overlay,
Underlay: node1Underlay,
NetworkID: 0, NetworkID: 0,
Light: false, Light: false,
} }
signer1 := crypto.NewDefaultSigner(privateKey1)
signer2 := crypto.NewDefaultSigner(privateKey2)
signature1, err := signer1.Sign([]byte("underlay10"))
if err != nil {
t.Fatal(err)
}
signature2, err := signer2.Sign([]byte("16Uiu2HAkx8ULY8cTXhdVAcMmLcH9AsTKz6uBQ7DPLKRjMLgBVYkS0"))
if err != nil {
t.Fatal(err)
}
node2Info := handshake.Info{
Overlay: node2Overlay,
Underlay: node2Underlay,
NetworkID: 0,
Light: false,
}
node1BzzAddress := &pb.BzzAddress{
Overlay: node1Info.Overlay.Bytes(),
Underlay: node1Info.Underlay,
Signature: signature1,
}
node2BzzAddress := &pb.BzzAddress{
Overlay: node2Info.Overlay.Bytes(),
Underlay: node2Info.Underlay,
Signature: signature2,
}
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger) handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
node2Info := Info{ if err != nil {
Address: node2Addr, t.Fatal(err)
NetworkID: 0,
Light: false,
} }
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
...@@ -202,20 +338,20 @@ func TestHandle(t *testing.T) { ...@@ -202,20 +338,20 @@ func TestHandle(t *testing.T) {
stream1 := mock.NewStream(&buffer1, &buffer2) stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1) stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2) w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.Syn{ if err := w.WriteMsg(&pb.Syn{
Address: node2Info.Address.Bytes(), BzzAddress: node2BzzAddress,
NetworkID: node2Info.NetworkID, NetworkID: node2Info.NetworkID,
Light: node2Info.Light, Light: node2Info.Light,
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := w.WriteMsg(&pb.Ack{Address: node2Info.Address.Bytes()}); err != nil { if err := w.WriteMsg(&pb.Ack{BzzAddress: node1BzzAddress}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, info.ID) res, err := handshakeService.Handle(stream1, node2AddrInfo.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -228,20 +364,25 @@ func TestHandle(t *testing.T) { ...@@ -228,20 +364,25 @@ func TestHandle(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
testInfo(t, nodeInfo, Info{ testInfo(t, node1Info, handshake.Info{
Address: swarm.NewAddress(got.Syn.Address), Overlay: swarm.NewAddress(got.Syn.BzzAddress.Overlay),
Underlay: got.Syn.BzzAddress.Underlay,
NetworkID: got.Syn.NetworkID, NetworkID: got.Syn.NetworkID,
Light: got.Syn.Light, Light: got.Syn.Light,
}) })
}) })
t.Run("ERROR - read error ", func(t *testing.T) { t.Run("ERROR - read error ", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger) handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
if err != nil {
t.Fatal(err)
}
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("read syn message: %w", testErr) expectedErr := fmt.Errorf("read syn message: %w", testErr)
stream := &mock.Stream{} stream := &mock.Stream{}
stream.SetReadErr(testErr, 0) stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handle(stream, info.ID) res, err := handshakeService.Handle(stream, node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -252,22 +393,26 @@ func TestHandle(t *testing.T) { ...@@ -252,22 +393,26 @@ func TestHandle(t *testing.T) {
}) })
t.Run("ERROR - write error ", func(t *testing.T) { t.Run("ERROR - write error ", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger) handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
if err != nil {
t.Fatal(err)
}
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write synack message: %w", testErr) expectedErr := fmt.Errorf("write synack message: %w", testErr)
var buffer bytes.Buffer var buffer bytes.Buffer
stream := mock.NewStream(&buffer, &buffer) stream := mock.NewStream(&buffer, &buffer)
stream.SetWriteErr(testErr, 1) stream.SetWriteErr(testErr, 1)
w, _ := protobuf.NewWriterAndReader(stream) w := protobuf.NewWriter(stream)
if err := w.WriteMsg(&pb.Syn{ if err := w.WriteMsg(&pb.Syn{
Address: node1Addr.Bytes(), BzzAddress: node2BzzAddress,
NetworkID: 0, NetworkID: node2Info.NetworkID,
Light: false, Light: node2Info.Light,
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream, info.ID) res, err := handshakeService.Handle(stream, node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -278,30 +423,28 @@ func TestHandle(t *testing.T) { ...@@ -278,30 +423,28 @@ func TestHandle(t *testing.T) {
}) })
t.Run("ERROR - ack read error ", func(t *testing.T) { t.Run("ERROR - ack read error ", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger) handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
testErr := errors.New("test error") if err != nil {
expectedErr := fmt.Errorf("read ack message: %w", testErr) t.Fatal(err)
node2Info := Info{
Address: node2Addr,
NetworkID: 0,
Light: false,
} }
testErr := errors.New("test error")
expectedErr := fmt.Errorf("read ack message: %w", testErr)
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2) stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1) stream2 := mock.NewStream(&buffer2, &buffer1)
stream1.SetReadErr(testErr, 1) stream1.SetReadErr(testErr, 1)
w, _ := protobuf.NewWriterAndReader(stream2) w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.Syn{ if err := w.WriteMsg(&pb.Syn{
Address: node2Info.Address.Bytes(), BzzAddress: node2BzzAddress,
NetworkID: node2Info.NetworkID, NetworkID: node2Info.NetworkID,
Light: node2Info.Light, Light: node2Info.Light,
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, info.ID) res, err := handshakeService.Handle(stream1, node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -312,11 +455,9 @@ func TestHandle(t *testing.T) { ...@@ -312,11 +455,9 @@ func TestHandle(t *testing.T) {
}) })
t.Run("ERROR - networkID mismatch ", func(t *testing.T) { t.Run("ERROR - networkID mismatch ", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger) handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
node2Info := Info{ if err != nil {
Address: node2Addr, t.Fatal(err)
NetworkID: 2,
Light: false,
} }
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
...@@ -324,31 +465,29 @@ func TestHandle(t *testing.T) { ...@@ -324,31 +465,29 @@ func TestHandle(t *testing.T) {
stream1 := mock.NewStream(&buffer1, &buffer2) stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1) stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2) w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.Syn{ if err := w.WriteMsg(&pb.Syn{
Address: node2Info.Address.Bytes(), BzzAddress: node2BzzAddress,
NetworkID: node2Info.NetworkID, NetworkID: 5,
Light: node2Info.Light, Light: node2Info.Light,
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, info.ID) res, err := handshakeService.Handle(stream1, node2AddrInfo.ID)
if res != nil { if res != nil {
t.Fatal("res should be nil") t.Fatal("res should be nil")
} }
if err != ErrNetworkIDIncompatible { if err != handshake.ErrNetworkIDIncompatible {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err) t.Fatalf("expected %s, got %s", handshake.ErrNetworkIDIncompatible, err)
} }
}) })
t.Run("ERROR - duplicate handshake", func(t *testing.T) { t.Run("ERROR - duplicate handshake", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger) handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
node2Info := Info{ if err != nil {
Address: node2Addr, t.Fatal(err)
NetworkID: 0,
Light: false,
} }
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
...@@ -356,20 +495,20 @@ func TestHandle(t *testing.T) { ...@@ -356,20 +495,20 @@ func TestHandle(t *testing.T) {
stream1 := mock.NewStream(&buffer1, &buffer2) stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1) stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2) w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.Syn{ if err := w.WriteMsg(&pb.Syn{
Address: node2Info.Address.Bytes(), BzzAddress: node2BzzAddress,
NetworkID: node2Info.NetworkID, NetworkID: node2Info.NetworkID,
Light: node2Info.Light, Light: node2Info.Light,
}); err != nil { }); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if err := w.WriteMsg(&pb.Ack{Address: node2Info.Address.Bytes()}); err != nil { if err := w.WriteMsg(&pb.Ack{BzzAddress: node1BzzAddress}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, info.ID) res, err := handshakeService.Handle(stream1, node2AddrInfo.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -382,24 +521,89 @@ func TestHandle(t *testing.T) { ...@@ -382,24 +521,89 @@ func TestHandle(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
testInfo(t, nodeInfo, Info{ testInfo(t, node1Info, handshake.Info{
Address: swarm.NewAddress(got.Syn.Address), Overlay: swarm.NewAddress(got.Syn.BzzAddress.Overlay),
Underlay: got.Syn.BzzAddress.Underlay,
NetworkID: got.Syn.NetworkID, NetworkID: got.Syn.NetworkID,
Light: got.Syn.Light, Light: got.Syn.Light,
}) })
_, err = handshakeService.Handle(stream1, info.ID) _, err = handshakeService.Handle(stream1, node2AddrInfo.ID)
if err != ErrHandshakeDuplicate { if err != handshake.ErrHandshakeDuplicate {
t.Fatalf("expected %s err, got %s err", ErrHandshakeDuplicate, err) t.Fatalf("expected %s, got %s", handshake.ErrHandshakeDuplicate, err)
}
})
t.Run("Error - invalid ack", func(t *testing.T) {
handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
if err != nil {
t.Fatal(err)
}
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.Syn{
BzzAddress: node2BzzAddress,
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
}); err != nil {
t.Fatal(err)
}
if err := w.WriteMsg(&pb.Ack{BzzAddress: node2BzzAddress}); err != nil {
t.Fatal(err)
}
_, err = handshakeService.Handle(stream1, node2AddrInfo.ID)
if err != handshake.ErrInvalidAck {
t.Fatalf("expected %s, got %s", handshake.ErrInvalidAck, err)
}
})
t.Run("ERROR - invalid signature ", func(t *testing.T) {
handshakeService, err := handshake.New(node1Info.Overlay, string(node1Info.Underlay), signer1, 0, logger)
if err != nil {
t.Fatal(err)
}
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w := protobuf.NewWriter(stream2)
if err := w.WriteMsg(&pb.Syn{
BzzAddress: &pb.BzzAddress{
Underlay: node2BzzAddress.Underlay,
Signature: []byte("wrong signature"),
Overlay: node2BzzAddress.Overlay,
},
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handle(stream1, node2AddrInfo.ID)
if res != nil {
t.Fatal("res should be nil")
}
if err != handshake.ErrInvalidSignature {
t.Fatalf("expected %s, got %s", handshake.ErrInvalidSignature, err)
} }
}) })
} }
// testInfo validates if two Info instances are equal. // testInfo validates if two Info instances are equal.
func testInfo(t *testing.T, got, want Info) { func testInfo(t *testing.T, got, want handshake.Info) {
t.Helper() t.Helper()
if !got.Address.Equal(want.Address) || got.NetworkID != want.NetworkID || got.Light != want.Light { if !got.Overlay.Equal(want.Overlay) || !bytes.Equal(got.Underlay, want.Underlay) || got.NetworkID != want.NetworkID || got.Light != want.Light {
t.Fatalf("got info %+v, want %+v", got, want) t.Fatalf("got info %+v, want %+v", got, want)
} }
} }
...@@ -23,9 +23,9 @@ var _ = math.Inf ...@@ -23,9 +23,9 @@ var _ = math.Inf
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type Syn struct { type Syn struct {
Address []byte `protobuf:"bytes,1,opt,name=Address,proto3" json:"Address,omitempty"` BzzAddress *BzzAddress `protobuf:"bytes,1,opt,name=BzzAddress,proto3" json:"BzzAddress,omitempty"`
NetworkID uint64 `protobuf:"varint,2,opt,name=NetworkID,proto3" json:"NetworkID,omitempty"` NetworkID uint64 `protobuf:"varint,2,opt,name=NetworkID,proto3" json:"NetworkID,omitempty"`
Light bool `protobuf:"varint,3,opt,name=Light,proto3" json:"Light,omitempty"` Light bool `protobuf:"varint,3,opt,name=Light,proto3" json:"Light,omitempty"`
} }
func (m *Syn) Reset() { *m = Syn{} } func (m *Syn) Reset() { *m = Syn{} }
...@@ -61,9 +61,9 @@ func (m *Syn) XXX_DiscardUnknown() { ...@@ -61,9 +61,9 @@ func (m *Syn) XXX_DiscardUnknown() {
var xxx_messageInfo_Syn proto.InternalMessageInfo var xxx_messageInfo_Syn proto.InternalMessageInfo
func (m *Syn) GetAddress() []byte { func (m *Syn) GetBzzAddress() *BzzAddress {
if m != nil { if m != nil {
return m.Address return m.BzzAddress
} }
return nil return nil
} }
...@@ -82,6 +82,50 @@ func (m *Syn) GetLight() bool { ...@@ -82,6 +82,50 @@ func (m *Syn) GetLight() bool {
return false return false
} }
type Ack struct {
BzzAddress *BzzAddress `protobuf:"bytes,1,opt,name=BzzAddress,proto3" json:"BzzAddress,omitempty"`
}
func (m *Ack) Reset() { *m = Ack{} }
func (m *Ack) String() string { return proto.CompactTextString(m) }
func (*Ack) ProtoMessage() {}
func (*Ack) Descriptor() ([]byte, []int) {
return fileDescriptor_a77305914d5d202f, []int{1}
}
func (m *Ack) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Ack) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Ack.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Ack) XXX_Merge(src proto.Message) {
xxx_messageInfo_Ack.Merge(m, src)
}
func (m *Ack) XXX_Size() int {
return m.Size()
}
func (m *Ack) XXX_DiscardUnknown() {
xxx_messageInfo_Ack.DiscardUnknown(m)
}
var xxx_messageInfo_Ack proto.InternalMessageInfo
func (m *Ack) GetBzzAddress() *BzzAddress {
if m != nil {
return m.BzzAddress
}
return nil
}
type SynAck struct { type SynAck struct {
Syn *Syn `protobuf:"bytes,1,opt,name=Syn,proto3" json:"Syn,omitempty"` Syn *Syn `protobuf:"bytes,1,opt,name=Syn,proto3" json:"Syn,omitempty"`
Ack *Ack `protobuf:"bytes,2,opt,name=Ack,proto3" json:"Ack,omitempty"` Ack *Ack `protobuf:"bytes,2,opt,name=Ack,proto3" json:"Ack,omitempty"`
...@@ -91,7 +135,7 @@ func (m *SynAck) Reset() { *m = SynAck{} } ...@@ -91,7 +135,7 @@ func (m *SynAck) Reset() { *m = SynAck{} }
func (m *SynAck) String() string { return proto.CompactTextString(m) } func (m *SynAck) String() string { return proto.CompactTextString(m) }
func (*SynAck) ProtoMessage() {} func (*SynAck) ProtoMessage() {}
func (*SynAck) Descriptor() ([]byte, []int) { func (*SynAck) Descriptor() ([]byte, []int) {
return fileDescriptor_a77305914d5d202f, []int{1} return fileDescriptor_a77305914d5d202f, []int{2}
} }
func (m *SynAck) XXX_Unmarshal(b []byte) error { func (m *SynAck) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b) return m.Unmarshal(b)
...@@ -134,22 +178,24 @@ func (m *SynAck) GetAck() *Ack { ...@@ -134,22 +178,24 @@ func (m *SynAck) GetAck() *Ack {
return nil return nil
} }
type Ack struct { type BzzAddress struct {
Address []byte `protobuf:"bytes,1,opt,name=Address,proto3" json:"Address,omitempty"` Underlay []byte `protobuf:"bytes,1,opt,name=Underlay,proto3" json:"Underlay,omitempty"`
Signature []byte `protobuf:"bytes,2,opt,name=Signature,proto3" json:"Signature,omitempty"`
Overlay []byte `protobuf:"bytes,3,opt,name=Overlay,proto3" json:"Overlay,omitempty"`
} }
func (m *Ack) Reset() { *m = Ack{} } func (m *BzzAddress) Reset() { *m = BzzAddress{} }
func (m *Ack) String() string { return proto.CompactTextString(m) } func (m *BzzAddress) String() string { return proto.CompactTextString(m) }
func (*Ack) ProtoMessage() {} func (*BzzAddress) ProtoMessage() {}
func (*Ack) Descriptor() ([]byte, []int) { func (*BzzAddress) Descriptor() ([]byte, []int) {
return fileDescriptor_a77305914d5d202f, []int{2} return fileDescriptor_a77305914d5d202f, []int{3}
} }
func (m *Ack) XXX_Unmarshal(b []byte) error { func (m *BzzAddress) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b) return m.Unmarshal(b)
} }
func (m *Ack) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { func (m *BzzAddress) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic { if deterministic {
return xxx_messageInfo_Ack.Marshal(b, m, deterministic) return xxx_messageInfo_BzzAddress.Marshal(b, m, deterministic)
} else { } else {
b = b[:cap(b)] b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b) n, err := m.MarshalToSizedBuffer(b)
...@@ -159,48 +205,67 @@ func (m *Ack) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { ...@@ -159,48 +205,67 @@ func (m *Ack) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return b[:n], nil return b[:n], nil
} }
} }
func (m *Ack) XXX_Merge(src proto.Message) { func (m *BzzAddress) XXX_Merge(src proto.Message) {
xxx_messageInfo_Ack.Merge(m, src) xxx_messageInfo_BzzAddress.Merge(m, src)
} }
func (m *Ack) XXX_Size() int { func (m *BzzAddress) XXX_Size() int {
return m.Size() return m.Size()
} }
func (m *Ack) XXX_DiscardUnknown() { func (m *BzzAddress) XXX_DiscardUnknown() {
xxx_messageInfo_Ack.DiscardUnknown(m) xxx_messageInfo_BzzAddress.DiscardUnknown(m)
} }
var xxx_messageInfo_Ack proto.InternalMessageInfo var xxx_messageInfo_BzzAddress proto.InternalMessageInfo
func (m *BzzAddress) GetUnderlay() []byte {
if m != nil {
return m.Underlay
}
return nil
}
func (m *BzzAddress) GetSignature() []byte {
if m != nil {
return m.Signature
}
return nil
}
func (m *Ack) GetAddress() []byte { func (m *BzzAddress) GetOverlay() []byte {
if m != nil { if m != nil {
return m.Address return m.Overlay
} }
return nil return nil
} }
func init() { func init() {
proto.RegisterType((*Syn)(nil), "handshake.Syn") proto.RegisterType((*Syn)(nil), "handshake.Syn")
proto.RegisterType((*SynAck)(nil), "handshake.SynAck")
proto.RegisterType((*Ack)(nil), "handshake.Ack") proto.RegisterType((*Ack)(nil), "handshake.Ack")
proto.RegisterType((*SynAck)(nil), "handshake.SynAck")
proto.RegisterType((*BzzAddress)(nil), "handshake.BzzAddress")
} }
func init() { proto.RegisterFile("handshake.proto", fileDescriptor_a77305914d5d202f) } func init() { proto.RegisterFile("handshake.proto", fileDescriptor_a77305914d5d202f) }
var fileDescriptor_a77305914d5d202f = []byte{ var fileDescriptor_a77305914d5d202f = []byte{
// 199 bytes of a gzipped FileDescriptorProto // 257 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcf, 0x48, 0xcc, 0x4b, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcf, 0x48, 0xcc, 0x4b,
0x29, 0xce, 0x48, 0xcc, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x84, 0x0b, 0x28, 0x29, 0xce, 0x48, 0xcc, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x84, 0x0b, 0x28,
0x05, 0x73, 0x31, 0x07, 0x57, 0xe6, 0x09, 0x49, 0x70, 0xb1, 0x3b, 0xa6, 0xa4, 0x14, 0xa5, 0x16, 0x15, 0x70, 0x31, 0x07, 0x57, 0xe6, 0x09, 0x99, 0x72, 0x71, 0x39, 0x55, 0x55, 0x39, 0xa6, 0xa4,
0x17, 0x4b, 0x30, 0x2a, 0x30, 0x6a, 0xf0, 0x04, 0xc1, 0xb8, 0x42, 0x32, 0x5c, 0x9c, 0x7e, 0xa9, 0x14, 0xa5, 0x16, 0x17, 0x4b, 0x30, 0x2a, 0x30, 0x6a, 0x70, 0x1b, 0x89, 0xea, 0x21, 0xf4, 0x21,
0x25, 0xe5, 0xf9, 0x45, 0xd9, 0x9e, 0x2e, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0x2c, 0x41, 0x08, 0x01, 0x24, 0x83, 0x90, 0x14, 0x0a, 0xc9, 0x70, 0x71, 0xfa, 0xa5, 0x96, 0x94, 0xe7, 0x17, 0x65, 0x7b,
0x21, 0x11, 0x2e, 0x56, 0x9f, 0xcc, 0xf4, 0x8c, 0x12, 0x09, 0x66, 0x05, 0x46, 0x0d, 0x8e, 0x20, 0xba, 0x48, 0x30, 0x29, 0x30, 0x6a, 0xb0, 0x04, 0x21, 0x04, 0x84, 0x44, 0xb8, 0x58, 0x7d, 0x32,
0x08, 0x47, 0xc9, 0x87, 0x8b, 0x2d, 0xb8, 0x32, 0xcf, 0x31, 0x39, 0x5b, 0x48, 0x01, 0x6c, 0x3c, 0xd3, 0x33, 0x4a, 0x24, 0x98, 0x15, 0x18, 0x35, 0x38, 0x82, 0x20, 0x1c, 0x25, 0x1b, 0x2e, 0x66,
0xd8, 0x4c, 0x6e, 0x23, 0x3e, 0x3d, 0x84, 0x43, 0x82, 0x2b, 0xf3, 0x82, 0xc0, 0x36, 0x2b, 0x70, 0xc7, 0xe4, 0x6c, 0x32, 0x6d, 0x54, 0xf2, 0xe1, 0x62, 0x0b, 0xae, 0xcc, 0x03, 0x19, 0xa0, 0x00,
0x31, 0x3b, 0x26, 0x67, 0x83, 0x4d, 0x46, 0x55, 0xe1, 0x98, 0x9c, 0x1d, 0x04, 0x92, 0x52, 0x92, 0x76, 0x39, 0x54, 0x27, 0x1f, 0x92, 0xce, 0xe0, 0xca, 0xbc, 0x20, 0xb0, 0xa7, 0x14, 0xc0, 0x36,
0x07, 0xab, 0xc0, 0xed, 0x44, 0x27, 0x99, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0x81, 0xdd, 0x85, 0xaa, 0xc2, 0x31, 0x39, 0x3b, 0x08, 0x24, 0xa5, 0x94, 0x80, 0xec, 0x08, 0x21,
0xf0, 0x48, 0x8e, 0x71, 0xc2, 0x63, 0x39, 0x86, 0x0b, 0x8f, 0xe5, 0x18, 0x6e, 0x3c, 0x96, 0x63, 0x29, 0x2e, 0x8e, 0xd0, 0xbc, 0x94, 0xd4, 0xa2, 0x9c, 0xc4, 0x4a, 0xb0, 0xb1, 0x3c, 0x41, 0x70,
0x88, 0x62, 0x2a, 0x48, 0x4a, 0x62, 0x03, 0xfb, 0xd9, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x04, 0x3e, 0xc8, 0xa7, 0xc1, 0x99, 0xe9, 0x79, 0x89, 0x25, 0xa5, 0x45, 0xa9, 0x60, 0x13, 0x79, 0x82,
0x2c, 0x04, 0x89, 0x06, 0x01, 0x00, 0x00, 0x10, 0x02, 0x42, 0x12, 0x5c, 0xec, 0xfe, 0x65, 0x10, 0x8d, 0xcc, 0x60, 0x39, 0x18, 0xd7, 0x49,
0xe6, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58,
0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0xa2, 0x98, 0x0a, 0x92, 0x92, 0xd8,
0xc0, 0xf1, 0x61, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x02, 0xb3, 0x32, 0xdf, 0xa2, 0x01, 0x00,
0x00,
} }
func (m *Syn) Marshal() (dAtA []byte, err error) { func (m *Syn) Marshal() (dAtA []byte, err error) {
...@@ -238,10 +303,50 @@ func (m *Syn) MarshalToSizedBuffer(dAtA []byte) (int, error) { ...@@ -238,10 +303,50 @@ func (m *Syn) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i-- i--
dAtA[i] = 0x10 dAtA[i] = 0x10
} }
if len(m.Address) > 0 { if m.BzzAddress != nil {
i -= len(m.Address) {
copy(dAtA[i:], m.Address) size, err := m.BzzAddress.MarshalToSizedBuffer(dAtA[:i])
i = encodeVarintHandshake(dAtA, i, uint64(len(m.Address))) if err != nil {
return 0, err
}
i -= size
i = encodeVarintHandshake(dAtA, i, uint64(size))
}
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func (m *Ack) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Ack) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.BzzAddress != nil {
{
size, err := m.BzzAddress.MarshalToSizedBuffer(dAtA[:i])
if err != nil {
return 0, err
}
i -= size
i = encodeVarintHandshake(dAtA, i, uint64(size))
}
i-- i--
dAtA[i] = 0xa dAtA[i] = 0xa
} }
...@@ -295,7 +400,7 @@ func (m *SynAck) MarshalToSizedBuffer(dAtA []byte) (int, error) { ...@@ -295,7 +400,7 @@ func (m *SynAck) MarshalToSizedBuffer(dAtA []byte) (int, error) {
return len(dAtA) - i, nil return len(dAtA) - i, nil
} }
func (m *Ack) Marshal() (dAtA []byte, err error) { func (m *BzzAddress) Marshal() (dAtA []byte, err error) {
size := m.Size() size := m.Size()
dAtA = make([]byte, size) dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size]) n, err := m.MarshalToSizedBuffer(dAtA[:size])
...@@ -305,20 +410,34 @@ func (m *Ack) Marshal() (dAtA []byte, err error) { ...@@ -305,20 +410,34 @@ func (m *Ack) Marshal() (dAtA []byte, err error) {
return dAtA[:n], nil return dAtA[:n], nil
} }
func (m *Ack) MarshalTo(dAtA []byte) (int, error) { func (m *BzzAddress) MarshalTo(dAtA []byte) (int, error) {
size := m.Size() size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size]) return m.MarshalToSizedBuffer(dAtA[:size])
} }
func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) { func (m *BzzAddress) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA) i := len(dAtA)
_ = i _ = i
var l int var l int
_ = l _ = l
if len(m.Address) > 0 { if len(m.Overlay) > 0 {
i -= len(m.Address) i -= len(m.Overlay)
copy(dAtA[i:], m.Address) copy(dAtA[i:], m.Overlay)
i = encodeVarintHandshake(dAtA, i, uint64(len(m.Address))) i = encodeVarintHandshake(dAtA, i, uint64(len(m.Overlay)))
i--
dAtA[i] = 0x1a
}
if len(m.Signature) > 0 {
i -= len(m.Signature)
copy(dAtA[i:], m.Signature)
i = encodeVarintHandshake(dAtA, i, uint64(len(m.Signature)))
i--
dAtA[i] = 0x12
}
if len(m.Underlay) > 0 {
i -= len(m.Underlay)
copy(dAtA[i:], m.Underlay)
i = encodeVarintHandshake(dAtA, i, uint64(len(m.Underlay)))
i-- i--
dAtA[i] = 0xa dAtA[i] = 0xa
} }
...@@ -342,8 +461,8 @@ func (m *Syn) Size() (n int) { ...@@ -342,8 +461,8 @@ func (m *Syn) Size() (n int) {
} }
var l int var l int
_ = l _ = l
l = len(m.Address) if m.BzzAddress != nil {
if l > 0 { l = m.BzzAddress.Size()
n += 1 + l + sovHandshake(uint64(l)) n += 1 + l + sovHandshake(uint64(l))
} }
if m.NetworkID != 0 { if m.NetworkID != 0 {
...@@ -355,6 +474,19 @@ func (m *Syn) Size() (n int) { ...@@ -355,6 +474,19 @@ func (m *Syn) Size() (n int) {
return n return n
} }
func (m *Ack) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if m.BzzAddress != nil {
l = m.BzzAddress.Size()
n += 1 + l + sovHandshake(uint64(l))
}
return n
}
func (m *SynAck) Size() (n int) { func (m *SynAck) Size() (n int) {
if m == nil { if m == nil {
return 0 return 0
...@@ -372,13 +504,21 @@ func (m *SynAck) Size() (n int) { ...@@ -372,13 +504,21 @@ func (m *SynAck) Size() (n int) {
return n return n
} }
func (m *Ack) Size() (n int) { func (m *BzzAddress) Size() (n int) {
if m == nil { if m == nil {
return 0 return 0
} }
var l int var l int
_ = l _ = l
l = len(m.Address) l = len(m.Underlay)
if l > 0 {
n += 1 + l + sovHandshake(uint64(l))
}
l = len(m.Signature)
if l > 0 {
n += 1 + l + sovHandshake(uint64(l))
}
l = len(m.Overlay)
if l > 0 { if l > 0 {
n += 1 + l + sovHandshake(uint64(l)) n += 1 + l + sovHandshake(uint64(l))
} }
...@@ -422,9 +562,9 @@ func (m *Syn) Unmarshal(dAtA []byte) error { ...@@ -422,9 +562,9 @@ func (m *Syn) Unmarshal(dAtA []byte) error {
switch fieldNum { switch fieldNum {
case 1: case 1:
if wireType != 2 { if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Address", wireType) return fmt.Errorf("proto: wrong wireType = %d for field BzzAddress", wireType)
} }
var byteLen int var msglen int
for shift := uint(0); ; shift += 7 { for shift := uint(0); ; shift += 7 {
if shift >= 64 { if shift >= 64 {
return ErrIntOverflowHandshake return ErrIntOverflowHandshake
...@@ -434,24 +574,26 @@ func (m *Syn) Unmarshal(dAtA []byte) error { ...@@ -434,24 +574,26 @@ func (m *Syn) Unmarshal(dAtA []byte) error {
} }
b := dAtA[iNdEx] b := dAtA[iNdEx]
iNdEx++ iNdEx++
byteLen |= int(b&0x7F) << shift msglen |= int(b&0x7F) << shift
if b < 0x80 { if b < 0x80 {
break break
} }
} }
if byteLen < 0 { if msglen < 0 {
return ErrInvalidLengthHandshake return ErrInvalidLengthHandshake
} }
postIndex := iNdEx + byteLen postIndex := iNdEx + msglen
if postIndex < 0 { if postIndex < 0 {
return ErrInvalidLengthHandshake return ErrInvalidLengthHandshake
} }
if postIndex > l { if postIndex > l {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
m.Address = append(m.Address[:0], dAtA[iNdEx:postIndex]...) if m.BzzAddress == nil {
if m.Address == nil { m.BzzAddress = &BzzAddress{}
m.Address = []byte{} }
if err := m.BzzAddress.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
return err
} }
iNdEx = postIndex iNdEx = postIndex
case 2: case 2:
...@@ -517,6 +659,95 @@ func (m *Syn) Unmarshal(dAtA []byte) error { ...@@ -517,6 +659,95 @@ func (m *Syn) Unmarshal(dAtA []byte) error {
} }
return nil return nil
} }
func (m *Ack) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Ack: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Ack: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field BzzAddress", wireType)
}
var msglen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
msglen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if msglen < 0 {
return ErrInvalidLengthHandshake
}
postIndex := iNdEx + msglen
if postIndex < 0 {
return ErrInvalidLengthHandshake
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
if m.BzzAddress == nil {
m.BzzAddress = &BzzAddress{}
}
if err := m.BzzAddress.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {
return err
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *SynAck) Unmarshal(dAtA []byte) error { func (m *SynAck) Unmarshal(dAtA []byte) error {
l := len(dAtA) l := len(dAtA)
iNdEx := 0 iNdEx := 0
...@@ -642,7 +873,7 @@ func (m *SynAck) Unmarshal(dAtA []byte) error { ...@@ -642,7 +873,7 @@ func (m *SynAck) Unmarshal(dAtA []byte) error {
} }
return nil return nil
} }
func (m *Ack) Unmarshal(dAtA []byte) error { func (m *BzzAddress) Unmarshal(dAtA []byte) error {
l := len(dAtA) l := len(dAtA)
iNdEx := 0 iNdEx := 0
for iNdEx < l { for iNdEx < l {
...@@ -665,15 +896,83 @@ func (m *Ack) Unmarshal(dAtA []byte) error { ...@@ -665,15 +896,83 @@ func (m *Ack) Unmarshal(dAtA []byte) error {
fieldNum := int32(wire >> 3) fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7) wireType := int(wire & 0x7)
if wireType == 4 { if wireType == 4 {
return fmt.Errorf("proto: Ack: wiretype end group for non-group") return fmt.Errorf("proto: BzzAddress: wiretype end group for non-group")
} }
if fieldNum <= 0 { if fieldNum <= 0 {
return fmt.Errorf("proto: Ack: illegal tag %d (wire type %d)", fieldNum, wire) return fmt.Errorf("proto: BzzAddress: illegal tag %d (wire type %d)", fieldNum, wire)
} }
switch fieldNum { switch fieldNum {
case 1: case 1:
if wireType != 2 { if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Address", wireType) return fmt.Errorf("proto: wrong wireType = %d for field Underlay", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthHandshake
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthHandshake
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Underlay = append(m.Underlay[:0], dAtA[iNdEx:postIndex]...)
if m.Underlay == nil {
m.Underlay = []byte{}
}
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Signature", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthHandshake
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthHandshake
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Signature = append(m.Signature[:0], dAtA[iNdEx:postIndex]...)
if m.Signature == nil {
m.Signature = []byte{}
}
iNdEx = postIndex
case 3:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Overlay", wireType)
} }
var byteLen int var byteLen int
for shift := uint(0); ; shift += 7 { for shift := uint(0); ; shift += 7 {
...@@ -700,9 +999,9 @@ func (m *Ack) Unmarshal(dAtA []byte) error { ...@@ -700,9 +999,9 @@ func (m *Ack) Unmarshal(dAtA []byte) error {
if postIndex > l { if postIndex > l {
return io.ErrUnexpectedEOF return io.ErrUnexpectedEOF
} }
m.Address = append(m.Address[:0], dAtA[iNdEx:postIndex]...) m.Overlay = append(m.Overlay[:0], dAtA[iNdEx:postIndex]...)
if m.Address == nil { if m.Overlay == nil {
m.Address = []byte{} m.Overlay = []byte{}
} }
iNdEx = postIndex iNdEx = postIndex
default: default:
......
...@@ -9,16 +9,22 @@ package handshake; ...@@ -9,16 +9,22 @@ package handshake;
option go_package = "pb"; option go_package = "pb";
message Syn { message Syn {
bytes Address = 1; BzzAddress BzzAddress = 1;
uint64 NetworkID = 2; uint64 NetworkID = 2;
bool Light = 3; bool Light = 3;
} }
message Ack {
BzzAddress BzzAddress = 1;
}
message SynAck { message SynAck {
Syn Syn = 1; Syn Syn = 1;
Ack Ack = 2; Ack Ack = 2;
} }
message Ack { message BzzAddress {
bytes Address = 1; bytes Underlay = 1;
bytes Signature = 2;
bytes Overlay = 3;
} }
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"net" "net"
"github.com/ethersphere/bee/pkg/addressbook" "github.com/ethersphere/bee/pkg/addressbook"
beecrypto "github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/breaker" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/breaker"
...@@ -56,18 +57,16 @@ type Service struct { ...@@ -56,18 +57,16 @@ type Service struct {
type Options struct { type Options struct {
PrivateKey *ecdsa.PrivateKey PrivateKey *ecdsa.PrivateKey
Overlay swarm.Address
Addr string
DisableWS bool DisableWS bool
DisableQUIC bool DisableQUIC bool
NetworkID uint64
Addressbook addressbook.Putter Addressbook addressbook.Putter
Logger logging.Logger Logger logging.Logger
Tracer *tracing.Tracer Tracer *tracing.Tracer
} }
func New(ctx context.Context, o Options) (*Service, error) { func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay swarm.Address, addr string,
host, port, err := net.SplitHostPort(o.Addr) o Options) (*Service, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("address: %w", err) return nil, fmt.Errorf("address: %w", err)
} }
...@@ -155,14 +154,19 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -155,14 +154,19 @@ func New(ctx context.Context, o Options) (*Service, error) {
return nil, fmt.Errorf("autonat: %w", err) return nil, fmt.Errorf("autonat: %w", err)
} }
handshakeService, err := handshake.New(overlay, h.ID().String(), signer, networkID, o.Logger)
if err != nil {
return nil, fmt.Errorf("handshake service: %w", err)
}
peerRegistry := newPeerRegistry() peerRegistry := newPeerRegistry()
s := &Service{ s := &Service{
ctx: ctx, ctx: ctx,
host: h, host: h,
libp2pPeerstore: libp2pPeerstore, libp2pPeerstore: libp2pPeerstore,
metrics: newMetrics(), metrics: newMetrics(),
networkID: o.NetworkID, networkID: networkID,
handshakeService: handshake.New(o.Overlay, o.NetworkID, o.Logger), handshakeService: handshakeService,
peers: peerRegistry, peers: peerRegistry,
addrssbook: o.Addressbook, addrssbook: o.Addressbook,
logger: o.Logger, logger: o.Logger,
...@@ -182,21 +186,13 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -182,21 +186,13 @@ func New(ctx context.Context, o Options) (*Service, error) {
peerID := stream.Conn().RemotePeer() peerID := stream.Conn().RemotePeer()
i, err := s.handshakeService.Handle(NewStream(stream), peerID) i, err := s.handshakeService.Handle(NewStream(stream), peerID)
if err != nil { if err != nil {
if err == handshake.ErrNetworkIDIncompatible {
s.logger.Warningf("peer %s has a different network id.", peerID)
}
if err == handshake.ErrHandshakeDuplicate {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
}
s.logger.Debugf("handshake: handle %s: %v", peerID, err) s.logger.Debugf("handshake: handle %s: %v", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID) s.logger.Errorf("unable to handshake with peer %v", peerID)
_ = s.disconnect(peerID) _ = s.disconnect(peerID)
return return
} }
if exists := s.peers.addIfNotExists(stream.Conn(), i.Address); exists { if exists := s.peers.addIfNotExists(stream.Conn(), i.Overlay); exists {
_ = stream.Close() _ = stream.Close()
return return
} }
...@@ -210,7 +206,7 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -210,7 +206,7 @@ func New(ctx context.Context, o Options) (*Service, error) {
return return
} }
err = s.addrssbook.Put(i.Address, remoteMultiaddr) err = s.addrssbook.Put(i.Overlay, remoteMultiaddr)
if err != nil { if err != nil {
s.logger.Debugf("handshake: addressbook put error %s: %v", peerID, err) s.logger.Debugf("handshake: addressbook put error %s: %v", peerID, err)
s.logger.Errorf("unable to persist peer %v", peerID) s.logger.Errorf("unable to persist peer %v", peerID)
...@@ -219,13 +215,13 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -219,13 +215,13 @@ func New(ctx context.Context, o Options) (*Service, error) {
} }
if s.peerHandler != nil { if s.peerHandler != nil {
if err := s.peerHandler(ctx, i.Address); err != nil { if err := s.peerHandler(ctx, i.Overlay); err != nil {
s.logger.Debugf("peerhandler error: %s: %v", peerID, err) s.logger.Debugf("peerhandler error: %s: %v", peerID, err)
} }
} }
s.metrics.HandledStreamCount.Inc() s.metrics.HandledStreamCount.Inc()
s.logger.Infof("peer %s connected", i.Address) s.logger.Infof("peer %s connected", i.Overlay)
}) })
h.Network().SetConnHandler(func(_ network.Conn) { h.Network().SetConnHandler(func(_ network.Conn) {
...@@ -335,12 +331,12 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -335,12 +331,12 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, fmt.Errorf("handshake: %w", err) return swarm.Address{}, fmt.Errorf("handshake: %w", err)
} }
if exists := s.peers.addIfNotExists(stream.Conn(), i.Address); exists { if exists := s.peers.addIfNotExists(stream.Conn(), i.Overlay); exists {
if err := helpers.FullClose(stream); err != nil { if err := helpers.FullClose(stream); err != nil {
return swarm.Address{}, err return swarm.Address{}, err
} }
return i.Address, nil return i.Overlay, nil
} }
if err := helpers.FullClose(stream); err != nil { if err := helpers.FullClose(stream); err != nil {
...@@ -348,8 +344,8 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -348,8 +344,8 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
} }
s.metrics.CreatedConnectionCount.Inc() s.metrics.CreatedConnectionCount.Inc()
s.logger.Infof("peer %s connected", i.Address) s.logger.Infof("peer %s connected", i.Overlay)
return i.Address, nil return i.Overlay, nil
} }
func (s *Service) Disconnect(overlay swarm.Address) error { func (s *Service) Disconnect(overlay swarm.Address) error {
......
...@@ -23,45 +23,33 @@ import ( ...@@ -23,45 +23,33 @@ import (
) )
// newService constructs a new libp2p service. // newService constructs a new libp2p service.
func newService(t *testing.T, o libp2p.Options) (s *libp2p.Service, overlay swarm.Address, cleanup func()) { func newService(t *testing.T, networkID uint64, o libp2p.Options) (s *libp2p.Service, overlay swarm.Address, cleanup func()) {
t.Helper() t.Helper()
if o.PrivateKey == nil { privateKey, err := crypto.GenerateSecp256k1Key()
var err error if err != nil {
o.PrivateKey, err = crypto.GenerateSecp256k1Key() t.Fatal(err)
if err != nil {
t.Fatal(err)
}
} }
if o.Overlay.IsZero() { overlay = crypto.NewOverlayAddress(privateKey.PublicKey, networkID)
var err error
swarmPK, err := crypto.GenerateSecp256k1Key() addr := ":0"
if err != nil {
t.Fatal(err)
}
o.Overlay = crypto.NewOverlayAddress(swarmPK.PublicKey, o.NetworkID)
}
if o.Logger == nil { if o.Logger == nil {
o.Logger = logging.New(ioutil.Discard, 0) o.Logger = logging.New(ioutil.Discard, 0)
} }
if o.Addr == "" {
o.Addr = ":0"
}
if o.Addressbook == nil { if o.Addressbook == nil {
statestore := mock.NewStateStore() statestore := mock.NewStateStore()
o.Addressbook = addressbook.New(statestore) o.Addressbook = addressbook.New(statestore)
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s, err := libp2p.New(ctx, o) s, err = libp2p.New(ctx, crypto.NewDefaultSigner(privateKey), networkID, overlay, addr, o)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return s, o.Overlay, func() { return s, overlay, func() {
cancel() cancel()
s.Close() s.Close()
} }
......
...@@ -18,10 +18,10 @@ func TestNewStream(t *testing.T) { ...@@ -18,10 +18,10 @@ func TestNewStream(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error { if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
...@@ -49,10 +49,10 @@ func TestNewStream_errNotSupported(t *testing.T) { ...@@ -49,10 +49,10 @@ func TestNewStream_errNotSupported(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -86,10 +86,10 @@ func TestNewStream_semanticVersioning(t *testing.T) { ...@@ -86,10 +86,10 @@ func TestNewStream_semanticVersioning(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -147,10 +147,10 @@ func TestDisconnectError(t *testing.T) { ...@@ -147,10 +147,10 @@ func TestDisconnectError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error { if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
......
...@@ -34,10 +34,10 @@ func TestTracing(t *testing.T) { ...@@ -34,10 +34,10 @@ func TestTracing(t *testing.T) {
} }
defer closer2.Close() defer closer2.Close()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
var handledTracingSpan string var handledTracingSpan string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment