Commit 0a1ad9a7 authored by acud's avatar acud Committed by GitHub

fix: add full node info to peer (#1741)

parent 18682dcf
......@@ -399,6 +399,7 @@ func TestTopologyNotifier(t *testing.T) {
mtx.Lock()
defer mtx.Unlock()
expectZeroAddress(t, n1connectedPeer.Address) // fail if set more than once
expectFullNode(t, p)
n1connectedPeer = p
return nil
}
......@@ -413,6 +414,7 @@ func TestTopologyNotifier(t *testing.T) {
defer mtx.Unlock()
expectZeroAddress(t, n2connectedPeer.Address) // fail if set more than once
n2connectedPeer = p
expectFullNode(t, p)
return nil
}
n2d = func(p p2p.Peer) {
......@@ -679,6 +681,12 @@ func expectStreamReset(t *testing.T, s io.ReadCloser, err error) {
}
}
}
func expectFullNode(t *testing.T, p p2p.Peer) {
t.Helper()
if !p.FullNode {
t.Fatal("expected peer to be a full node")
}
}
func expectZeroAddress(t *testing.T, addrs ...swarm.Address) {
t.Helper()
......
......@@ -282,7 +282,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
}
if s.notifier != nil {
if !s.notifier.Pick(p2p.Peer{Address: overlay}) {
if !s.notifier.Pick(p2p.Peer{Address: overlay, FullNode: i.FullNode}) {
s.logger.Warningf("stream handler: don't want incoming peer %s. disconnecting", overlay)
_ = handshakeStream.Reset()
_ = s.host.Network().ClosePeer(peerID)
......@@ -290,7 +290,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
}
}
if exists := s.peers.addIfNotExists(stream.Conn(), overlay); exists {
if exists := s.peers.addIfNotExists(stream.Conn(), overlay, i.FullNode); exists {
s.logger.Debugf("stream handler: peer %s already exists", overlay)
if err = handshakeStream.FullClose(); err != nil {
s.logger.Debugf("stream handler: could not close stream %s: %v", overlay, err)
......@@ -317,7 +317,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
}
}
peer := p2p.Peer{Address: overlay}
peer := p2p.Peer{Address: overlay, FullNode: i.FullNode}
s.protocolsmu.RLock()
for _, tn := range s.protocols {
......@@ -396,6 +396,12 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
s.logger.Debugf("overlay address for peer %q not found", peerID)
return
}
full, found := s.peers.fullnode(peerID)
if !found {
_ = streamlibp2p.Reset()
s.logger.Debugf("fullnode info for peer %q not found", peerID)
return
}
stream := newStream(streamlibp2p)
......@@ -423,7 +429,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
logger := tracing.NewLoggerWithTraceID(ctx, s.logger)
s.metrics.HandledStreamCount.Inc()
if err := ss.Handler(ctx, p2p.Peer{Address: overlay}, stream); err != nil {
if err := ss.Handler(ctx, p2p.Peer{Address: overlay, FullNode: full}, stream); err != nil {
var de *p2p.DisconnectError
if errors.As(err, &de) {
_ = stream.Reset()
......@@ -573,7 +579,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
return nil, fmt.Errorf("peer blocklisted")
}
if exists := s.peers.addIfNotExists(stream.Conn(), overlay); exists {
if exists := s.peers.addIfNotExists(stream.Conn(), overlay, i.FullNode); exists {
if err := handshakeStream.FullClose(); err != nil {
_ = s.Disconnect(overlay)
return nil, fmt.Errorf("peer exists, full close: %w", err)
......@@ -598,7 +604,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
s.protocolsmu.RLock()
for _, tn := range s.protocols {
if tn.ConnectOut != nil {
if err := tn.ConnectOut(ctx, p2p.Peer{Address: overlay}); err != nil {
if err := tn.ConnectOut(ctx, p2p.Peer{Address: overlay, FullNode: i.FullNode}); err != nil {
s.logger.Debugf("connectOut: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, overlay, err)
_ = s.Disconnect(overlay)
s.protocolsmu.RUnlock()
......@@ -625,11 +631,12 @@ func (s *Service) Disconnect(overlay swarm.Address) error {
s.logger.Debugf("libp2p disconnect: disconnecting peer %s", overlay)
found, peerID := s.peers.remove(overlay)
// found is checked at the bottom of the function
found, full, peerID := s.peers.remove(overlay)
_ = s.host.Network().ClosePeer(peerID)
peer := p2p.Peer{Address: overlay}
peer := p2p.Peer{Address: overlay, FullNode: full}
s.protocolsmu.RLock()
for _, tn := range s.protocols {
......@@ -660,6 +667,15 @@ func (s *Service) Disconnect(overlay swarm.Address) error {
func (s *Service) disconnected(address swarm.Address) {
peer := p2p.Peer{Address: address}
peerID, found := s.peers.peerID(address)
if !found {
s.logger.Debugf("libp2p disconnected: cannot find peerID for overlay: %s", address.String())
} else {
full, found := s.peers.fullnode(peerID)
if found {
peer.FullNode = full
}
}
s.protocolsmu.RLock()
for _, tn := range s.protocols {
if tn.DisconnectIn != nil {
......
......@@ -20,6 +20,7 @@ import (
type peerRegistry struct {
underlays map[string]libp2ppeer.ID // map overlay address to underlay peer id
overlays map[libp2ppeer.ID]swarm.Address // map underlay peer id to overlay address
full map[libp2ppeer.ID]bool // map to track whether a node is full or light node (true=full)
connections map[libp2ppeer.ID]map[network.Conn]struct{} // list of connections for safe removal on Disconnect notification
streams map[libp2ppeer.ID]map[network.Stream]context.CancelFunc
mu sync.RWMutex
......@@ -37,6 +38,7 @@ func newPeerRegistry() *peerRegistry {
return &peerRegistry{
underlays: make(map[string]libp2ppeer.ID),
overlays: make(map[libp2ppeer.ID]swarm.Address),
full: make(map[libp2ppeer.ID]bool),
connections: make(map[libp2ppeer.ID]map[network.Conn]struct{}),
streams: make(map[libp2ppeer.ID]map[network.Stream]context.CancelFunc),
......@@ -78,6 +80,7 @@ func (r *peerRegistry) Disconnected(_ network.Network, c network.Conn) {
cancel()
}
delete(r.streams, peerID)
delete(r.full, peerID)
r.mu.Unlock()
r.disconnecter.disconnected(overlay)
......@@ -115,9 +118,10 @@ func (r *peerRegistry) removeStream(peerID libp2ppeer.ID, stream network.Stream)
func (r *peerRegistry) peers() []p2p.Peer {
r.mu.RLock()
peers := make([]p2p.Peer, 0, len(r.overlays))
for _, a := range r.overlays {
for p, a := range r.overlays {
peers = append(peers, p2p.Peer{
Address: a,
FullNode: r.full[p],
})
}
r.mu.RUnlock()
......@@ -127,7 +131,7 @@ func (r *peerRegistry) peers() []p2p.Peer {
return peers
}
func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address) (exists bool) {
func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address, full bool) (exists bool) {
peerID := c.RemotePeer()
r.mu.Lock()
defer r.mu.Unlock()
......@@ -146,6 +150,7 @@ func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address) (ex
r.streams[peerID] = make(map[network.Stream]context.CancelFunc)
r.underlays[overlay.ByteString()] = peerID
r.overlays[peerID] = overlay
r.full[peerID] = full
return false
}
......@@ -164,6 +169,13 @@ func (r *peerRegistry) overlay(peerID libp2ppeer.ID) (swarm.Address, bool) {
return overlay, found
}
func (r *peerRegistry) fullnode(peerID libp2ppeer.ID) (bool, bool) {
r.mu.RLock()
full, found := r.full[peerID]
r.mu.RUnlock()
return full, found
}
func (r *peerRegistry) isConnected(peerID libp2ppeer.ID, remoteAddr ma.Multiaddr) (swarm.Address, bool) {
if remoteAddr == nil {
return swarm.ZeroAddress, false
......@@ -193,9 +205,9 @@ func (r *peerRegistry) isConnected(peerID libp2ppeer.ID, remoteAddr ma.Multiaddr
return swarm.ZeroAddress, false
}
func (r *peerRegistry) remove(overlay swarm.Address) (bool, libp2ppeer.ID) {
func (r *peerRegistry) remove(overlay swarm.Address) (found, full bool, peerID libp2ppeer.ID) {
r.mu.Lock()
peerID, found := r.underlays[overlay.ByteString()]
peerID, found = r.underlays[overlay.ByteString()]
delete(r.overlays, peerID)
delete(r.underlays, overlay.ByteString())
delete(r.connections, peerID)
......@@ -203,9 +215,11 @@ func (r *peerRegistry) remove(overlay swarm.Address) (bool, libp2ppeer.ID) {
cancel()
}
delete(r.streams, peerID)
full = r.full[peerID]
delete(r.full, peerID)
r.mu.Unlock()
return found, peerID
return found, full, peerID
}
func (r *peerRegistry) setDisconnecter(d disconnecter) {
......
......@@ -27,7 +27,81 @@ func TestNewStream(t *testing.T) {
s2, _ := newService(t, 1, libp2pServiceOpts{})
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
return nil
})); err != nil {
t.Fatal(err)
}
addr := serviceUnderlayAddress(t, s1)
if _, err := s2.Connect(ctx, addr); err != nil {
t.Fatal(err)
}
stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
if err := stream.Close(); err != nil {
t.Fatal(err)
}
}
// TestNewStream_OnlyFull tests that the handler gets the full
// node information communicated correctly.
func TestNewStream_OnlyFull(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
FullNode: true,
}})
s2, _ := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
FullNode: true,
}})
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
if !p.FullNode {
t.Error("expected full node")
}
return nil
})); err != nil {
t.Fatal(err)
}
addr := serviceUnderlayAddress(t, s1)
if _, err := s2.Connect(ctx, addr); err != nil {
t.Fatal(err)
}
stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
if err := stream.Close(); err != nil {
t.Fatal(err)
}
}
// TestNewStream_Mixed tests that the handler gets the full
// node information communicated correctly for light node
func TestNewStream_Mixed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
FullNode: true,
}})
s2, _ := newService(t, 1, libp2pServiceOpts{})
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
if p.FullNode {
t.Error("expected light node")
}
return nil
})); err != nil {
t.Fatal(err)
......@@ -61,12 +135,12 @@ func TestNewStreamMulti(t *testing.T) {
var (
h1calls, h2calls int32
h1 = func(_ context.Context, _ p2p.Peer, s p2p.Stream) error {
h1 = func(_ context.Context, p p2p.Peer, s p2p.Stream) error {
defer s.Close()
_ = atomic.AddInt32(&h1calls, 1)
return nil
}
h2 = func(_ context.Context, _ p2p.Peer, s p2p.Stream) error {
h2 = func(_ context.Context, p p2p.Peer, s p2p.Stream) error {
defer s.Close()
_ = atomic.AddInt32(&h2calls, 1)
return nil
......
......@@ -95,6 +95,7 @@ type StreamSpec struct {
// Peer holds information about a Peer.
type Peer struct {
Address swarm.Address `json:"address"`
FullNode bool `json:"fullNode"`
}
// HandlerFunc handles a received Stream from a Peer.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment