Commit dea34546 authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

Granular notifiers in protocol spec (#722)

* granular notifiers in protocol spec

* remove debugapi toplogy notifier

* fix tests and return toploguy notifier but a bit different

* fix tests

* please linter

* remove unnecessary pointer

* protocolsmu

* protocolsmu

* tests
parent 006de07d
......@@ -35,14 +35,6 @@ func (s *server) peerConnectHandler(w http.ResponseWriter, r *http.Request) {
return
}
if err := s.TopologyDriver.Connected(r.Context(), bzzAddr.Overlay); err != nil {
_ = s.P2P.Disconnect(bzzAddr.Overlay)
s.Logger.Debugf("debug api: peer connect handler %s: %v", addr, err)
s.Logger.Errorf("unable to connect to peer %s", addr)
jsonhttp.InternalServerError(w, err)
return
}
jsonhttp.OK(w, peerConnectResponse{
Address: bzzAddr.Overlay.String(),
})
......
......@@ -440,8 +440,8 @@ func (k *Kad) AddPeers(ctx context.Context, addrs ...swarm.Address) error {
}
// Connected is called when a peer has dialed in.
func (k *Kad) Connected(ctx context.Context, addr swarm.Address) error {
if err := k.connected(ctx, addr); err != nil {
func (k *Kad) Connected(ctx context.Context, peer p2p.Peer) error {
if err := k.connected(ctx, peer.Address); err != nil {
return err
}
......@@ -476,12 +476,12 @@ func (k *Kad) connected(ctx context.Context, addr swarm.Address) error {
}
// Disconnected is called when peer disconnects.
func (k *Kad) Disconnected(addr swarm.Address) {
po := swarm.Proximity(k.base.Bytes(), addr.Bytes())
k.connectedPeers.Remove(addr, po)
func (k *Kad) Disconnected(peer p2p.Peer) {
po := swarm.Proximity(k.base.Bytes(), peer.Address.Bytes())
k.connectedPeers.Remove(peer.Address, po)
k.waitNextMu.Lock()
k.waitNext[addr.String()] = retryInfo{tryAfter: time.Now().Add(timeToRetry), failedAttempts: 0}
k.waitNext[peer.Address.String()] = retryInfo{tryAfter: time.Now().Add(timeToRetry), failedAttempts: 0}
k.waitNextMu.Unlock()
k.depthMu.Lock()
......
......@@ -288,7 +288,7 @@ func TestNotifierHooks(t *testing.T) {
}
// disconnect the peer, expect error
kad.Disconnected(peer)
kad.Disconnected(p2p.Peer{Address: peer})
_, err = kad.ClosestPeer(addr)
if !errors.Is(err, topology.ErrNotFound) {
t.Fatalf("expected topology.ErrNotFound but got %v", err)
......@@ -779,7 +779,7 @@ func p2pMock(ab addressbook.Interface, signer beeCrypto.Signer, counter, failedC
}
func removeOne(k *kademlia.Kad, peer swarm.Address) {
k.Disconnected(peer)
k.Disconnected(p2p.Peer{Address: peer})
}
const underlayBase = "/ip4/127.0.0.1/tcp/7070/dns/"
......@@ -798,7 +798,7 @@ func connectOne(t *testing.T, signer beeCrypto.Signer, k *kademlia.Kad, ab addre
if err := ab.Put(peer, *bzzAddr); err != nil {
t.Fatal(err)
}
_ = k.Connected(context.Background(), peer)
_ = k.Connected(context.Background(), p2p.Peer{Address: peer})
}
func addOne(t *testing.T, signer beeCrypto.Signer, k *kademlia.Kad, ab addressbook.Putter, peer swarm.Address) {
......
......@@ -261,7 +261,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
kad := kademlia.New(swarmAddress, addressbook, hive, p2ps, logger, kademlia.Options{Bootnodes: bootnodes, Standalone: o.Standalone})
b.topologyCloser = kad
hive.SetAddPeersHandler(kad.AddPeers)
p2ps.AddNotifier(kad)
p2ps.SetNotifier(kad)
addrs, err := p2ps.Addresses()
if err != nil {
return nil, fmt.Errorf("get server addresses: %w", err)
......
......@@ -17,7 +17,6 @@ import (
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake"
"github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
libp2ppeer "github.com/libp2p/go-libp2p-core/peer"
ma "github.com/multiformats/go-multiaddr"
)
......@@ -306,6 +305,48 @@ func TestConnectRepeatHandshake(t *testing.T) {
expectPeersEventually(t, s1)
}
func TestBlocklisting(t *testing.T) {
s1, overlay1 := newService(t, 1, libp2pServiceOpts{})
s2, overlay2 := newService(t, 1, libp2pServiceOpts{})
addr1 := serviceUnderlayAddress(t, s1)
addr2 := serviceUnderlayAddress(t, s2)
// s2 connects to s1, thus the notifier on s1 should be called on Connect
_, err := s2.Connect(context.Background(), addr1)
if err != nil {
t.Fatal(err)
}
expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2)
if err := s2.Blocklist(overlay1, 0); err != nil {
t.Fatal(err)
}
expectPeers(t, s2)
expectPeersEventually(t, s1)
// s2 connects to s1, thus the notifier on s1 should be called on Connect
_, err = s2.Connect(context.Background(), addr1)
if err == nil {
t.Fatal("expected error during connection, got nil")
}
expectPeers(t, s2)
expectPeersEventually(t, s1)
// s2 connects to s1, thus the notifier on s1 should be called on Connect
_, err = s1.Connect(context.Background(), addr2)
if err == nil {
t.Fatal("expected error during connection, got nil")
}
expectPeers(t, s1)
expectPeersEventually(t, s2)
}
func TestTopologyNotifier(t *testing.T) {
var (
mtx sync.Mutex
......@@ -313,44 +354,44 @@ func TestTopologyNotifier(t *testing.T) {
ab1, ab2 = addressbook.New(mock.NewStateStore()), addressbook.New(mock.NewStateStore())
n1connectedAddr swarm.Address
n1disconnectedAddr swarm.Address
n2connectedAddr swarm.Address
n2disconnectedAddr swarm.Address
n1connectedPeer p2p.Peer
n1disconnectedPeer p2p.Peer
n2connectedPeer p2p.Peer
n2disconnectedPeer p2p.Peer
n1c = func(_ context.Context, a swarm.Address) error {
n1c = func(_ context.Context, p p2p.Peer) error {
mtx.Lock()
defer mtx.Unlock()
expectZeroAddress(t, n1connectedAddr) // fail if set more than once
n1connectedAddr = a
expectZeroAddress(t, n1connectedPeer.Address) // fail if set more than once
n1connectedPeer = p
return nil
}
n1d = func(a swarm.Address) {
n1d = func(p p2p.Peer) {
mtx.Lock()
defer mtx.Unlock()
n1disconnectedAddr = a
n1disconnectedPeer = p
}
n2c = func(_ context.Context, a swarm.Address) error {
n2c = func(_ context.Context, p p2p.Peer) error {
mtx.Lock()
defer mtx.Unlock()
expectZeroAddress(t, n2connectedAddr) // fail if set more than once
n2connectedAddr = a
expectZeroAddress(t, n2connectedPeer.Address) // fail if set more than once
n2connectedPeer = p
return nil
}
n2d = func(a swarm.Address) {
n2d = func(p p2p.Peer) {
mtx.Lock()
defer mtx.Unlock()
n2disconnectedAddr = a
n2disconnectedPeer = p
}
)
notifier1 := mockNotifier(n1c, n1d)
s1, overlay1 := newService(t, 1, libp2pServiceOpts{Addressbook: ab1})
s1.AddNotifier(notifier1)
s1.SetNotifier(notifier1)
notifier2 := mockNotifier(n2c, n2d)
s2, overlay2 := newService(t, 1, libp2pServiceOpts{Addressbook: ab2})
s2.AddNotifier(notifier2)
s2.SetNotifier(notifier2)
addr := serviceUnderlayAddress(t, s1)
......@@ -364,10 +405,10 @@ func TestTopologyNotifier(t *testing.T) {
expectPeersEventually(t, s1, overlay2)
// expect that n1 notifee called with s2 overlay
waitAddrSet(t, &n1connectedAddr, &mtx, overlay2)
waitAddrSet(t, &n1connectedPeer.Address, &mtx, overlay2)
mtx.Lock()
expectZeroAddress(t, n1disconnectedAddr, n2connectedAddr, n2disconnectedAddr)
expectZeroAddress(t, n1disconnectedPeer.Address, n2connectedPeer.Address, n2disconnectedPeer.Address)
mtx.Unlock()
// check address book entries are there
......@@ -380,7 +421,7 @@ func TestTopologyNotifier(t *testing.T) {
expectPeers(t, s2)
expectPeersEventually(t, s1)
waitAddrSet(t, &n1disconnectedAddr, &mtx, overlay2)
waitAddrSet(t, &n1disconnectedPeer.Address, &mtx, overlay2)
// note that both n1disconnect and n2disconnect callbacks are called after just
// one disconnect. this is due to the fact the when the libp2p abstraction is explicitly
......@@ -388,7 +429,7 @@ func TestTopologyNotifier(t *testing.T) {
// peer disconnections can also result from components from outside the bound of the
// topology driver
mtx.Lock()
expectZeroAddress(t, n2connectedAddr)
expectZeroAddress(t, n2connectedPeer.Address)
mtx.Unlock()
addr2 := serviceUnderlayAddress(t, s2)
......@@ -400,7 +441,7 @@ func TestTopologyNotifier(t *testing.T) {
expectPeers(t, s1, overlay2)
expectPeersEventually(t, s2, overlay1)
waitAddrSet(t, &n2connectedAddr, &mtx, overlay1)
waitAddrSet(t, &n2connectedPeer.Address, &mtx, overlay1)
// s1 disconnects from s2 so s2 disconnect notifiee should be called
if err := s1.Disconnect(bzzAddr2.Overlay); err != nil {
......@@ -408,96 +449,16 @@ func TestTopologyNotifier(t *testing.T) {
}
expectPeers(t, s1)
expectPeersEventually(t, s2)
waitAddrSet(t, &n2disconnectedAddr, &mtx, overlay1)
waitAddrSet(t, &n2disconnectedPeer.Address, &mtx, overlay1)
}
func TestTopologySupportMultipleNotifiers(t *testing.T) {
var (
mtx sync.Mutex
n21connectedAddr swarm.Address
n22connectedAddr swarm.Address
n21c = func(_ context.Context, a swarm.Address) error {
mtx.Lock()
defer mtx.Unlock()
n21connectedAddr = a
return nil
}
n21d = func(a swarm.Address) {
}
n22c = func(_ context.Context, a swarm.Address) error {
mtx.Lock()
defer mtx.Unlock()
n22connectedAddr = a
return nil
}
n22d = func(a swarm.Address) {
func expectZeroAddress(t *testing.T, addrs ...swarm.Address) {
t.Helper()
for i, a := range addrs {
if !a.Equal(swarm.ZeroAddress) {
t.Fatalf("address did not equal zero address. index %d", i)
}
)
s1, overlay1 := newService(t, 1, libp2pServiceOpts{})
s1.AddNotifier(mockNotifier(n21c, n21d))
s1.AddNotifier(mockNotifier(n22c, n22d))
s2, overlay2 := newService(t, 1, libp2pServiceOpts{})
addr := serviceUnderlayAddress(t, s1)
// s2 connects to s1, thus the notifier on s1 should be called on Connect
_, err := s2.Connect(context.Background(), addr)
if err != nil {
t.Fatal(err)
}
expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2)
// expect that n1 notifee called with s2 overlay
waitAddrSet(t, &n21connectedAddr, &mtx, overlay2)
waitAddrSet(t, &n22connectedAddr, &mtx, overlay2)
}
func TestBlocklisting(t *testing.T) {
s1, overlay1 := newService(t, 1, libp2pServiceOpts{})
s2, overlay2 := newService(t, 1, libp2pServiceOpts{})
addr1 := serviceUnderlayAddress(t, s1)
addr2 := serviceUnderlayAddress(t, s2)
// s2 connects to s1, thus the notifier on s1 should be called on Connect
_, err := s2.Connect(context.Background(), addr1)
if err != nil {
t.Fatal(err)
}
expectPeers(t, s2, overlay1)
expectPeersEventually(t, s1, overlay2)
if err := s2.Blocklist(overlay1, 0); err != nil {
t.Fatal(err)
}
expectPeers(t, s2)
expectPeersEventually(t, s1)
// s2 connects to s1, thus the notifier on s1 should be called on Connect
_, err = s2.Connect(context.Background(), addr1)
if err == nil {
t.Fatal("expected error during connection, got nil")
}
expectPeers(t, s2)
expectPeersEventually(t, s1)
// s2 connects to s1, thus the notifier on s1 should be called on Connect
_, err = s1.Connect(context.Background(), addr2)
if err == nil {
t.Fatal("expected error during connection, got nil")
}
expectPeers(t, s1)
expectPeersEventually(t, s2)
}
func waitAddrSet(t *testing.T, addr *swarm.Address, mtx *sync.Mutex, exp swarm.Address) {
......@@ -530,26 +491,21 @@ func checkAddressbook(t *testing.T, ab addressbook.Getter, overlay swarm.Address
}
type notifiee struct {
connected func(context.Context, swarm.Address) error
disconnected func(swarm.Address)
connected func(context.Context, p2p.Peer) error
disconnected func(p2p.Peer)
}
func (n *notifiee) Connected(c context.Context, a swarm.Address) error {
return n.connected(c, a)
func (n *notifiee) Connected(c context.Context, p p2p.Peer) error {
return n.connected(c, p)
}
func (n *notifiee) Disconnected(a swarm.Address) {
n.disconnected(a)
func (n *notifiee) Disconnected(p p2p.Peer) {
n.disconnected(p)
}
func mockNotifier(c cFunc, d dFunc) topology.Notifier {
func mockNotifier(c cFunc, d dFunc) p2p.Notifier {
return &notifiee{connected: c, disconnected: d}
}
type cFunc func(context.Context, swarm.Address) error
type dFunc func(swarm.Address)
var noopNotifier = mockNotifier(
func(_ context.Context, _ swarm.Address) error { return nil },
func(_ swarm.Address) {},
)
type cFunc func(context.Context, p2p.Peer) error
type dFunc func(p2p.Peer)
......@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/ethersphere/bee/pkg/addressbook"
......@@ -22,7 +23,6 @@ import (
handshake "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/tracing"
"github.com/libp2p/go-libp2p"
autonat "github.com/libp2p/go-libp2p-autonat-svc"
......@@ -56,11 +56,14 @@ type Service struct {
handshakeService *handshake.Service
addressbook addressbook.Putter
peers *peerRegistry
topologyNotifiers []topology.Notifier
connectionBreaker breaker.Interface
blocklist *blocklist.Blocklist
protocols []p2p.ProtocolSpec
notifier p2p.Notifier
logger logging.Logger
tracer *tracing.Tracer
protocolsmu sync.RWMutex
}
type Options struct {
......@@ -208,6 +211,9 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
tracer: tracer,
connectionBreaker: breaker.NewBreaker(breaker.Options{}), // use default options
}
peerRegistry.setDisconnecter(s)
// Construct protocols.
id := protocol.ID(p2p.NewSwarmStreamName(handshake.ProtocolName, handshake.ProtocolVersion, handshake.StreamName))
matcher, err := s.protocolSemverMatcher(id)
......@@ -224,7 +230,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
s.logger.Debugf("handshake: handle %s: %v", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID)
_ = handshakeStream.Reset()
_ = s.disconnect(peerID)
_ = s.host.Network().ClosePeer(peerID)
return
}
......@@ -232,13 +238,13 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
if err != nil {
s.logger.Debugf("blocklisting: exists %s: %v", peerID, err)
s.logger.Errorf("internal error while connecting with peer %s", peerID)
_ = s.disconnect(peerID)
_ = s.host.Network().ClosePeer(peerID)
return
}
if blocked {
s.logger.Errorf("blocked connection from blocklisted peer %s", peerID)
_ = s.disconnect(peerID)
_ = s.host.Network().ClosePeer(peerID)
return
}
......@@ -246,7 +252,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
if err = handshakeStream.FullClose(); err != nil {
s.logger.Debugf("handshake: could not close stream %s: %v", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID)
_ = s.disconnect(peerID)
_ = s.Disconnect(i.BzzAddress.Overlay)
}
return
}
......@@ -254,7 +260,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
if err = handshakeStream.FullClose(); err != nil {
s.logger.Debugf("handshake: could not close stream %s: %v", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID)
_ = s.disconnect(peerID)
_ = s.Disconnect(i.BzzAddress.Overlay)
return
}
......@@ -262,18 +268,29 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
if err != nil {
s.logger.Debugf("handshake: addressbook put error %s: %v", peerID, err)
s.logger.Errorf("unable to persist peer %v", peerID)
_ = s.disconnect(peerID)
_ = s.Disconnect(i.BzzAddress.Overlay)
return
}
if len(s.topologyNotifiers) > 0 {
for _, tn := range s.topologyNotifiers {
if err := tn.Connected(ctx, i.BzzAddress.Overlay); err != nil {
s.logger.Debugf("topology notifier: %s: %v", peerID, err)
peer := p2p.Peer{Address: i.BzzAddress.Overlay}
s.protocolsmu.RLock()
for _, tn := range s.protocols {
if tn.ConnectIn != nil {
if err := tn.ConnectIn(ctx, peer); err != nil {
s.logger.Debugf("connectIn: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, i.BzzAddress.Overlay, err)
}
}
}
s.protocolsmu.RUnlock()
if s.notifier != nil {
if err := s.notifier.Connected(ctx, peer); err != nil {
s.logger.Debugf("notifier.Connected: peer: %s: %v", i.BzzAddress.Overlay, err)
}
}
s.metrics.HandledStreamCount.Inc()
s.logger.Infof("successfully connected to peer (inbound) %s", i.BzzAddress.ShortString())
})
......@@ -287,6 +304,10 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
return s, nil
}
func (s *Service) SetNotifier(n p2p.Notifier) {
s.notifier = n
}
func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
for _, ss := range p.StreamSpecs {
ss := ss
......@@ -300,7 +321,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
peerID := streamlibp2p.Conn().RemotePeer()
overlay, found := s.peers.overlay(peerID)
if !found {
_ = s.disconnect(peerID)
_ = s.Disconnect(overlay)
s.logger.Debugf("overlay address for peer %q not found", peerID)
return
}
......@@ -353,6 +374,10 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
}
})
}
s.protocolsmu.Lock()
s.protocols = append(s.protocols, p)
s.protocolsmu.Unlock()
return nil
}
......@@ -414,7 +439,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
stream, err := s.newStreamForPeerID(ctx, info.ID, handshake.ProtocolName, handshake.ProtocolVersion, handshake.StreamName)
if err != nil {
_ = s.disconnect(info.ID)
_ = s.host.Network().ClosePeer(info.ID)
return nil, fmt.Errorf("connect new stream: %w", err)
}
......@@ -422,7 +447,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
i, err := s.handshakeService.Handshake(handshakeStream, stream.Conn().RemoteMultiaddr(), stream.Conn().RemotePeer())
if err != nil {
_ = handshakeStream.Reset()
_ = s.disconnect(info.ID)
_ = s.host.Network().ClosePeer(info.ID)
return nil, fmt.Errorf("handshake: %w", err)
}
......@@ -430,19 +455,19 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
if err != nil {
s.logger.Debugf("blocklisting: exists %s: %v", info.ID, err)
s.logger.Errorf("internal error while connecting with peer %s", info.ID)
_ = s.disconnect(info.ID)
_ = s.host.Network().ClosePeer(info.ID)
return nil, fmt.Errorf("peer blocklisted")
}
if blocked {
s.logger.Errorf("blocked connection from blocklisted peer %s", info.ID)
_ = s.disconnect(info.ID)
_ = s.host.Network().ClosePeer(info.ID)
return nil, fmt.Errorf("peer blocklisted")
}
if exists := s.peers.addIfNotExists(stream.Conn(), i.BzzAddress.Overlay); exists {
if err := handshakeStream.FullClose(); err != nil {
_ = s.disconnect(info.ID)
_ = s.Disconnect(i.BzzAddress.Overlay)
return nil, fmt.Errorf("peer exists, full close: %w", err)
}
......@@ -450,45 +475,80 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
}
if err := handshakeStream.FullClose(); err != nil {
_ = s.disconnect(info.ID)
_ = s.Disconnect(i.BzzAddress.Overlay)
return nil, fmt.Errorf("connect full close %w", err)
}
err = s.addressbook.Put(i.BzzAddress.Overlay, *i.BzzAddress)
if err != nil {
_ = s.disconnect(info.ID)
_ = s.Disconnect(i.BzzAddress.Overlay)
return nil, fmt.Errorf("storing bzz address: %w", err)
}
s.protocolsmu.RLock()
for _, tn := range s.protocols {
if tn.ConnectOut != nil {
if err := tn.ConnectOut(ctx, p2p.Peer{Address: i.BzzAddress.Overlay}); err != nil {
s.logger.Debugf("connectOut: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, i.BzzAddress.Overlay, err)
}
}
}
s.protocolsmu.RUnlock()
s.metrics.CreatedConnectionCount.Inc()
s.logger.Infof("successfully connected to peer (outbound) %s", i.BzzAddress.ShortString())
return i.BzzAddress, nil
}
func (s *Service) Disconnect(overlay swarm.Address) error {
peerID, found := s.peers.peerID(overlay)
found, peerID := s.peers.remove(overlay)
if !found {
return p2p.ErrPeerNotFound
}
return s.disconnect(peerID)
}
_ = s.host.Network().ClosePeer(peerID)
func (s *Service) disconnect(peerID libp2ppeer.ID) error {
if err := s.host.Network().ClosePeer(peerID); err != nil {
return err
peer := p2p.Peer{Address: overlay}
s.protocolsmu.RLock()
for _, tn := range s.protocols {
if tn.DisconnectOut != nil {
if err := tn.DisconnectOut(peer); err != nil {
s.logger.Debugf("disconnectOut: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, overlay, err)
}
}
}
s.peers.remove(peerID)
s.protocolsmu.RUnlock()
if s.notifier != nil {
s.notifier.Disconnected(peer)
}
return nil
}
func (s *Service) Peers() []p2p.Peer {
return s.peers.peers()
// disconnected is a registered peer registry event
func (s *Service) disconnected(address swarm.Address) {
peer := p2p.Peer{Address: address}
s.protocolsmu.RLock()
for _, tn := range s.protocols {
if tn.DisconnectIn != nil {
if err := tn.DisconnectIn(peer); err != nil {
s.logger.Debugf("disconnectIn: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, address.String(), err)
}
}
}
s.protocolsmu.RUnlock()
if s.notifier != nil {
s.notifier.Disconnected(peer)
}
}
func (s *Service) AddNotifier(n topology.Notifier) {
s.topologyNotifiers = append(s.topologyNotifiers, n)
s.peers.addDisconnecter(n)
func (s *Service) Peers() []p2p.Peer {
return s.peers.peers()
}
func (s *Service) NewStream(ctx context.Context, overlay swarm.Address, headers p2p.Headers, protocolName, protocolVersion, streamName string) (p2p.Stream, error) {
......
......@@ -70,8 +70,6 @@ func newService(t *testing.T, networkID uint64, o libp2pServiceOpts) (s *libp2p.
t.Fatal(err)
}
s.AddNotifier(noopNotifier)
t.Cleanup(func() {
cancel()
s.Close()
......@@ -138,15 +136,6 @@ func expectPeersEventually(t *testing.T, s *libp2p.Service, addrs ...swarm.Addre
}
}
func expectZeroAddress(t *testing.T, addrs ...swarm.Address) {
t.Helper()
for i, a := range addrs {
if !a.Equal(swarm.ZeroAddress) {
t.Fatalf("address did not equal zero address. index %d", i)
}
}
}
func serviceUnderlayAddress(t *testing.T, s *libp2p.Service) multiaddr.Multiaddr {
t.Helper()
......
......@@ -12,7 +12,6 @@ import (
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
"github.com/libp2p/go-libp2p-core/network"
libp2ppeer "github.com/libp2p/go-libp2p-core/peer"
)
......@@ -25,8 +24,12 @@ type peerRegistry struct {
mu sync.RWMutex
//nolint:misspell
disconnecters []topology.Disconnecter // peerRegistry notifies topology on peer disconnection
network.Notifiee // peerRegistry can be the receiver for network.Notify
disconnecter disconnecter // peerRegistry notifies libp2p on peer disconnection
network.Notifiee // peerRegistry can be the receiver for network.Notify
}
type disconnecter interface {
disconnected(swarm.Address)
}
func newPeerRegistry() *peerRegistry {
......@@ -53,7 +56,7 @@ func (r *peerRegistry) Disconnected(_ network.Network, c network.Conn) {
r.mu.Lock()
// remove only the related connection,
// not eventually newly created one for the same peer
// not eventusally newly created one for the same peer
if _, ok := r.connections[peerID][c]; !ok {
r.mu.Unlock()
return
......@@ -74,14 +77,9 @@ func (r *peerRegistry) Disconnected(_ network.Network, c network.Conn) {
cancel()
}
delete(r.streams, peerID)
r.mu.Unlock()
r.disconnecter.disconnected(overlay)
if len(r.disconnecters) > 0 {
for _, d := range r.disconnecters {
d.Disconnected(overlay)
}
}
}
func (r *peerRegistry) addStream(peerID libp2ppeer.ID, stream network.Stream, cancel context.CancelFunc) {
......@@ -165,9 +163,9 @@ func (r *peerRegistry) overlay(peerID libp2ppeer.ID) (swarm.Address, bool) {
return overlay, found
}
func (r *peerRegistry) remove(peerID libp2ppeer.ID) {
func (r *peerRegistry) remove(overlay swarm.Address) (bool, libp2ppeer.ID) {
r.mu.Lock()
overlay, found := r.overlays[peerID]
peerID, found := r.underlays[overlay.ByteString()]
delete(r.overlays, peerID)
delete(r.underlays, overlay.ByteString())
delete(r.connections, peerID)
......@@ -177,14 +175,9 @@ func (r *peerRegistry) remove(peerID libp2ppeer.ID) {
delete(r.streams, peerID)
r.mu.Unlock()
// if overlay was not found disconnect handler should not be signaled.
if len(r.disconnecters) > 0 && found {
for _, d := range r.disconnecters {
d.Disconnected(overlay)
}
}
return found, peerID
}
func (r *peerRegistry) addDisconnecter(d topology.Disconnecter) {
r.disconnecters = append(r.disconnecters, d)
func (r *peerRegistry) setDisconnecter(d disconnecter) {
r.disconnecter = d
}
......@@ -7,8 +7,10 @@ package libp2p_test
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/multiformats/go-multistream"
......@@ -213,6 +215,81 @@ func TestDisconnectError(t *testing.T) {
expectPeersEventually(t, s1)
}
func TestConnectDisconnectEvents(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s1, overlay1 := newService(t, 1, libp2pServiceOpts{})
s2, _ := newService(t, 1, libp2pServiceOpts{})
testProtocol := newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
return nil
})
cinCount, coutCount, dinCount, doutCount := 0, 0, 0, 0
var countMU sync.Mutex
testProtocol.ConnectIn = func(c context.Context, p p2p.Peer) error {
countMU.Lock()
cinCount++
countMU.Unlock()
return nil
}
testProtocol.ConnectOut = func(c context.Context, p p2p.Peer) error {
countMU.Lock()
coutCount++
countMU.Unlock()
return nil
}
testProtocol.DisconnectIn = func(p p2p.Peer) error {
countMU.Lock()
dinCount++
countMU.Unlock()
return nil
}
testProtocol.DisconnectOut = func(p p2p.Peer) error {
countMU.Lock()
doutCount++
countMU.Unlock()
return nil
}
if err := s1.AddProtocol(testProtocol); err != nil {
t.Fatal(err)
}
if err := s2.AddProtocol(testProtocol); err != nil {
t.Fatal(err)
}
addr := serviceUnderlayAddress(t, s1)
if _, err := s2.Connect(ctx, addr); err != nil {
t.Fatal(err)
}
expectCounter(t, &cinCount, 1, &countMU)
expectCounter(t, &coutCount, 1, &countMU)
expectCounter(t, &dinCount, 0, &countMU)
expectCounter(t, &doutCount, 0, &countMU)
if err := s2.Disconnect(overlay1); err != nil {
t.Fatal(err)
}
cinCount = 0
coutCount = 0
expectCounter(t, &cinCount, 0, &countMU)
expectCounter(t, &coutCount, 0, &countMU)
expectCounter(t, &dinCount, 1, &countMU)
expectCounter(t, &doutCount, 1, &countMU)
}
const (
testProtocolName = "testing"
testProtocolVersion = "2.3.4"
......@@ -259,3 +336,18 @@ func expectErrNotSupported(t *testing.T, err error) {
t.Fatalf("got error %v, want %v", err, multistream.ErrNotSupported)
}
}
func expectCounter(t *testing.T, c *int, expected int, mtx *sync.Mutex) {
for i := 0; i < 20; i++ {
mtx.Lock()
if *c == expected {
mtx.Unlock()
return
}
mtx.Unlock()
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for counter to be set")
}
......@@ -12,7 +12,6 @@ import (
"github.com/ethersphere/bee/pkg/bzz"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
ma "github.com/multiformats/go-multiaddr"
)
......@@ -22,8 +21,8 @@ type Service struct {
connectFunc func(ctx context.Context, addr ma.Multiaddr) (address *bzz.Address, err error)
disconnectFunc func(overlay swarm.Address) error
peersFunc func() []p2p.Peer
addNotifierFunc func(topology.Notifier)
addressesFunc func() ([]ma.Multiaddr, error)
setNotifierFunc func(p2p.Notifier)
setWelcomeMessageFunc func(string) error
getWelcomeMessageFunc func() string
blocklistFunc func(swarm.Address, time.Duration) error
......@@ -37,6 +36,13 @@ func WithAddProtocolFunc(f func(p2p.ProtocolSpec) error) Option {
})
}
// WithSetNotifierFunc sets the mock implementation of the SetNotifier function
func WithSetNotifierFunc(f func(p2p.Notifier)) Option {
return optionFunc(func(s *Service) {
s.setNotifierFunc = f
})
}
// WithConnectFunc sets the mock implementation of the Connect function
func WithConnectFunc(f func(ctx context.Context, addr ma.Multiaddr) (address *bzz.Address, err error)) Option {
return optionFunc(func(s *Service) {
......@@ -58,13 +64,6 @@ func WithPeersFunc(f func() []p2p.Peer) Option {
})
}
// WithAddNotifierFunc sets the mock implementation of the AddNotifier function
func WithAddNotifierFunc(f func(topology.Notifier)) Option {
return optionFunc(func(s *Service) {
s.addNotifierFunc = f
})
}
// WithAddressesFunc sets the mock implementation of the Adresses function
func WithAddressesFunc(f func() ([]ma.Multiaddr, error)) Option {
return optionFunc(func(s *Service) {
......@@ -122,14 +121,6 @@ func (s *Service) Disconnect(overlay swarm.Address) error {
return s.disconnectFunc(overlay)
}
func (s *Service) AddNotifier(f topology.Notifier) {
if s.addNotifierFunc == nil {
return
}
s.addNotifierFunc(f)
}
func (s *Service) Addresses() ([]ma.Multiaddr, error) {
if s.addressesFunc == nil {
return nil, errors.New("function Addresses not configured")
......@@ -166,6 +157,14 @@ func (s *Service) Blocklist(overlay swarm.Address, duration time.Duration) error
return s.blocklistFunc(overlay, duration)
}
func (s *Service) SetNotifier(f p2p.Notifier) {
if s.setNotifierFunc == nil {
return
}
s.setNotifierFunc(f)
}
type Option interface {
apply(*Service)
}
......
......@@ -11,7 +11,6 @@ import (
"github.com/ethersphere/bee/pkg/bzz"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
ma "github.com/multiformats/go-multiaddr"
)
......@@ -25,8 +24,13 @@ type Service interface {
// duration 0 is treated as an infinite duration
Blocklist(overlay swarm.Address, duration time.Duration) error
Peers() []Peer
AddNotifier(topology.Notifier)
Addresses() ([]ma.Multiaddr, error)
SetNotifier(Notifier)
}
type Notifier interface {
Connected(context.Context, Peer) error
Disconnected(Peer)
}
// DebugService extends the Service with method used for debugging.
......@@ -52,9 +56,13 @@ type Stream interface {
// ProtocolSpec defines a collection of Stream specifications with handlers.
type ProtocolSpec struct {
Name string
Version string
StreamSpecs []StreamSpec
Name string
Version string
StreamSpecs []StreamSpec
ConnectIn func(context.Context, Peer) error
ConnectOut func(context.Context, Peer) error
DisconnectIn func(Peer) error
DisconnectOut func(Peer) error
}
// StreamSpec defines a Stream handling within the protocol.
......
......@@ -21,36 +21,16 @@ type Driver interface {
PeerAdder
ClosestPeerer
EachPeerer
Notifier
NeighborhoodDepth() uint8
SubscribePeersChange() (c <-chan struct{}, unsubscribe func())
io.Closer
}
type Notifier interface {
Connecter
Disconnecter
}
type PeerAdder interface {
// AddPeers is called when peers are added to the topology backlog
AddPeers(ctx context.Context, addr ...swarm.Address) error
}
type Connecter interface {
// Connected is called when a peer dials in, or in case explicit
// notification to kademlia on dial out is requested.
Connected(context.Context, swarm.Address) error
}
type Disconnecter interface {
// Disconnected is called when a peer disconnects.
// The disconnect event can be initiated on the local
// node or on the remote node, this handle does not make
// any distinctions between either of them.
Disconnected(swarm.Address)
}
type ClosestPeerer interface {
ClosestPeer(addr swarm.Address) (peerAddr swarm.Address, err error)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment