Commit 942853ec authored by Petar Radovic's avatar Petar Radovic

check for duplicate handshake msg in protocol

parent d6344c2f
...@@ -24,18 +24,25 @@ const ( ...@@ -24,18 +24,25 @@ const (
// ErrNetworkIDIncompatible should be returned by handshake handlers if // ErrNetworkIDIncompatible should be returned by handshake handlers if
// response from the other peer does not have valid networkID. // response from the other peer does not have valid networkID.
var ErrNetworkIDIncompatible = errors.New("incompatible networkID") var ErrNetworkIDIncompatible = errors.New("incompatible networkID")
var ErrHandshakeDuplicate = errors.New("duplicate handshake")
type PeerFinder interface {
Exists(overlay swarm.Address) (found bool)
}
type Service struct { type Service struct {
overlay swarm.Address peerFinder PeerFinder
networkID int32 overlay swarm.Address
logger logging.Logger networkID int32
logger logging.Logger
} }
func New(overlay swarm.Address, networkID int32, logger logging.Logger) *Service { func New(peerFinder PeerFinder, overlay swarm.Address, networkID int32, logger logging.Logger) *Service {
return &Service{ return &Service{
overlay: overlay, peerFinder: peerFinder,
networkID: networkID, overlay: overlay,
logger: logger, networkID: networkID,
logger: logger,
} }
} }
...@@ -53,6 +60,11 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { ...@@ -53,6 +60,11 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("read message: %w", err) return nil, fmt.Errorf("read message: %w", err)
} }
address := swarm.NewAddress(resp.Syn.Address)
if s.peerFinder.Exists(address) {
return nil, ErrHandshakeDuplicate
}
if resp.Syn.NetworkID != s.networkID { if resp.Syn.NetworkID != s.networkID {
return nil, ErrNetworkIDIncompatible return nil, ErrNetworkIDIncompatible
} }
...@@ -61,8 +73,6 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { ...@@ -61,8 +73,6 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("ack: write message: %w", err) return nil, fmt.Errorf("ack: write message: %w", err)
} }
address := swarm.NewAddress(resp.Syn.Address)
s.logger.Tracef("handshake finished for peer %s", address) s.logger.Tracef("handshake finished for peer %s", address)
return &Info{ return &Info{
...@@ -81,6 +91,11 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) { ...@@ -81,6 +91,11 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("read message: %w", err) return nil, fmt.Errorf("read message: %w", err)
} }
address := swarm.NewAddress(req.Address)
if s.peerFinder.Exists(address) {
return nil, ErrHandshakeDuplicate
}
if req.NetworkID != s.networkID { if req.NetworkID != s.networkID {
return nil, ErrNetworkIDIncompatible return nil, ErrNetworkIDIncompatible
} }
...@@ -100,8 +115,6 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) { ...@@ -100,8 +115,6 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("ack: read message: %w", err) return nil, fmt.Errorf("ack: read message: %w", err)
} }
address := swarm.NewAddress(req.Address)
s.logger.Tracef("handshake finished for peer %s", address) s.logger.Tracef("handshake finished for peer %s", address)
return &Info{ return &Info{
Address: address, Address: address,
......
...@@ -26,7 +26,9 @@ func TestHandshake(t *testing.T) { ...@@ -26,7 +26,9 @@ func TestHandshake(t *testing.T) {
NetworkID: 0, NetworkID: 0,
Light: false, Light: false,
} }
handshakeService := New(info.Address, info.NetworkID, logger)
peerFinderMock := &mock.PeerFinderMock{}
handshakeService := New(peerFinderMock, info.Address, info.NetworkID, logger)
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
expectedInfo := Info{ expectedInfo := Info{
...@@ -163,6 +165,41 @@ func TestHandshake(t *testing.T) { ...@@ -163,6 +165,41 @@ func TestHandshake(t *testing.T) {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err) t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
} }
}) })
t.Run("ERROR - duplicate handshake ", func(t *testing.T) {
node2Info := Info{
Address: node2Addr,
NetworkID: 0,
Light: false,
}
peerFinderMock.SetFound(true)
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{
Address: node2Info.Address.Bytes(),
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
},
Ack: &pb.Ack{Address: info.Address.Bytes()},
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handshake(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != ErrHandshakeDuplicate {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
} }
func TestHandle(t *testing.T) { func TestHandle(t *testing.T) {
...@@ -175,7 +212,8 @@ func TestHandle(t *testing.T) { ...@@ -175,7 +212,8 @@ func TestHandle(t *testing.T) {
} }
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger) peerFinderMock := &mock.PeerFinderMock{}
handshakeService := New(peerFinderMock, nodeInfo.Address, nodeInfo.NetworkID, logger)
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
node2Info := Info{ node2Info := Info{
...@@ -325,6 +363,38 @@ func TestHandle(t *testing.T) { ...@@ -325,6 +363,38 @@ func TestHandle(t *testing.T) {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err) t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
} }
}) })
t.Run("ERROR - duplicate handshake msg", func(t *testing.T) {
node2Info := Info{
Address: node2Addr,
NetworkID: 0,
Light: false,
}
peerFinderMock.SetFound(true)
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.Syn{
Address: node2Info.Address.Bytes(),
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handle(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != ErrHandshakeDuplicate {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
} }
// testInfo validates if two Info instances are equal. // testInfo validates if two Info instances are equal.
......
package mock
import "github.com/ethersphere/bee/pkg/swarm"
type PeerFinderMock struct {
found bool
}
func (p *PeerFinderMock) SetFound(found bool) {
p.found = found
}
func (p *PeerFinderMock) Exists(overlay swarm.Address) (found bool) {
return p.found
}
...@@ -23,13 +23,11 @@ func NewStream(readBuffer, writeBuffer *bytes.Buffer) *StreamMock { ...@@ -23,13 +23,11 @@ func NewStream(readBuffer, writeBuffer *bytes.Buffer) *StreamMock {
func (s *StreamMock) SetReadErr(err error, checkmark int) { func (s *StreamMock) SetReadErr(err error, checkmark int) {
s.readError = err s.readError = err
s.readErrCheckmark = checkmark s.readErrCheckmark = checkmark
} }
func (s *StreamMock) SetWriteErr(err error, checkmark int) { func (s *StreamMock) SetWriteErr(err error, checkmark int) {
s.writeError = err s.writeError = err
s.writeErrCheckmark = checkmark s.writeErrCheckmark = checkmark
} }
func (s *StreamMock) Read(p []byte) (n int, err error) { func (s *StreamMock) Read(p []byte) (n int, err error) {
......
...@@ -182,12 +182,13 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -182,12 +182,13 @@ func New(ctx context.Context, o Options) (*Service, error) {
return nil, fmt.Errorf("autonat: %w", err) return nil, fmt.Errorf("autonat: %w", err)
} }
peerRegistry := newPeerRegistry()
s := &Service{ s := &Service{
host: h, host: h,
metrics: newMetrics(), metrics: newMetrics(),
networkID: o.NetworkID, networkID: o.NetworkID,
handshakeService: handshake.New(o.Overlay, o.NetworkID, o.Logger), handshakeService: handshake.New(peerRegistry, o.Overlay, o.NetworkID, o.Logger),
peers: newPeerRegistry(), peers: peerRegistry,
logger: o.Logger, logger: o.Logger,
} }
...@@ -207,6 +208,10 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -207,6 +208,10 @@ func New(ctx context.Context, o Options) (*Service, error) {
s.logger.Warningf("peer %s has a different network id.", peerID) s.logger.Warningf("peer %s has a different network id.", peerID)
} }
if err == handshake.ErrHandshakeDuplicate {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
}
s.logger.Debugf("handshake: handle %s: %w", peerID, err) s.logger.Debugf("handshake: handle %s: %w", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID) s.logger.Errorf("unable to handshake with peer %v", peerID)
// todo: test connection close and refactor // todo: test connection close and refactor
...@@ -214,13 +219,6 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -214,13 +219,6 @@ func New(ctx context.Context, o Options) (*Service, error) {
return return
} }
if peerID, found := s.peers.peerID(i.Address); found {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
// disconnect if handshake was performed for already existing peer
_ = s.Disconnect(i.Address)
return
}
s.peers.add(peerID, i.Address) s.peers.add(peerID, i.Address)
s.metrics.HandledStreamCount.Inc() s.metrics.HandledStreamCount.Inc()
s.logger.Infof("peer %s connected", i.Address) s.logger.Infof("peer %s connected", i.Address)
...@@ -319,13 +317,6 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -319,13 +317,6 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, err return swarm.Address{}, err
} }
if peerID, found := s.peers.peerID(i.Address); found {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
// disconnect if handshake was performed for already existing peer
_ = s.Disconnect(i.Address)
return
}
s.peers.add(info.ID, i.Address) s.peers.add(info.ID, i.Address)
s.metrics.CreatedConnectionCount.Inc() s.metrics.CreatedConnectionCount.Inc()
s.logger.Infof("peer %s connected", i.Address) s.logger.Infof("peer %s connected", i.Address)
......
...@@ -24,6 +24,11 @@ func newPeerRegistry() *peerRegistry { ...@@ -24,6 +24,11 @@ func newPeerRegistry() *peerRegistry {
} }
} }
func (r *peerRegistry) Exists(overlay swarm.Address) (found bool) {
_, found = r.peerID(overlay)
return found
}
func (r *peerRegistry) add(peerID libp2ppeer.ID, overlay swarm.Address) { func (r *peerRegistry) add(peerID libp2ppeer.ID, overlay swarm.Address) {
r.mu.Lock() r.mu.Lock()
r.peers[encodePeersKey(overlay)] = peerID r.peers[encodePeersKey(overlay)] = peerID
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment