Commit 6fa682f5 authored by Petar Radovic's avatar Petar Radovic

disconnect on diff network id and repeated handshake

parent d0f578de
......@@ -5,6 +5,7 @@
package handshake
import (
"errors"
"fmt"
"github.com/ethersphere/bee/pkg/logging"
......@@ -20,6 +21,10 @@ const (
StreamVersion = "1.0.0"
)
// ErrNetworkIDIncompatible should be returned by handshake handlers if
// response from the other peer does not have valid networkID.
var ErrNetworkIDIncompatible = errors.New("incompatible networkID")
type Service struct {
overlay swarm.Address
networkID int32
......@@ -48,6 +53,10 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("read message: %w", err)
}
if resp.Syn.NetworkID != s.networkID {
return nil, ErrNetworkIDIncompatible
}
if err := w.WriteMsg(&pb.Ack{Address: resp.Syn.Address}); err != nil {
return nil, fmt.Errorf("ack: write message: %w", err)
}
......@@ -72,6 +81,10 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("read message: %w", err)
}
if req.NetworkID != s.networkID {
return nil, ErrNetworkIDIncompatible
}
if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{
Address: s.overlay.Bytes(),
......
......@@ -31,7 +31,7 @@ func TestHandshake(t *testing.T) {
t.Run("OK", func(t *testing.T) {
expectedInfo := Info{
Address: node2Addr,
NetworkID: 1,
NetworkID: 0,
Light: false,
}
......@@ -58,7 +58,6 @@ func TestHandshake(t *testing.T) {
}
testInfo(t, *res, expectedInfo)
if err := r.ReadMsg(&pb.Ack{}); err != nil {
t.Fatal(err)
}
......@@ -99,7 +98,7 @@ func TestHandshake(t *testing.T) {
expectedErr := fmt.Errorf("ack: write message: %w", testErr)
expectedInfo := Info{
Address: node2Addr,
NetworkID: 1,
NetworkID: 0,
Light: false,
}
......@@ -130,6 +129,40 @@ func TestHandshake(t *testing.T) {
t.Fatal("handshake returned non-nil res")
}
})
t.Run("ERROR - networkID mismatch ", func(t *testing.T) {
node2Info := Info{
Address: node2Addr,
NetworkID: 2,
Light: false,
}
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{
Address: node2Info.Address.Bytes(),
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
},
Ack: &pb.Ack{Address: info.Address.Bytes()},
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handshake(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != ErrNetworkIDIncompatible {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
}
func TestHandle(t *testing.T) {
......@@ -147,7 +180,7 @@ func TestHandle(t *testing.T) {
t.Run("OK", func(t *testing.T) {
node2Info := Info{
Address: node2Addr,
NetworkID: 1,
NetworkID: 0,
Light: false,
}
......@@ -234,7 +267,7 @@ func TestHandle(t *testing.T) {
expectedErr := fmt.Errorf("ack: read message: %w", testErr)
node2Info := Info{
Address: node2Addr,
NetworkID: 1,
NetworkID: 0,
Light: false,
}
......@@ -261,6 +294,37 @@ func TestHandle(t *testing.T) {
t.Fatal("handshake returned non-nil res")
}
})
t.Run("ERROR - networkID mismatch ", func(t *testing.T) {
node2Info := Info{
Address: node2Addr,
NetworkID: 2,
Light: false,
}
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.Syn{
Address: node2Info.Address.Bytes(),
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handle(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != ErrNetworkIDIncompatible {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
}
// testInfo validates if two Info instances are equal.
......
......@@ -203,18 +203,24 @@ func New(ctx context.Context, o Options) (*Service, error) {
peerID := stream.Conn().RemotePeer()
i, err := s.handshakeService.Handle(stream)
if err != nil {
if err == handshake.ErrNetworkIDIncompatible {
s.logger.Warningf("peer %s has a different network id.", peerID)
}
s.logger.Debugf("handshake: handle %s: %w", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID)
// todo: test connection close and refactor
_ = stream.Conn().Close()
_ = s.host.Network().ClosePeer(peerID)
return
}
if i.NetworkID != s.networkID {
s.logger.Warningf("peer %s has a different network id %v", peerID, i.NetworkID)
// todo: test connection close and refactor
_ = stream.Conn().Close()
if peerID, found := s.peers.peerID(i.Address); found {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
// disconnect if handshake was performed for already existing peer
_ = s.Disconnect(i.Address)
return
}
s.peers.add(peerID, i.Address)
s.metrics.HandledStreamCount.Inc()
s.logger.Infof("peer %s connected", i.Address)
......@@ -254,7 +260,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
if !found {
// todo: this should never happen, should we disconnect in this case?
// todo: test connection close and refactor
_ = stream.Conn().Close()
_ = s.host.Network().ClosePeer(peerID)
s.logger.Errorf("overlay address for peer %q not found", peerID)
return
}
......@@ -264,8 +270,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
var e *p2p.DisconnectError
if errors.Is(err, e) {
// todo: test connection close and refactor
s.peers.remove(peerID)
_ = stream.Conn().Close()
_ = s.Disconnect(overlay)
}
s.logger.Debugf("handle protocol %s: stream %s/%s: peer %s: %w", p.Name, ss.Name, ss.Version, overlay, err)
......@@ -303,16 +308,26 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
stream, err := s.newStreamForPeerID(ctx, info.ID, handshake.ProtocolName, handshake.StreamName, handshake.StreamVersion)
if err != nil {
_ = s.host.Network().ClosePeer(info.ID)
return swarm.Address{}, fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
i, err := s.handshakeService.Handshake(stream)
if err != nil {
if err == handshake.ErrNetworkIDIncompatible {
s.logger.Warningf("peer %s has a different network id.", info.ID)
}
_ = s.host.Network().ClosePeer(info.ID)
return swarm.Address{}, err
}
if i.NetworkID != s.networkID {
return swarm.Address{}, fmt.Errorf("invalid network id %v", i.NetworkID)
if peerID, found := s.peers.peerID(i.Address); found {
s.logger.Warningf("handshake happened for already connected peer %s", peerID)
// disconnect if handshake was performed for already existing peer
_ = s.Disconnect(i.Address)
return
}
s.peers.add(info.ID, i.Address)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment