Commit a8afbded authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

Connect from both sides at the same time (#65)

* double connect

* handshake track received handshakes per peerID
parent 4635724d
...@@ -23,7 +23,7 @@ const ( ...@@ -23,7 +23,7 @@ const (
protocolName = "hive" protocolName = "hive"
protocolVersion = "1.0.0" protocolVersion = "1.0.0"
peersStreamName = "peers" peersStreamName = "peers"
messageTimeout = 5 * time.Second // maximum allowed time for a message to be read or written. messageTimeout = 1 * time.Minute // maximum allowed time for a message to be read or written.
maxBatchSize = 50 maxBatchSize = 50
) )
......
...@@ -13,6 +13,9 @@ import ( ...@@ -13,6 +13,9 @@ import (
// peer is not found. // peer is not found.
var ErrPeerNotFound = errors.New("peer not found") var ErrPeerNotFound = errors.New("peer not found")
// ErrAlreadyConnected is returned if connect was called for already connected node
var ErrAlreadyConnected = errors.New("already connected")
// DisconnectError is an error that is specifically handled inside p2p. If returned by specific protocol // DisconnectError is an error that is specifically handled inside p2p. If returned by specific protocol
// handler it causes peer disconnect. // handler it causes peer disconnect.
type DisconnectError struct { type DisconnectError struct {
......
...@@ -11,6 +11,8 @@ import ( ...@@ -11,6 +11,8 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p" "github.com/ethersphere/bee/pkg/p2p/libp2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake"
libp2ppeer "github.com/libp2p/go-libp2p-core/peer"
) )
func TestAddresses(t *testing.T) { func TestAddresses(t *testing.T) {
...@@ -73,12 +75,12 @@ func TestDoubleConnect(t *testing.T) { ...@@ -73,12 +75,12 @@ func TestDoubleConnect(t *testing.T) {
expectPeers(t, s2, overlay1) expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2) expectPeersEventually(t, s1, overlay2)
if _, err := s2.Connect(ctx, addr); err == nil { if _, err := s2.Connect(ctx, addr); !errors.Is(err, p2p.ErrAlreadyConnected) {
t.Fatal("second connect attempt should result with an error") t.Fatalf("expected %s error, got %s error", p2p.ErrAlreadyConnected, err)
} }
expectPeers(t, s2) expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1) expectPeers(t, s1, overlay2)
} }
func TestDoubleDisconnect(t *testing.T) { func TestDoubleDisconnect(t *testing.T) {
...@@ -116,45 +118,6 @@ func TestDoubleDisconnect(t *testing.T) { ...@@ -116,45 +118,6 @@ func TestDoubleDisconnect(t *testing.T) {
expectPeersEventually(t, s1) expectPeersEventually(t, s1)
} }
func TestReconnectAfterDoubleConnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1)
overlay, err := s2.Connect(ctx, addr)
if err != nil {
t.Fatal(err)
}
expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2)
if _, err := s2.Connect(ctx, addr); err == nil {
t.Fatal("second connect attempt should result with an error")
}
expectPeers(t, s2)
expectPeersEventually(t, s1)
overlay, err = s2.Connect(ctx, addr)
if err != nil {
t.Fatal(err)
}
if !overlay.Equal(overlay1) {
t.Errorf("got overlay %s, want %s", overlay, overlay1)
}
expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2)
}
func TestMultipleConnectDisconnect(t *testing.T) { func TestMultipleConnectDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
...@@ -252,8 +215,15 @@ func TestDoubleConnectOnAllAddresses(t *testing.T) { ...@@ -252,8 +215,15 @@ func TestDoubleConnectOnAllAddresses(t *testing.T) {
expectPeers(t, s2, overlay1) expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2) expectPeersEventually(t, s1, overlay2)
if _, err := s2.Connect(ctx, addr); err == nil { if _, err := s2.Connect(ctx, addr); !errors.Is(err, p2p.ErrAlreadyConnected) {
t.Fatal("second connect attempt should result with an error") t.Fatalf("expected %s error, got %s error", p2p.ErrAlreadyConnected, err)
}
expectPeers(t, s2, overlay1)
expectPeers(t, s1, overlay2)
if err := s2.Disconnect(overlay1); err != nil {
t.Fatal(err)
} }
expectPeers(t, s2) expectPeers(t, s2)
...@@ -308,3 +278,42 @@ func TestConnectWithDisabledQUICAndWSTransports(t *testing.T) { ...@@ -308,3 +278,42 @@ func TestConnectWithDisabledQUICAndWSTransports(t *testing.T) {
expectPeers(t, s2, overlay1) expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2) expectPeersEventually(t, s1, overlay2)
} }
// TestConnectRepeatHandshake tests if handshake was attempted more then once by the same peer
func TestConnectRepeatHandshake(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s1, overlay1, cleanup1 := newService(t, libp2p.Options{NetworkID: 1})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, libp2p.Options{NetworkID: 1})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1)
_, err := s2.Connect(ctx, addr)
if err != nil {
t.Fatal(err)
}
expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2)
info, err := libp2ppeer.AddrInfoFromP2pAddr(addr)
if err != nil {
t.Fatal(err)
}
stream, err := s2.NewStreamForPeerID(info.ID, handshake.ProtocolName, handshake.ProtocolVersion, handshake.StreamName)
if err != nil {
t.Fatal(err)
}
if _, err := s2.HandshakeService().Handshake(libp2p.NewStream(stream)); err == nil {
t.Fatalf("expected stream error")
}
expectPeersEventually(t, s2)
expectPeersEventually(t, s1)
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package libp2p
import (
"context"
handshake "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake"
"github.com/libp2p/go-libp2p-core/network"
libp2ppeer "github.com/libp2p/go-libp2p-core/peer"
)
func (s *Service) HandshakeService() *handshake.Service {
return s.handshakeService
}
func (s *Service) NewStreamForPeerID(peerID libp2ppeer.ID, protocolName, protocolVersion, streamName string) (network.Stream, error) {
return s.newStreamForPeerID(context.Background(), peerID, protocolName, protocolVersion, streamName)
}
...@@ -7,6 +7,7 @@ package handshake ...@@ -7,6 +7,7 @@ package handshake
import ( import (
"errors" "errors"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
...@@ -14,6 +15,9 @@ import ( ...@@ -14,6 +15,9 @@ import (
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/libp2p/go-libp2p-core/network"
libp2ppeer "github.com/libp2p/go-libp2p-core/peer"
) )
const ( const (
...@@ -37,24 +41,27 @@ type PeerFinder interface { ...@@ -37,24 +41,27 @@ type PeerFinder interface {
} }
type Service struct { type Service struct {
peerFinder PeerFinder overlay swarm.Address
overlay swarm.Address networkID int32
networkID int32 receivedHandshakes map[libp2ppeer.ID]struct{}
logger logging.Logger receivedHandshakesMu sync.Mutex
logger logging.Logger
network.Notifiee // handhsake service can be the receiver for network.Notify
} }
func New(peerFinder PeerFinder, overlay swarm.Address, networkID int32, logger logging.Logger) *Service { func New(overlay swarm.Address, networkID int32, logger logging.Logger) *Service {
return &Service{ return &Service{
peerFinder: peerFinder, overlay: overlay,
overlay: overlay, networkID: networkID,
networkID: networkID, receivedHandshakes: make(map[libp2ppeer.ID]struct{}),
logger: logger, logger: logger,
Notifiee: new(network.NoopNotifiee),
} }
} }
func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{
Address: s.overlay.Bytes(), Address: s.overlay.Bytes(),
NetworkID: s.networkID, NetworkID: s.networkID,
...@@ -68,10 +75,6 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { ...@@ -68,10 +75,6 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
} }
address := swarm.NewAddress(resp.Syn.Address) address := swarm.NewAddress(resp.Syn.Address)
if s.peerFinder.Exists(address) {
return nil, ErrHandshakeDuplicate
}
if resp.Syn.NetworkID != s.networkID { if resp.Syn.NetworkID != s.networkID {
return nil, ErrNetworkIDIncompatible return nil, ErrNetworkIDIncompatible
} }
...@@ -91,9 +94,17 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { ...@@ -91,9 +94,17 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
}, nil }, nil
} }
func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) { func (s *Service) Handle(stream p2p.Stream, peerID libp2ppeer.ID) (i *Info, err error) {
w, r := protobuf.NewWriterAndReader(stream)
defer stream.Close() defer stream.Close()
s.receivedHandshakesMu.Lock()
if _, exists := s.receivedHandshakes[peerID]; exists {
s.receivedHandshakesMu.Unlock()
return nil, ErrHandshakeDuplicate
}
s.receivedHandshakes[peerID] = struct{}{}
s.receivedHandshakesMu.Unlock()
w, r := protobuf.NewWriterAndReader(stream)
var req pb.Syn var req pb.Syn
if err := r.ReadMsgWithTimeout(messageTimeout, &req); err != nil { if err := r.ReadMsgWithTimeout(messageTimeout, &req); err != nil {
...@@ -101,10 +112,6 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) { ...@@ -101,10 +112,6 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
} }
address := swarm.NewAddress(req.Address) address := swarm.NewAddress(req.Address)
if s.peerFinder.Exists(address) {
return nil, ErrHandshakeDuplicate
}
if req.NetworkID != s.networkID { if req.NetworkID != s.networkID {
return nil, ErrNetworkIDIncompatible return nil, ErrNetworkIDIncompatible
} }
...@@ -132,6 +139,12 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) { ...@@ -132,6 +139,12 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
}, nil }, nil
} }
func (s *Service) Disconnected(_ network.Network, c network.Conn) {
s.receivedHandshakesMu.Lock()
defer s.receivedHandshakesMu.Unlock()
delete(s.receivedHandshakes, c.RemotePeer())
}
type Info struct { type Info struct {
Address swarm.Address Address swarm.Address
NetworkID int32 NetworkID int32
......
...@@ -16,6 +16,9 @@ import ( ...@@ -16,6 +16,9 @@ import (
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake/pb"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/libp2p/go-libp2p-core/peer"
ma "github.com/multiformats/go-multiaddr"
) )
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
...@@ -28,8 +31,7 @@ func TestHandshake(t *testing.T) { ...@@ -28,8 +31,7 @@ func TestHandshake(t *testing.T) {
Light: false, Light: false,
} }
peerFinderMock := &mock.PeerFinder{} handshakeService := New(info.Address, info.NetworkID, logger)
handshakeService := New(peerFinderMock, info.Address, info.NetworkID, logger)
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
expectedInfo := Info{ expectedInfo := Info{
...@@ -66,7 +68,7 @@ func TestHandshake(t *testing.T) { ...@@ -66,7 +68,7 @@ func TestHandshake(t *testing.T) {
} }
}) })
t.Run("ERROR - Syn write error ", func(t *testing.T) { t.Run("ERROR - Syn write error", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write syn message: %w", testErr) expectedErr := fmt.Errorf("write syn message: %w", testErr)
stream := &mock.Stream{} stream := &mock.Stream{}
...@@ -81,7 +83,7 @@ func TestHandshake(t *testing.T) { ...@@ -81,7 +83,7 @@ func TestHandshake(t *testing.T) {
} }
}) })
t.Run("ERROR - Syn read error ", func(t *testing.T) { t.Run("ERROR - Syn read error", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("read synack message: %w", testErr) expectedErr := fmt.Errorf("read synack message: %w", testErr)
stream := mock.NewStream(nil, &bytes.Buffer{}) stream := mock.NewStream(nil, &bytes.Buffer{})
...@@ -96,7 +98,7 @@ func TestHandshake(t *testing.T) { ...@@ -96,7 +98,7 @@ func TestHandshake(t *testing.T) {
} }
}) })
t.Run("ERROR - ack write error ", func(t *testing.T) { t.Run("ERROR - ack write error", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write ack message: %w", testErr) expectedErr := fmt.Errorf("write ack message: %w", testErr)
expectedInfo := Info{ expectedInfo := Info{
...@@ -133,7 +135,7 @@ func TestHandshake(t *testing.T) { ...@@ -133,7 +135,7 @@ func TestHandshake(t *testing.T) {
} }
}) })
t.Run("ERROR - networkID mismatch ", func(t *testing.T) { t.Run("ERROR - networkID mismatch", func(t *testing.T) {
node2Info := Info{ node2Info := Info{
Address: node2Addr, Address: node2Addr,
NetworkID: 2, NetworkID: 2,
...@@ -166,46 +168,20 @@ func TestHandshake(t *testing.T) { ...@@ -166,46 +168,20 @@ func TestHandshake(t *testing.T) {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err) t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
} }
}) })
t.Run("ERROR - duplicate handshake ", func(t *testing.T) {
node2Info := Info{
Address: node2Addr,
NetworkID: 0,
Light: false,
}
peerFinderMock.SetFound(true)
var buffer1 bytes.Buffer
var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2)
stream2 := mock.NewStream(&buffer2, &buffer1)
w, _ := protobuf.NewWriterAndReader(stream2)
if err := w.WriteMsg(&pb.SynAck{
Syn: &pb.Syn{
Address: node2Info.Address.Bytes(),
NetworkID: node2Info.NetworkID,
Light: node2Info.Light,
},
Ack: &pb.Ack{Address: info.Address.Bytes()},
}); err != nil {
t.Fatal(err)
}
res, err := handshakeService.Handshake(stream1)
if res != nil {
t.Fatal("res should be nil")
}
if err != ErrHandshakeDuplicate {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err)
}
})
} }
func TestHandle(t *testing.T) { func TestHandle(t *testing.T) {
node1Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c") node1Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c")
node2Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59b") node2Addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59b")
multiaddress, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/7070/p2p/16Uiu2HAkx8ULY8cTXhdVAcMmLcH9AsTKz6uBQ7DPLKRjMLgBVYkS")
if err != nil {
t.Fatal(err)
}
info, err := peer.AddrInfoFromP2pAddr(multiaddress)
if err != nil {
t.Fatal(err)
}
nodeInfo := Info{ nodeInfo := Info{
Address: node1Addr, Address: node1Addr,
NetworkID: 0, NetworkID: 0,
...@@ -213,10 +189,8 @@ func TestHandle(t *testing.T) { ...@@ -213,10 +189,8 @@ func TestHandle(t *testing.T) {
} }
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
peerFinderMock := &mock.PeerFinder{}
handshakeService := New(peerFinderMock, nodeInfo.Address, nodeInfo.NetworkID, logger)
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger)
node2Info := Info{ node2Info := Info{
Address: node2Addr, Address: node2Addr,
NetworkID: 0, NetworkID: 0,
...@@ -241,7 +215,7 @@ func TestHandle(t *testing.T) { ...@@ -241,7 +215,7 @@ func TestHandle(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1) res, err := handshakeService.Handle(stream1, info.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -262,11 +236,12 @@ func TestHandle(t *testing.T) { ...@@ -262,11 +236,12 @@ func TestHandle(t *testing.T) {
}) })
t.Run("ERROR - read error ", func(t *testing.T) { t.Run("ERROR - read error ", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger)
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("read syn message: %w", testErr) expectedErr := fmt.Errorf("read syn message: %w", testErr)
stream := &mock.Stream{} stream := &mock.Stream{}
stream.SetReadErr(testErr, 0) stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handle(stream) res, err := handshakeService.Handle(stream, info.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -277,6 +252,7 @@ func TestHandle(t *testing.T) { ...@@ -277,6 +252,7 @@ func TestHandle(t *testing.T) {
}) })
t.Run("ERROR - write error ", func(t *testing.T) { t.Run("ERROR - write error ", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger)
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write synack message: %w", testErr) expectedErr := fmt.Errorf("write synack message: %w", testErr)
var buffer bytes.Buffer var buffer bytes.Buffer
...@@ -291,7 +267,7 @@ func TestHandle(t *testing.T) { ...@@ -291,7 +267,7 @@ func TestHandle(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream) res, err := handshakeService.Handle(stream, info.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -302,6 +278,7 @@ func TestHandle(t *testing.T) { ...@@ -302,6 +278,7 @@ func TestHandle(t *testing.T) {
}) })
t.Run("ERROR - ack read error ", func(t *testing.T) { t.Run("ERROR - ack read error ", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger)
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("read ack message: %w", testErr) expectedErr := fmt.Errorf("read ack message: %w", testErr)
node2Info := Info{ node2Info := Info{
...@@ -324,7 +301,7 @@ func TestHandle(t *testing.T) { ...@@ -324,7 +301,7 @@ func TestHandle(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1) res, err := handshakeService.Handle(stream1, info.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -335,6 +312,7 @@ func TestHandle(t *testing.T) { ...@@ -335,6 +312,7 @@ func TestHandle(t *testing.T) {
}) })
t.Run("ERROR - networkID mismatch ", func(t *testing.T) { t.Run("ERROR - networkID mismatch ", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger)
node2Info := Info{ node2Info := Info{
Address: node2Addr, Address: node2Addr,
NetworkID: 2, NetworkID: 2,
...@@ -355,7 +333,7 @@ func TestHandle(t *testing.T) { ...@@ -355,7 +333,7 @@ func TestHandle(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1) res, err := handshakeService.Handle(stream1, info.ID)
if res != nil { if res != nil {
t.Fatal("res should be nil") t.Fatal("res should be nil")
} }
...@@ -365,14 +343,14 @@ func TestHandle(t *testing.T) { ...@@ -365,14 +343,14 @@ func TestHandle(t *testing.T) {
} }
}) })
t.Run("ERROR - duplicate handshake msg", func(t *testing.T) { t.Run("ERROR - duplicate handshake", func(t *testing.T) {
handshakeService := New(nodeInfo.Address, nodeInfo.NetworkID, logger)
node2Info := Info{ node2Info := Info{
Address: node2Addr, Address: node2Addr,
NetworkID: 0, NetworkID: 0,
Light: false, Light: false,
} }
peerFinderMock.SetFound(true)
var buffer1 bytes.Buffer var buffer1 bytes.Buffer
var buffer2 bytes.Buffer var buffer2 bytes.Buffer
stream1 := mock.NewStream(&buffer1, &buffer2) stream1 := mock.NewStream(&buffer1, &buffer2)
...@@ -387,13 +365,32 @@ func TestHandle(t *testing.T) { ...@@ -387,13 +365,32 @@ func TestHandle(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1) if err := w.WriteMsg(&pb.Ack{Address: node2Info.Address.Bytes()}); err != nil {
if res != nil { t.Fatal(err)
t.Fatal("res should be nil") }
res, err := handshakeService.Handle(stream1, info.ID)
if err != nil {
t.Fatal(err)
} }
testInfo(t, *res, node2Info)
_, r := protobuf.NewWriterAndReader(stream2)
var got pb.SynAck
if err := r.ReadMsg(&got); err != nil {
t.Fatal(err)
}
testInfo(t, nodeInfo, Info{
Address: swarm.NewAddress(got.Syn.Address),
NetworkID: got.Syn.NetworkID,
Light: got.Syn.Light,
})
_, err = handshakeService.Handle(stream1, info.ID)
if err != ErrHandshakeDuplicate { if err != ErrHandshakeDuplicate {
t.Fatalf("expected %s, got %s", ErrNetworkIDIncompatible, err) t.Fatalf("expected %s err, got %s err", ErrHandshakeDuplicate, err)
} }
}) })
} }
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import "github.com/ethersphere/bee/pkg/swarm"
// todo: implement peer registry mocks, export appropriate interface and move those in libp2p so it can be used in handshake
type PeerFinder struct {
found bool
}
func (p *PeerFinder) SetFound(found bool) {
p.found = found
}
func (p *PeerFinder) Exists(overlay swarm.Address) (found bool) {
return p.found
}
...@@ -158,7 +158,7 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -158,7 +158,7 @@ func New(ctx context.Context, o Options) (*Service, error) {
libp2pPeerstore: libp2pPeerstore, libp2pPeerstore: libp2pPeerstore,
metrics: newMetrics(), metrics: newMetrics(),
networkID: o.NetworkID, networkID: o.NetworkID,
handshakeService: handshake.New(peerRegistry, o.Overlay, o.NetworkID, o.Logger), handshakeService: handshake.New(o.Overlay, o.NetworkID, o.Logger),
peers: peerRegistry, peers: peerRegistry,
addrssbook: o.Addressbook, addrssbook: o.Addressbook,
logger: o.Logger, logger: o.Logger,
...@@ -173,9 +173,10 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -173,9 +173,10 @@ func New(ctx context.Context, o Options) (*Service, error) {
return nil, fmt.Errorf("protocol version match %s: %w", id, err) return nil, fmt.Errorf("protocol version match %s: %w", id, err)
} }
// handshake
s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) { s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) {
peerID := stream.Conn().RemotePeer() peerID := stream.Conn().RemotePeer()
i, err := s.handshakeService.Handle(newStream(stream)) i, err := s.handshakeService.Handle(NewStream(stream), peerID)
if err != nil { if err != nil {
if err == handshake.ErrNetworkIDIncompatible { if err == handshake.ErrNetworkIDIncompatible {
s.logger.Warningf("peer %s has a different network id.", peerID) s.logger.Warningf("peer %s has a different network id.", peerID)
...@@ -187,12 +188,14 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -187,12 +188,14 @@ func New(ctx context.Context, o Options) (*Service, error) {
s.logger.Debugf("handshake: handle %s: %v", peerID, err) s.logger.Debugf("handshake: handle %s: %v", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID) s.logger.Errorf("unable to handshake with peer %v", peerID)
// todo: test connection close and refactor
_ = s.disconnect(peerID) _ = s.disconnect(peerID)
return return
} }
s.peers.add(stream.Conn(), i.Address) if exists := s.peers.addIfNotExists(stream.Conn(), i.Address); exists {
return
}
remoteMultiaddr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", stream.Conn().RemoteMultiaddr().String(), peerID.Pretty())) remoteMultiaddr, err := ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", stream.Conn().RemoteMultiaddr().String(), peerID.Pretty()))
if err != nil { if err != nil {
s.logger.Debugf("multiaddr error: handle %s: %v", peerID, err) s.logger.Debugf("multiaddr error: handle %s: %v", peerID, err)
...@@ -216,7 +219,8 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -216,7 +219,8 @@ func New(ctx context.Context, o Options) (*Service, error) {
s.metrics.HandledConnectionCount.Inc() s.metrics.HandledConnectionCount.Inc()
}) })
h.Network().Notify(peerRegistry) // update peer registry on network events h.Network().Notify(peerRegistry) // update peer registry on network events
h.Network().Notify(s.handshakeService) // update handshake service on network events
return s, nil return s, nil
} }
...@@ -297,6 +301,10 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -297,6 +301,10 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, err return swarm.Address{}, err
} }
if _, found := s.peers.overlay(info.ID); found {
return swarm.Address{}, p2p.ErrAlreadyConnected
}
if err := s.host.Connect(ctx, *info); err != nil { if err := s.host.Connect(ctx, *info); err != nil {
return swarm.Address{}, err return swarm.Address{}, err
} }
...@@ -307,7 +315,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -307,7 +315,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, err return swarm.Address{}, err
} }
i, err := s.handshakeService.Handshake(newStream(stream)) i, err := s.handshakeService.Handshake(NewStream(stream))
if err != nil { if err != nil {
_ = s.disconnect(info.ID) _ = s.disconnect(info.ID)
return swarm.Address{}, fmt.Errorf("handshake: %w", err) return swarm.Address{}, fmt.Errorf("handshake: %w", err)
...@@ -317,7 +325,10 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -317,7 +325,10 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, err return swarm.Address{}, err
} }
s.peers.add(stream.Conn(), i.Address) if exists := s.peers.addIfNotExists(stream.Conn(), i.Address); exists {
return i.Address, nil
}
s.metrics.CreatedConnectionCount.Inc() s.metrics.CreatedConnectionCount.Inc()
s.logger.Infof("peer %s connected", i.Address) s.logger.Infof("peer %s connected", i.Address)
return i.Address, nil return i.Address, nil
...@@ -328,6 +339,7 @@ func (s *Service) Disconnect(overlay swarm.Address) error { ...@@ -328,6 +339,7 @@ func (s *Service) Disconnect(overlay swarm.Address) error {
if !found { if !found {
return p2p.ErrPeerNotFound return p2p.ErrPeerNotFound
} }
return s.disconnect(peerID) return s.disconnect(peerID)
} }
...@@ -335,6 +347,7 @@ func (s *Service) disconnect(peerID libp2ppeer.ID) error { ...@@ -335,6 +347,7 @@ func (s *Service) disconnect(peerID libp2ppeer.ID) error {
if err := s.host.Network().ClosePeer(peerID); err != nil { if err := s.host.Network().ClosePeer(peerID); err != nil {
return err return err
} }
s.peers.remove(peerID) s.peers.remove(peerID)
return nil return nil
} }
......
...@@ -63,31 +63,37 @@ func (r *peerRegistry) Disconnected(_ network.Network, c network.Conn) { ...@@ -63,31 +63,37 @@ func (r *peerRegistry) Disconnected(_ network.Network, c network.Conn) {
} }
func (r *peerRegistry) peers() []p2p.Peer { func (r *peerRegistry) peers() []p2p.Peer {
r.mu.Lock() r.mu.RLock()
peers := make([]p2p.Peer, 0, len(r.overlays)) peers := make([]p2p.Peer, 0, len(r.overlays))
for _, a := range r.overlays { for _, a := range r.overlays {
peers = append(peers, p2p.Peer{ peers = append(peers, p2p.Peer{
Address: a, Address: a,
}) })
} }
r.mu.Unlock() r.mu.RUnlock()
sort.Slice(peers, func(i, j int) bool { sort.Slice(peers, func(i, j int) bool {
return bytes.Compare(peers[i].Address.Bytes(), peers[j].Address.Bytes()) == -1 return bytes.Compare(peers[i].Address.Bytes(), peers[j].Address.Bytes()) == -1
}) })
return peers return peers
} }
func (r *peerRegistry) add(c network.Conn, overlay swarm.Address) { func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address) (exists bool) {
peerID := c.RemotePeer() peerID := c.RemotePeer()
r.mu.Lock() r.mu.Lock()
r.underlays[overlay.ByteString()] = peerID defer r.mu.Unlock()
r.overlays[peerID] = overlay
if _, ok := r.connections[peerID]; !ok { if _, ok := r.connections[peerID]; !ok {
r.connections[peerID] = make(map[network.Conn]struct{}) r.connections[peerID] = make(map[network.Conn]struct{})
} }
r.connections[peerID][c] = struct{}{} r.connections[peerID][c] = struct{}{}
r.mu.Unlock()
if _, exists := r.underlays[overlay.ByteString()]; !exists {
r.underlays[overlay.ByteString()] = peerID
r.overlays[peerID] = overlay
return false
}
return true
} }
func (r *peerRegistry) peerID(overlay swarm.Address) (peerID libp2ppeer.ID, found bool) { func (r *peerRegistry) peerID(overlay swarm.Address) (peerID libp2ppeer.ID, found bool) {
......
...@@ -10,15 +10,20 @@ import ( ...@@ -10,15 +10,20 @@ import (
"github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/network"
) )
var _ p2p.Stream = (*stream)(nil)
type stream struct { type stream struct {
network.Stream network.Stream
headers map[string][]byte headers map[string][]byte
} }
func newStream(s network.Stream) *stream { func NewStream(s network.Stream) p2p.Stream {
return &stream{Stream: s} return &stream{Stream: s}
} }
func newStream(s network.Stream) *stream {
return &stream{Stream: s}
}
func (s *stream) Headers() p2p.Headers { func (s *stream) Headers() p2p.Headers {
return s.headers return s.headers
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment