Commit 2db9d81f authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

Signing in handshake (#196)

* signing in handshake & Signer interface
parent 70ff4fa9
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package crypto
import (
"crypto/ecdsa"
"github.com/btcsuite/btcd/btcec"
)
type Signer interface {
Sign(data []byte) ([]byte, error)
PublicKey() (*ecdsa.PublicKey, error)
}
// Recover verifies signature with the data base provided.
// It is using `btcec.RecoverCompact` function
func Recover(signature, data []byte) (*ecdsa.PublicKey, error) {
p, _, err := btcec.RecoverCompact(btcec.S256(), signature, data)
return (*ecdsa.PublicKey)(p), err
}
type defaultSigner struct {
key *ecdsa.PrivateKey
}
func NewDefaultSigner(key *ecdsa.PrivateKey) Signer {
return &defaultSigner{
key: key,
}
}
func (d *defaultSigner) PublicKey() (*ecdsa.PublicKey, error) {
return &d.key.PublicKey, nil
}
func (d *defaultSigner) Sign(data []byte) (signature []byte, err error) {
return btcec.SignCompact(btcec.S256(), (*btcec.PrivateKey)(d.key), data, true)
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package crypto_test
import (
"testing"
"github.com/ethersphere/bee/pkg/crypto"
)
func TestDefaultSigner(t *testing.T) {
testBytes := []byte("test string")
privKey, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
signer := crypto.NewDefaultSigner(privKey)
signature, err := signer.Sign(testBytes)
if err != nil {
t.Fatal(err)
}
t.Run("OK - sign & recover", func(t *testing.T) {
pubKey, err := crypto.Recover(signature, testBytes)
if err != nil {
t.Fatal(err)
}
if pubKey.X.Cmp(privKey.PublicKey.X) != 0 || pubKey.Y.Cmp(privKey.PublicKey.Y) != 0 {
t.Fatalf("wanted %v but got %v", pubKey, &privKey.PublicKey)
}
})
t.Run("OK - recover with invalid data", func(t *testing.T) {
pubKey, err := crypto.Recover(signature, []byte("invalid"))
if err != nil {
t.Fatal(err)
}
if pubKey.X.Cmp(privKey.PublicKey.X) == 0 && pubKey.Y.Cmp(privKey.PublicKey.Y) == 0 {
t.Fatal("expected different public key")
}
})
}
...@@ -134,13 +134,10 @@ func NewBee(o Options) (*Bee, error) { ...@@ -134,13 +134,10 @@ func NewBee(o Options) (*Bee, error) {
b.stateStoreCloser = stateStore b.stateStoreCloser = stateStore
addressbook := addressbook.New(stateStore) addressbook := addressbook.New(stateStore)
p2ps, err := libp2p.New(p2pCtx, libp2p.Options{ p2ps, err := libp2p.New(p2pCtx, crypto.NewDefaultSigner(swarmPrivateKey), o.NetworkID, address, o.Addr, libp2p.Options{
PrivateKey: libp2pPrivateKey, PrivateKey: libp2pPrivateKey,
Overlay: address,
Addr: o.Addr,
DisableWS: o.DisableWS, DisableWS: o.DisableWS,
DisableQUIC: o.DisableQUIC, DisableQUIC: o.DisableQUIC,
NetworkID: o.NetworkID,
Addressbook: addressbook, Addressbook: addressbook,
Logger: logger, Logger: logger,
Tracer: tracer, Tracer: tracer,
......
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
) )
func TestAddresses(t *testing.T) { func TestAddresses(t *testing.T) {
s, _, cleanup := newService(t, libp2p.Options{NetworkID: 1}) s, _, cleanup := newService(t, 1, libp2p.Options{})
defer cleanup() defer cleanup()
addrs, err := s.Addresses() addrs, err := s.Addresses()
...@@ -32,10 +32,10 @@ func TestConnectDisconnect(t *testing.T) { ...@@ -32,10 +32,10 @@ func TestConnectDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -60,10 +60,10 @@ func TestDoubleConnect(t *testing.T) { ...@@ -60,10 +60,10 @@ func TestDoubleConnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -87,10 +87,10 @@ func TestDoubleDisconnect(t *testing.T) { ...@@ -87,10 +87,10 @@ func TestDoubleDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -122,10 +122,10 @@ func TestMultipleConnectDisconnect(t *testing.T) { ...@@ -122,10 +122,10 @@ func TestMultipleConnectDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -165,10 +165,10 @@ func TestConnectDisconnectOnAllAddresses(t *testing.T) { ...@@ -165,10 +165,10 @@ func TestConnectDisconnectOnAllAddresses(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addrs, err := s1.Addresses() addrs, err := s1.Addresses()
...@@ -197,10 +197,10 @@ func TestDoubleConnectOnAllAddresses(t *testing.T) { ...@@ -197,10 +197,10 @@ func TestDoubleConnectOnAllAddresses(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addrs, err := s1.Addresses() addrs, err := s1.Addresses()
...@@ -235,10 +235,10 @@ func TestDifferentNetworkIDs(t *testing.T) { ...@@ -235,10 +235,10 @@ func TestDifferentNetworkIDs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, _, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, _, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{NetworkID: 2}) s2, _, cleanup2 := newService(t, 2, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -255,15 +255,13 @@ func TestConnectWithDisabledQUICAndWSTransports(t *testing.T) { ...@@ -255,15 +255,13 @@ func TestConnectWithDisabledQUICAndWSTransports(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{ s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{
NetworkID: 1,
DisableQUIC: true, DisableQUIC: true,
DisableWS: true, DisableWS: true,
}) })
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{ s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{
NetworkID: 1,
DisableQUIC: true, DisableQUIC: true,
DisableWS: true, DisableWS: true,
}) })
...@@ -284,10 +282,10 @@ func TestConnectRepeatHandshake(t *testing.T) { ...@@ -284,10 +282,10 @@ func TestConnectRepeatHandshake(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
......
...@@ -23,10 +23,10 @@ func TestHeaders(t *testing.T) { ...@@ -23,10 +23,10 @@ func TestHeaders(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
var gotHeaders p2p.Headers var gotHeaders p2p.Headers
...@@ -72,10 +72,10 @@ func TestHeaders_empty(t *testing.T) { ...@@ -72,10 +72,10 @@ func TestHeaders_empty(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
var gotHeaders p2p.Headers var gotHeaders p2p.Headers
...@@ -130,10 +130,10 @@ func TestHeadler(t *testing.T) { ...@@ -130,10 +130,10 @@ func TestHeadler(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
var gotReceivedHeaders p2p.Headers var gotReceivedHeaders p2p.Headers
......
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
package handshake package handshake
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"strconv"
"sync" "sync"
"time" "time"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb"
...@@ -27,13 +30,19 @@ const ( ...@@ -27,13 +30,19 @@ const (
messageTimeout = 5 * time.Second // maximum allowed time for a message to be read or written. messageTimeout = 5 * time.Second // maximum allowed time for a message to be read or written.
) )
// ErrNetworkIDIncompatible should be returned by handshake handlers if var (
// response from the other peer does not have valid networkID. // ErrNetworkIDIncompatible is returned if response from the other peer does not have valid networkID.
var ErrNetworkIDIncompatible = errors.New("incompatible network ID") ErrNetworkIDIncompatible = errors.New("incompatible network ID")
// ErrHandshakeDuplicate should be returned by handshake handlers if // ErrHandshakeDuplicate is returned if the handshake response has been received by an already processed peer.
// the handshake response has been received by an already processed peer. ErrHandshakeDuplicate = errors.New("duplicate handshake")
var ErrHandshakeDuplicate = errors.New("duplicate handshake")
// ErrInvalidSignature is returned if peer info was received with invalid signature
ErrInvalidSignature = errors.New("invalid signature")
// ErrInvalidAck is returned if ack does not match the syn provided
ErrInvalidAck = errors.New("invalid ack")
)
// PeerFinder has the information if the peer already exists in swarm. // PeerFinder has the information if the peer already exists in swarm.
type PeerFinder interface { type PeerFinder interface {
...@@ -42,6 +51,9 @@ type PeerFinder interface { ...@@ -42,6 +51,9 @@ type PeerFinder interface {
type Service struct { type Service struct {
overlay swarm.Address overlay swarm.Address
underlay []byte
signature []byte
signer crypto.Signer
networkID uint64 networkID uint64
receivedHandshakes map[libp2ppeer.ID]struct{} receivedHandshakes map[libp2ppeer.ID]struct{}
receivedHandshakesMu sync.Mutex receivedHandshakesMu sync.Mutex
...@@ -50,20 +62,32 @@ type Service struct { ...@@ -50,20 +62,32 @@ type Service struct {
network.Notifiee // handhsake service can be the receiver for network.Notify network.Notifiee // handhsake service can be the receiver for network.Notify
} }
func New(overlay swarm.Address, networkID uint64, logger logging.Logger) *Service { func New(overlay swarm.Address, underlay string, signer crypto.Signer, networkID uint64, logger logging.Logger) (*Service, error) {
signature, err := signer.Sign([]byte(underlay + strconv.FormatUint(networkID, 10)))
if err != nil {
return nil, err
}
return &Service{ return &Service{
overlay: overlay, overlay: overlay,
underlay: []byte(underlay),
signature: signature,
signer: signer,
networkID: networkID, networkID: networkID,
receivedHandshakes: make(map[libp2ppeer.ID]struct{}), receivedHandshakes: make(map[libp2ppeer.ID]struct{}),
logger: logger, logger: logger,
Notifiee: new(network.NoopNotifiee), Notifiee: new(network.NoopNotifiee),
} }, nil
} }
func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{
Address: s.overlay.Bytes(), BzzAddress: &pb.BzzAddress{
Underlay: s.underlay,
Signature: s.signature,
Overlay: s.overlay.Bytes(),
},
NetworkID: s.networkID, NetworkID: s.networkID,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write syn message: %w", err) return nil, fmt.Errorf("write syn message: %w", err)
...@@ -74,21 +98,25 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { ...@@ -74,21 +98,25 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("read synack message: %w", err) return nil, fmt.Errorf("read synack message: %w", err)
} }
address := swarm.NewAddress(resp.Syn.Address) if err := s.checkAck(resp.Ack); err != nil {
if resp.Syn.NetworkID != s.networkID { return nil, err
return nil, ErrNetworkIDIncompatible }
if err := s.checkSyn(resp.Syn); err != nil {
return nil, err
} }
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{
Address: resp.Syn.Address, BzzAddress: resp.Syn.BzzAddress,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write ack message: %w", err) return nil, fmt.Errorf("write ack message: %w", err)
} }
s.logger.Tracef("handshake finished for peer %s", address) s.logger.Tracef("handshake finished for peer %s", swarm.NewAddress(resp.Syn.BzzAddress.Overlay).String())
return &Info{ return &Info{
Address: address, Overlay: swarm.NewAddress(resp.Syn.BzzAddress.Overlay),
Underlay: resp.Syn.BzzAddress.Underlay,
NetworkID: resp.Syn.NetworkID, NetworkID: resp.Syn.NetworkID,
Light: resp.Syn.Light, Light: resp.Syn.Light,
}, nil }, nil
...@@ -110,17 +138,20 @@ func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err ...@@ -110,17 +138,20 @@ func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err
return nil, fmt.Errorf("read syn message: %w", err) return nil, fmt.Errorf("read syn message: %w", err)
} }
address := swarm.NewAddress(req.Address) if err := s.checkSyn(&req); err != nil {
if req.NetworkID != s.networkID { return nil, err
return nil, ErrNetworkIDIncompatible
} }
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{
Syn: &pb.Syn{ Syn: &pb.Syn{
Address: s.overlay.Bytes(), BzzAddress: &pb.BzzAddress{
Underlay: s.underlay,
Signature: s.signature,
Overlay: s.overlay.Bytes(),
},
NetworkID: s.networkID, NetworkID: s.networkID,
}, },
Ack: &pb.Ack{Address: req.Address}, Ack: &pb.Ack{BzzAddress: req.BzzAddress},
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write synack message: %w", err) return nil, fmt.Errorf("write synack message: %w", err)
} }
...@@ -130,9 +161,14 @@ func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err ...@@ -130,9 +161,14 @@ func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err
return nil, fmt.Errorf("read ack message: %w", err) return nil, fmt.Errorf("read ack message: %w", err)
} }
s.logger.Tracef("handshake finished for peer %s", address) if err := s.checkAck(&ack); err != nil {
return nil, err
}
s.logger.Tracef("handshake finished for peer %s", req.BzzAddress.Overlay)
return &Info{ return &Info{
Address: address, Overlay: swarm.NewAddress(req.BzzAddress.Overlay),
Underlay: req.BzzAddress.Underlay,
NetworkID: req.NetworkID, NetworkID: req.NetworkID,
Light: req.Light, Light: req.Light,
}, nil }, nil
...@@ -144,8 +180,37 @@ func (s *Service) Disconnected(_ network.Network, c network.Conn) { ...@@ -144,8 +180,37 @@ func (s *Service) Disconnected(_ network.Network, c network.Conn) {
delete(s.receivedHandshakes, c.RemotePeer()) delete(s.receivedHandshakes, c.RemotePeer())
} }
func (s *Service) checkSyn(syn *pb.Syn) error {
if syn.NetworkID != s.networkID {
return ErrNetworkIDIncompatible
}
recoveredPK, err := crypto.Recover(syn.BzzAddress.Signature, append(syn.BzzAddress.Underlay, strconv.FormatUint(syn.NetworkID, 10)...))
if err != nil {
return ErrInvalidSignature
}
recoveredOverlay := crypto.NewOverlayAddress(*recoveredPK, syn.NetworkID)
if !bytes.Equal(recoveredOverlay.Bytes(), syn.BzzAddress.Overlay) {
return ErrInvalidSignature
}
return nil
}
func (s *Service) checkAck(ack *pb.Ack) error {
if !bytes.Equal(ack.BzzAddress.Overlay, s.overlay.Bytes()) ||
!bytes.Equal(ack.BzzAddress.Underlay, s.underlay) ||
!bytes.Equal(ack.BzzAddress.Signature, s.signature) {
return ErrInvalidAck
}
return nil
}
type Info struct { type Info struct {
Address swarm.Address Overlay swarm.Address
Underlay []byte
NetworkID uint64 NetworkID uint64
Light bool Light bool
} }
...@@ -9,16 +9,22 @@ package handshake; ...@@ -9,16 +9,22 @@ package handshake;
option go_package = "pb"; option go_package = "pb";
message Syn { message Syn {
bytes Address = 1; BzzAddress BzzAddress = 1;
uint64 NetworkID = 2; uint64 NetworkID = 2;
bool Light = 3; bool Light = 3;
} }
message Ack {
BzzAddress BzzAddress = 1;
}
message SynAck { message SynAck {
Syn Syn = 1; Syn Syn = 1;
Ack Ack = 2; Ack Ack = 2;
} }
message Ack { message BzzAddress {
bytes Address = 1; bytes Underlay = 1;
bytes Signature = 2;
bytes Overlay = 3;
} }
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"net" "net"
"github.com/ethersphere/bee/pkg/addressbook" "github.com/ethersphere/bee/pkg/addressbook"
beecrypto "github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/breaker" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/breaker"
...@@ -56,18 +57,16 @@ type Service struct { ...@@ -56,18 +57,16 @@ type Service struct {
type Options struct { type Options struct {
PrivateKey *ecdsa.PrivateKey PrivateKey *ecdsa.PrivateKey
Overlay swarm.Address
Addr string
DisableWS bool DisableWS bool
DisableQUIC bool DisableQUIC bool
NetworkID uint64
Addressbook addressbook.Putter Addressbook addressbook.Putter
Logger logging.Logger Logger logging.Logger
Tracer *tracing.Tracer Tracer *tracing.Tracer
} }
func New(ctx context.Context, o Options) (*Service, error) { func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay swarm.Address, addr string,
host, port, err := net.SplitHostPort(o.Addr) o Options) (*Service, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("address: %w", err) return nil, fmt.Errorf("address: %w", err)
} }
...@@ -155,14 +154,19 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -155,14 +154,19 @@ func New(ctx context.Context, o Options) (*Service, error) {
return nil, fmt.Errorf("autonat: %w", err) return nil, fmt.Errorf("autonat: %w", err)
} }
handshakeService, err := handshake.New(overlay, h.ID().String(), signer, networkID, o.Logger)
if err != nil {
return nil, fmt.Errorf("handshake service: %w", err)
}
peerRegistry := newPeerRegistry() peerRegistry := newPeerRegistry()
s := &Service{ s := &Service{
ctx: ctx, ctx: ctx,
host: h, host: h,
libp2pPeerstore: libp2pPeerstore, libp2pPeerstore: libp2pPeerstore,
metrics: newMetrics(), metrics: newMetrics(),
networkID: o.NetworkID, networkID: networkID,
handshakeService: handshake.New(o.Overlay, o.NetworkID, o.Logger), handshakeService: handshakeService,
peers: peerRegistry, peers: peerRegistry,
addrssbook: o.Addressbook, addrssbook: o.Addressbook,
logger: o.Logger, logger: o.Logger,
...@@ -182,21 +186,13 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -182,21 +186,13 @@ func New(ctx context.Context, o Options) (*Service, error) {
peerID := stream.Conn().RemotePeer() peerID := stream.Conn().RemotePeer()
i, err := s.handshakeService.Handle(NewStream(stream), peerID) i, err := s.handshakeService.Handle(NewStream(stream), peerID)
if err != nil { if err != nil {
if err == handshake.ErrNetworkIDIncompatible {
s.logger.Warningf("peer %s has a different network id.", peerID)
}
if err == handshake.ErrHandshakeDuplicate {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
}
s.logger.Debugf("handshake: handle %s: %v", peerID, err) s.logger.Debugf("handshake: handle %s: %v", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID) s.logger.Errorf("unable to handshake with peer %v", peerID)
_ = s.disconnect(peerID) _ = s.disconnect(peerID)
return return
} }
if exists := s.peers.addIfNotExists(stream.Conn(), i.Address); exists { if exists := s.peers.addIfNotExists(stream.Conn(), i.Overlay); exists {
_ = stream.Close() _ = stream.Close()
return return
} }
...@@ -210,7 +206,7 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -210,7 +206,7 @@ func New(ctx context.Context, o Options) (*Service, error) {
return return
} }
err = s.addrssbook.Put(i.Address, remoteMultiaddr) err = s.addrssbook.Put(i.Overlay, remoteMultiaddr)
if err != nil { if err != nil {
s.logger.Debugf("handshake: addressbook put error %s: %v", peerID, err) s.logger.Debugf("handshake: addressbook put error %s: %v", peerID, err)
s.logger.Errorf("unable to persist peer %v", peerID) s.logger.Errorf("unable to persist peer %v", peerID)
...@@ -219,13 +215,13 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -219,13 +215,13 @@ func New(ctx context.Context, o Options) (*Service, error) {
} }
if s.peerHandler != nil { if s.peerHandler != nil {
if err := s.peerHandler(ctx, i.Address); err != nil { if err := s.peerHandler(ctx, i.Overlay); err != nil {
s.logger.Debugf("peerhandler error: %s: %v", peerID, err) s.logger.Debugf("peerhandler error: %s: %v", peerID, err)
} }
} }
s.metrics.HandledStreamCount.Inc() s.metrics.HandledStreamCount.Inc()
s.logger.Infof("peer %s connected", i.Address) s.logger.Infof("peer %s connected", i.Overlay)
}) })
h.Network().SetConnHandler(func(_ network.Conn) { h.Network().SetConnHandler(func(_ network.Conn) {
...@@ -335,12 +331,12 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -335,12 +331,12 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, fmt.Errorf("handshake: %w", err) return swarm.Address{}, fmt.Errorf("handshake: %w", err)
} }
if exists := s.peers.addIfNotExists(stream.Conn(), i.Address); exists { if exists := s.peers.addIfNotExists(stream.Conn(), i.Overlay); exists {
if err := helpers.FullClose(stream); err != nil { if err := helpers.FullClose(stream); err != nil {
return swarm.Address{}, err return swarm.Address{}, err
} }
return i.Address, nil return i.Overlay, nil
} }
if err := helpers.FullClose(stream); err != nil { if err := helpers.FullClose(stream); err != nil {
...@@ -348,8 +344,8 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -348,8 +344,8 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
} }
s.metrics.CreatedConnectionCount.Inc() s.metrics.CreatedConnectionCount.Inc()
s.logger.Infof("peer %s connected", i.Address) s.logger.Infof("peer %s connected", i.Overlay)
return i.Address, nil return i.Overlay, nil
} }
func (s *Service) Disconnect(overlay swarm.Address) error { func (s *Service) Disconnect(overlay swarm.Address) error {
......
...@@ -23,45 +23,33 @@ import ( ...@@ -23,45 +23,33 @@ import (
) )
// newService constructs a new libp2p service. // newService constructs a new libp2p service.
func newService(t *testing.T, o libp2p.Options) (s *libp2p.Service, overlay swarm.Address, cleanup func()) { func newService(t *testing.T, networkID uint64, o libp2p.Options) (s *libp2p.Service, overlay swarm.Address, cleanup func()) {
t.Helper() t.Helper()
if o.PrivateKey == nil { privateKey, err := crypto.GenerateSecp256k1Key()
var err error if err != nil {
o.PrivateKey, err = crypto.GenerateSecp256k1Key() t.Fatal(err)
if err != nil {
t.Fatal(err)
}
} }
if o.Overlay.IsZero() { overlay = crypto.NewOverlayAddress(privateKey.PublicKey, networkID)
var err error
swarmPK, err := crypto.GenerateSecp256k1Key() addr := ":0"
if err != nil {
t.Fatal(err)
}
o.Overlay = crypto.NewOverlayAddress(swarmPK.PublicKey, o.NetworkID)
}
if o.Logger == nil { if o.Logger == nil {
o.Logger = logging.New(ioutil.Discard, 0) o.Logger = logging.New(ioutil.Discard, 0)
} }
if o.Addr == "" {
o.Addr = ":0"
}
if o.Addressbook == nil { if o.Addressbook == nil {
statestore := mock.NewStateStore() statestore := mock.NewStateStore()
o.Addressbook = addressbook.New(statestore) o.Addressbook = addressbook.New(statestore)
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s, err := libp2p.New(ctx, o) s, err = libp2p.New(ctx, crypto.NewDefaultSigner(privateKey), networkID, overlay, addr, o)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return s, o.Overlay, func() { return s, overlay, func() {
cancel() cancel()
s.Close() s.Close()
} }
......
...@@ -18,10 +18,10 @@ func TestNewStream(t *testing.T) { ...@@ -18,10 +18,10 @@ func TestNewStream(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error { if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
...@@ -49,10 +49,10 @@ func TestNewStream_errNotSupported(t *testing.T) { ...@@ -49,10 +49,10 @@ func TestNewStream_errNotSupported(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -86,10 +86,10 @@ func TestNewStream_semanticVersioning(t *testing.T) { ...@@ -86,10 +86,10 @@ func TestNewStream_semanticVersioning(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -147,10 +147,10 @@ func TestDisconnectError(t *testing.T) { ...@@ -147,10 +147,10 @@ func TestDisconnectError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1}) s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error { if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
......
...@@ -34,10 +34,10 @@ func TestTracing(t *testing.T) { ...@@ -34,10 +34,10 @@ func TestTracing(t *testing.T) {
} }
defer closer2.Close() defer closer2.Close()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{}) s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{})
defer cleanup1() defer cleanup1()
s2, _, cleanup2 := newService(t, libp2p.Options{}) s2, _, cleanup2 := newService(t, 1, libp2p.Options{})
defer cleanup2() defer cleanup2()
var handledTracingSpan string var handledTracingSpan string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment