Commit 598df843 authored by acud's avatar acud Committed by GitHub

feat: limit number of light nodes (#1898)

parent b3aeab29
...@@ -18,6 +18,8 @@ import ( ...@@ -18,6 +18,8 @@ import (
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake" "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake"
"github.com/ethersphere/bee/pkg/statestore/mock" "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/swarm/test"
"github.com/ethersphere/bee/pkg/topology/lightnode"
"github.com/libp2p/go-libp2p-core/mux" "github.com/libp2p/go-libp2p-core/mux"
libp2ppeer "github.com/libp2p/go-libp2p-core/peer" libp2ppeer "github.com/libp2p/go-libp2p-core/peer"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
...@@ -82,6 +84,46 @@ func TestConnectToLightPeer(t *testing.T) { ...@@ -82,6 +84,46 @@ func TestConnectToLightPeer(t *testing.T) {
expectPeersEventually(t, s1) expectPeersEventually(t, s1)
} }
func TestLightPeerLimit(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var (
limit = 3
container = lightnode.NewContainer(test.RandomAddress())
sf, _ = newService(t, 1, libp2pServiceOpts{lightNodes: container,
libp2pOpts: libp2p.Options{
LightNodeLimit: limit,
FullNode: true,
}})
notifier = mockNotifier(noopCf, noopDf, true)
)
sf.SetPickyNotifier(notifier)
addr := serviceUnderlayAddress(t, sf)
for i := 0; i < 5; i++ {
sl, _ := newService(t, 1, libp2pServiceOpts{
libp2pOpts: libp2p.Options{
FullNode: false,
}})
_, err := sl.Connect(ctx, addr)
if err != nil {
t.Fatal(err)
}
}
for i := 0; i < 20; i++ {
if cnt := container.Count(); cnt == limit {
return
}
time.Sleep(50 * time.Millisecond)
}
t.Fatal("timed out waiting for correct number of lightnodes")
}
func TestDoubleConnect(t *testing.T) { func TestDoubleConnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
...@@ -754,3 +796,9 @@ func mockNotifier(c cFunc, d dFunc, pick bool) p2p.PickyNotifier { ...@@ -754,3 +796,9 @@ func mockNotifier(c cFunc, d dFunc, pick bool) p2p.PickyNotifier {
type cFunc func(context.Context, p2p.Peer) error type cFunc func(context.Context, p2p.Peer) error
type dFunc func(p2p.Peer) type dFunc func(p2p.Peer)
var noopCf = func(_ context.Context, _ p2p.Peer) error {
return nil
}
var noopDf = func(p p2p.Peer) {}
...@@ -48,6 +48,8 @@ var ( ...@@ -48,6 +48,8 @@ var (
_ p2p.DebugService = (*Service)(nil) _ p2p.DebugService = (*Service)(nil)
) )
const defaultLightNodeLimit = 100
type Service struct { type Service struct {
ctx context.Context ctx context.Context
host host.Host host host.Host
...@@ -68,12 +70,15 @@ type Service struct { ...@@ -68,12 +70,15 @@ type Service struct {
tracer *tracing.Tracer tracer *tracing.Tracer
ready chan struct{} ready chan struct{}
lightNodes lightnodes lightNodes lightnodes
lightNodeLimit int
protocolsmu sync.RWMutex protocolsmu sync.RWMutex
} }
type lightnodes interface { type lightnodes interface {
Connected(context.Context, p2p.Peer) Connected(context.Context, p2p.Peer)
Disconnected(p2p.Peer) Disconnected(p2p.Peer)
Count() int
RandomPeer(swarm.Address) (swarm.Address, error)
} }
type Options struct { type Options struct {
...@@ -83,6 +88,7 @@ type Options struct { ...@@ -83,6 +88,7 @@ type Options struct {
EnableQUIC bool EnableQUIC bool
Standalone bool Standalone bool
FullNode bool FullNode bool
LightNodeLimit int
WelcomeMessage string WelcomeMessage string
Transaction []byte Transaction []byte
} }
...@@ -238,6 +244,11 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -238,6 +244,11 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
peerRegistry.setDisconnecter(s) peerRegistry.setDisconnecter(s)
s.lightNodeLimit = defaultLightNodeLimit
if o.LightNodeLimit > 0 {
s.lightNodeLimit = o.LightNodeLimit
}
// Construct protocols. // Construct protocols.
id := protocol.ID(p2p.NewSwarmStreamName(handshake.ProtocolName, handshake.ProtocolVersion, handshake.StreamName)) id := protocol.ID(p2p.NewSwarmStreamName(handshake.ProtocolName, handshake.ProtocolVersion, handshake.StreamName))
matcher, err := s.protocolSemverMatcher(id) matcher, err := s.protocolSemverMatcher(id)
...@@ -245,8 +256,18 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -245,8 +256,18 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
return nil, fmt.Errorf("protocol version match %s: %w", id, err) return nil, fmt.Errorf("protocol version match %s: %w", id, err)
} }
// handshake s.host.SetStreamHandlerMatch(id, matcher, s.handleIncoming)
s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) {
h.Network().SetConnHandler(func(_ network.Conn) {
s.metrics.HandledConnectionCount.Inc()
})
h.Network().Notify(peerRegistry) // update peer registry on network events
h.Network().Notify(s.handshakeService) // update handshake service on network events
return s, nil
}
func (s *Service) handleIncoming(stream network.Stream) {
select { select {
case <-s.ready: case <-s.ready:
case <-s.ctx.Done(): case <-s.ctx.Done():
...@@ -254,7 +275,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -254,7 +275,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
} }
peerID := stream.Conn().RemotePeer() peerID := stream.Conn().RemotePeer()
handshakeStream := NewStream(stream) handshakeStream := NewStream(stream)
i, err := s.handshakeService.Handle(ctx, handshakeStream, stream.Conn().RemoteMultiaddr(), peerID) i, err := s.handshakeService.Handle(s.ctx, handshakeStream, stream.Conn().RemoteMultiaddr(), peerID)
if err != nil { if err != nil {
s.logger.Debugf("stream handler: handshake: handle %s: %v", peerID, err) s.logger.Debugf("stream handler: handshake: handle %s: %v", peerID, err)
s.logger.Errorf("stream handler: handshake: unable to handshake with peer id %v", peerID) s.logger.Errorf("stream handler: handshake: unable to handshake with peer id %v", peerID)
...@@ -322,7 +343,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -322,7 +343,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
s.protocolsmu.RLock() s.protocolsmu.RLock()
for _, tn := range s.protocols { for _, tn := range s.protocols {
if tn.ConnectIn != nil { if tn.ConnectIn != nil {
if err := tn.ConnectIn(ctx, peer); err != nil { if err := tn.ConnectIn(s.ctx, peer); err != nil {
s.logger.Debugf("stream handler: connectIn: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, overlay, err) s.logger.Debugf("stream handler: connectIn: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, overlay, err)
_ = s.Disconnect(overlay) _ = s.Disconnect(overlay)
s.protocolsmu.RUnlock() s.protocolsmu.RUnlock()
...@@ -334,12 +355,27 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -334,12 +355,27 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
if s.notifier != nil { if s.notifier != nil {
if !i.FullNode { if !i.FullNode {
s.lightNodes.Connected(ctx, peer) s.lightNodes.Connected(s.ctx, peer)
//light node announces explicitly //light node announces explicitly
if err := s.notifier.Announce(ctx, peer.Address, i.FullNode); err != nil { if err := s.notifier.Announce(s.ctx, peer.Address, i.FullNode); err != nil {
s.logger.Debugf("stream handler: notifier.Announce: %s: %v", peer.Address.String(), err) s.logger.Debugf("stream handler: notifier.Announce: %s: %v", peer.Address.String(), err)
} }
} else if err := s.notifier.Connected(ctx, peer); err != nil { // full node announces implicitly
if s.lightNodes.Count() > s.lightNodeLimit {
// kick another node to fit this one in
p, err := s.lightNodes.RandomPeer(peer.Address)
if err != nil {
s.logger.Debugf("stream handler: cant find a peer slot for light node: %v", err)
_ = s.Disconnect(peer.Address)
return
} else {
s.logger.Tracef("stream handler: kicking away light node %s to make room for %s", p.String(), peer.Address.String())
_ = s.Disconnect(p)
return
}
}
} else if err := s.notifier.Connected(s.ctx, peer); err != nil {
// full node announces implicitly
s.logger.Debugf("stream handler: notifier.Connected: peer disconnected: %s: %v", i.BzzAddress.Overlay, err) s.logger.Debugf("stream handler: notifier.Connected: peer disconnected: %s: %v", i.BzzAddress.Overlay, err)
// note: this cannot be unit tested since the node // note: this cannot be unit tested since the node
// waiting on handshakeStream.FullClose() on the other side // waiting on handshakeStream.FullClose() on the other side
...@@ -364,15 +400,6 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -364,15 +400,6 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
s.logger.Debugf("stream handler: successfully connected to peer %s%s (inbound)", i.BzzAddress.ShortString(), i.LightString()) s.logger.Debugf("stream handler: successfully connected to peer %s%s (inbound)", i.BzzAddress.ShortString(), i.LightString())
s.logger.Infof("stream handler: successfully connected to peer %s%s (inbound)", i.BzzAddress.Overlay, i.LightString()) s.logger.Infof("stream handler: successfully connected to peer %s%s (inbound)", i.BzzAddress.Overlay, i.LightString())
})
h.Network().SetConnHandler(func(_ network.Conn) {
s.metrics.HandledConnectionCount.Inc()
})
h.Network().Notify(peerRegistry) // update peer registry on network events
h.Network().Notify(s.handshakeService) // update handshake service on network events
return s, nil
} }
func (s *Service) SetPickyNotifier(n p2p.PickyNotifier) { func (s *Service) SetPickyNotifier(n p2p.PickyNotifier) {
......
...@@ -31,6 +31,7 @@ type libp2pServiceOpts struct { ...@@ -31,6 +31,7 @@ type libp2pServiceOpts struct {
PrivateKey *ecdsa.PrivateKey PrivateKey *ecdsa.PrivateKey
MockPeerKey *ecdsa.PrivateKey MockPeerKey *ecdsa.PrivateKey
libp2pOpts libp2p.Options libp2pOpts libp2p.Options
lightNodes *lightnode.Container
} }
// newService constructs a new libp2p service. // newService constructs a new libp2p service.
...@@ -69,14 +70,15 @@ func newService(t *testing.T, networkID uint64, o libp2pServiceOpts) (s *libp2p. ...@@ -69,14 +70,15 @@ func newService(t *testing.T, networkID uint64, o libp2pServiceOpts) (s *libp2p.
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
lightnodes := lightnode.NewContainer(overlay) if o.lightNodes == nil {
o.lightNodes = lightnode.NewContainer(overlay)
}
opts := o.libp2pOpts opts := o.libp2pOpts
opts.Transaction = []byte(hexutil.EncodeUint64(o.PrivateKey.Y.Uint64())) opts.Transaction = []byte(hexutil.EncodeUint64(o.PrivateKey.Y.Uint64()))
senderMatcher := &MockSenderMatcher{} senderMatcher := &MockSenderMatcher{}
s, err = libp2p.New(ctx, crypto.NewDefaultSigner(swarmKey), networkID, overlay, addr, o.Addressbook, statestore, lightnodes, senderMatcher, o.Logger, nil, opts) s, err = libp2p.New(ctx, crypto.NewDefaultSigner(swarmKey), networkID, overlay, addr, o.Addressbook, statestore, o.lightNodes, senderMatcher, o.Logger, nil, opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -6,6 +6,8 @@ package lightnode ...@@ -6,6 +6,8 @@ package lightnode
import ( import (
"context" "context"
"crypto/rand"
"math/big"
"sync" "sync"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
...@@ -49,6 +51,43 @@ func (c *Container) Disconnected(peer p2p.Peer) { ...@@ -49,6 +51,43 @@ func (c *Container) Disconnected(peer p2p.Peer) {
} }
} }
func (c *Container) Count() int {
return c.connectedPeers.Length()
}
func (c *Container) RandomPeer(not swarm.Address) (swarm.Address, error) {
c.peerMu.Lock()
defer c.peerMu.Unlock()
var (
cnt = big.NewInt(int64(c.Count()))
addr = swarm.ZeroAddress
count = int64(0)
)
PICKPEER:
i, e := rand.Int(rand.Reader, cnt)
if e != nil {
return swarm.ZeroAddress, e
}
i64 := i.Int64()
count = 0
_ = c.connectedPeers.EachBinRev(func(peer swarm.Address, _ uint8) (bool, bool, error) {
if count == i64 {
addr = peer
return true, false, nil
}
count++
return false, false, nil
})
if addr.Equal(not) {
goto PICKPEER
}
return addr, nil
}
func (c *Container) PeerInfo() topology.BinInfo { func (c *Container) PeerInfo() topology.BinInfo {
return topology.BinInfo{ return topology.BinInfo{
BinPopulation: uint(c.connectedPeers.Length()), BinPopulation: uint(c.connectedPeers.Length()),
......
...@@ -33,14 +33,35 @@ func TestContainer(t *testing.T) { ...@@ -33,14 +33,35 @@ func TestContainer(t *testing.T) {
t.Run("can add peers to container", func(t *testing.T) { t.Run("can add peers to container", func(t *testing.T) {
c := lightnode.NewContainer(base) c := lightnode.NewContainer(base)
c.Connected(context.Background(), p2p.Peer{Address: swarm.NewAddress([]byte("123"))}) p1 := swarm.NewAddress([]byte("123"))
c.Connected(context.Background(), p2p.Peer{Address: swarm.NewAddress([]byte("456"))}) p2 := swarm.NewAddress([]byte("456"))
c.Connected(context.Background(), p2p.Peer{Address: p1})
c.Connected(context.Background(), p2p.Peer{Address: p2})
peerCount := len(c.PeerInfo().ConnectedPeers) peerCount := len(c.PeerInfo().ConnectedPeers)
if peerCount != 2 { if peerCount != 2 {
t.Errorf("expected %d connected peer, got %d", 2, peerCount) t.Errorf("expected %d connected peer, got %d", 2, peerCount)
} }
if cc := c.Count(); cc != 2 {
t.Errorf("expected count 2 got %d", cc)
}
p, err := c.RandomPeer(p1)
if err != nil {
t.Fatal(err)
}
if !p.Equal(p2) {
t.Fatalf("expected p1 but got %s", p.String())
}
p, err = c.RandomPeer(p2)
if err != nil {
t.Fatal(err)
}
if !p.Equal(p1) {
t.Fatalf("expected p2 but got %s", p.String())
}
}) })
t.Run("empty container after peer disconnect", func(t *testing.T) { t.Run("empty container after peer disconnect", func(t *testing.T) {
c := lightnode.NewContainer(base) c := lightnode.NewContainer(base)
...@@ -59,5 +80,8 @@ func TestContainer(t *testing.T) { ...@@ -59,5 +80,8 @@ func TestContainer(t *testing.T) {
if connPeerCount != 0 { if connPeerCount != 0 {
t.Errorf("expected %d connected peer, got %d", 0, connPeerCount) t.Errorf("expected %d connected peer, got %d", 0, connPeerCount)
} }
if cc := c.Count(); cc != 0 {
t.Errorf("expected count 0 got %d", cc)
}
}) })
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment