Commit 942853ec authored by Petar Radovic's avatar Petar Radovic

check for duplicate handshake msg in protocol

parent d6344c2f
......@@ -24,18 +24,25 @@ const (
// ErrNetworkIDIncompatible should be returned by handshake handlers if
// response from the other peer does not have valid networkID.
var ErrNetworkIDIncompatible = errors.New("incompatible networkID")
var ErrHandshakeDuplicate = errors.New("duplicate handshake")
type PeerFinder interface {
Exists(overlay swarm.Address) (found bool)
}
type Service struct {
overlay swarm.Address
networkID int32
logger logging.Logger
peerFinder PeerFinder
overlay swarm.Address
networkID int32
logger logging.Logger
}
func New(overlay swarm.Address, networkID int32, logger logging.Logger) *Service {
func New(peerFinder PeerFinder, overlay swarm.Address, networkID int32, logger logging.Logger) *Service {
return &Service{
overlay: overlay,
networkID: networkID,
logger: logger,
peerFinder: peerFinder,
overlay: overlay,
networkID: networkID,
logger: logger,
}
}
......@@ -53,6 +60,11 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("read message: %w", err)
}
address := swarm.NewAddress(resp.Syn.Address)
if s.peerFinder.Exists(address) {
return nil, ErrHandshakeDuplicate
}
if resp.Syn.NetworkID != s.networkID {
return nil, ErrNetworkIDIncompatible
}
......@@ -61,8 +73,6 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("ack: write message: %w", err)
}
address := swarm.NewAddress(resp.Syn.Address)
s.logger.Tracef("handshake finished for peer %s", address)
return &Info{
......@@ -81,6 +91,11 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("read message: %w", err)
}
address := swarm.NewAddress(req.Address)
if s.peerFinder.Exists(address) {
return nil, ErrHandshakeDuplicate
}
if req.NetworkID != s.networkID {
return nil, ErrNetworkIDIncompatible
}
......@@ -100,8 +115,6 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("ack: read message: %w", err)
}
address := swarm.NewAddress(req.Address)
s.logger.Tracef("handshake finished for peer %s", address)
return &Info{
Address: address,
......
......@@ -26,7 +26,9 @@ func TestHandshake(t *testing.T) {
NetworkID: 0,
Light: false,
}
handshakeService := New(info.Address, info.NetworkID, logger)
peerFinderMock := &mock.PeerFinderMock{}
handshakeService := New(peerFinderMock, info.Address, info.NetworkID, logger)
t.Run("OK", func(t *testing.T) {
expectedInfo := Info{
......@@ -163,6 +165,41 @@ func TestHandshake(t *testing.T) {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
t.Run("ERROR - duplicate handshake ", func(t *testing.T) {
node2Info := Info{
Address: node2Addr,
NetworkID: 0,
Light: false,
}
peerFinderMock.SetFound(true)
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{
Address: node2Info.Address.Bytes(),
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
},
Ack: &pb.Ack{Address: info.Address.Bytes()},
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handshake(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != ErrHandshakeDuplicate {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
}
func TestHandle(t *testing.T) {
......@@ -175,7 +212,8 @@ func TestHandle(t *testing.T) {
}
logger := logging.New(ioutil.Discard, 0)
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger)
peerFinderMock := &mock.PeerFinderMock{}
handshakeService := New(peerFinderMock, nodeInfo.Address, nodeInfo.NetworkID, logger)
t.Run("OK", func(t *testing.T) {
node2Info := Info{
......@@ -325,6 +363,38 @@ func TestHandle(t *testing.T) {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
t.Run("ERROR - duplicate handshake msg", func(t *testing.T) {
node2Info := Info{
Address: node2Addr,
NetworkID: 0,
Light: false,
}
peerFinderMock.SetFound(true)
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.Syn{
Address: node2Info.Address.Bytes(),
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handle(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != ErrHandshakeDuplicate {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
}
// testInfo validates if two Info instances are equal.
......
package mock
import "github.com/ethersphere/bee/pkg/swarm"
type PeerFinderMock struct {
found bool
}
func (p *PeerFinderMock) SetFound(found bool) {
p.found = found
}
func (p *PeerFinderMock) Exists(overlay swarm.Address) (found bool) {
return p.found
}
......@@ -23,13 +23,11 @@ func NewStream(readBuffer, writeBuffer *bytes.Buffer) *StreamMock {
func (s *StreamMock) SetReadErr(err error, checkmark int) {
s.readError = err
s.readErrCheckmark = checkmark
}
func (s *StreamMock) SetWriteErr(err error, checkmark int) {
s.writeError = err
s.writeErrCheckmark = checkmark
}
func (s *StreamMock) Read(p []byte) (n int, err error) {
......
......@@ -182,12 +182,13 @@ func New(ctx context.Context, o Options) (*Service, error) {
return nil, fmt.Errorf("autonat: %w", err)
}
peerRegistry := newPeerRegistry()
s := &Service{
host: h,
metrics: newMetrics(),
networkID: o.NetworkID,
handshakeService: handshake.New(o.Overlay, o.NetworkID, o.Logger),
peers: newPeerRegistry(),
handshakeService: handshake.New(peerRegistry, o.Overlay, o.NetworkID, o.Logger),
peers: peerRegistry,
logger: o.Logger,
}
......@@ -207,6 +208,10 @@ func New(ctx context.Context, o Options) (*Service, error) {
s.logger.Warningf("peer %s has a different network id.", peerID)
}
if err == handshake.ErrHandshakeDuplicate {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
}
s.logger.Debugf("handshake: handle %s: %w", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID)
// todo: test connection close and refactor
......@@ -214,13 +219,6 @@ func New(ctx context.Context, o Options) (*Service, error) {
return
}
if peerID, found := s.peers.peerID(i.Address); found {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
// disconnect if handshake was performed for already existing peer
_ = s.Disconnect(i.Address)
return
}
s.peers.add(peerID, i.Address)
s.metrics.HandledStreamCount.Inc()
s.logger.Infof("peer %s connected", i.Address)
......@@ -319,13 +317,6 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, err
}
if peerID, found := s.peers.peerID(i.Address); found {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
// disconnect if handshake was performed for already existing peer
_ = s.Disconnect(i.Address)
return
}
s.peers.add(info.ID, i.Address)
s.metrics.CreatedConnectionCount.Inc()
s.logger.Infof("peer %s connected", i.Address)
......
......@@ -24,6 +24,11 @@ func newPeerRegistry() *peerRegistry {
}
}
func (r *peerRegistry) Exists(overlay swarm.Address) (found bool) {
_, found = r.peerID(overlay)
return found
}
func (r *peerRegistry) add(peerID libp2ppeer.ID, overlay swarm.Address) {
r.mu.Lock()
r.peers[encodePeersKey(overlay)] = peerID
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment