Commit 0a1ad9a7 authored by acud's avatar acud Committed by GitHub

fix: add full node info to peer (#1741)

parent 18682dcf
...@@ -399,6 +399,7 @@ func TestTopologyNotifier(t *testing.T) { ...@@ -399,6 +399,7 @@ func TestTopologyNotifier(t *testing.T) {
mtx.Lock() mtx.Lock()
defer mtx.Unlock() defer mtx.Unlock()
expectZeroAddress(t, n1connectedPeer.Address) // fail if set more than once expectZeroAddress(t, n1connectedPeer.Address) // fail if set more than once
expectFullNode(t, p)
n1connectedPeer = p n1connectedPeer = p
return nil return nil
} }
...@@ -413,6 +414,7 @@ func TestTopologyNotifier(t *testing.T) { ...@@ -413,6 +414,7 @@ func TestTopologyNotifier(t *testing.T) {
defer mtx.Unlock() defer mtx.Unlock()
expectZeroAddress(t, n2connectedPeer.Address) // fail if set more than once expectZeroAddress(t, n2connectedPeer.Address) // fail if set more than once
n2connectedPeer = p n2connectedPeer = p
expectFullNode(t, p)
return nil return nil
} }
n2d = func(p p2p.Peer) { n2d = func(p p2p.Peer) {
...@@ -679,6 +681,12 @@ func expectStreamReset(t *testing.T, s io.ReadCloser, err error) { ...@@ -679,6 +681,12 @@ func expectStreamReset(t *testing.T, s io.ReadCloser, err error) {
} }
} }
} }
func expectFullNode(t *testing.T, p p2p.Peer) {
t.Helper()
if !p.FullNode {
t.Fatal("expected peer to be a full node")
}
}
func expectZeroAddress(t *testing.T, addrs ...swarm.Address) { func expectZeroAddress(t *testing.T, addrs ...swarm.Address) {
t.Helper() t.Helper()
......
...@@ -282,7 +282,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -282,7 +282,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
} }
if s.notifier != nil { if s.notifier != nil {
if !s.notifier.Pick(p2p.Peer{Address: overlay}) { if !s.notifier.Pick(p2p.Peer{Address: overlay, FullNode: i.FullNode}) {
s.logger.Warningf("stream handler: don't want incoming peer %s. disconnecting", overlay) s.logger.Warningf("stream handler: don't want incoming peer %s. disconnecting", overlay)
_ = handshakeStream.Reset() _ = handshakeStream.Reset()
_ = s.host.Network().ClosePeer(peerID) _ = s.host.Network().ClosePeer(peerID)
...@@ -290,7 +290,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -290,7 +290,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
} }
} }
if exists := s.peers.addIfNotExists(stream.Conn(), overlay); exists { if exists := s.peers.addIfNotExists(stream.Conn(), overlay, i.FullNode); exists {
s.logger.Debugf("stream handler: peer %s already exists", overlay) s.logger.Debugf("stream handler: peer %s already exists", overlay)
if err = handshakeStream.FullClose(); err != nil { if err = handshakeStream.FullClose(); err != nil {
s.logger.Debugf("stream handler: could not close stream %s: %v", overlay, err) s.logger.Debugf("stream handler: could not close stream %s: %v", overlay, err)
...@@ -317,7 +317,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -317,7 +317,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
} }
} }
peer := p2p.Peer{Address: overlay} peer := p2p.Peer{Address: overlay, FullNode: i.FullNode}
s.protocolsmu.RLock() s.protocolsmu.RLock()
for _, tn := range s.protocols { for _, tn := range s.protocols {
...@@ -396,6 +396,12 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) { ...@@ -396,6 +396,12 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
s.logger.Debugf("overlay address for peer %q not found", peerID) s.logger.Debugf("overlay address for peer %q not found", peerID)
return return
} }
full, found := s.peers.fullnode(peerID)
if !found {
_ = streamlibp2p.Reset()
s.logger.Debugf("fullnode info for peer %q not found", peerID)
return
}
stream := newStream(streamlibp2p) stream := newStream(streamlibp2p)
...@@ -423,7 +429,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) { ...@@ -423,7 +429,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
logger := tracing.NewLoggerWithTraceID(ctx, s.logger) logger := tracing.NewLoggerWithTraceID(ctx, s.logger)
s.metrics.HandledStreamCount.Inc() s.metrics.HandledStreamCount.Inc()
if err := ss.Handler(ctx, p2p.Peer{Address: overlay}, stream); err != nil { if err := ss.Handler(ctx, p2p.Peer{Address: overlay, FullNode: full}, stream); err != nil {
var de *p2p.DisconnectError var de *p2p.DisconnectError
if errors.As(err, &de) { if errors.As(err, &de) {
_ = stream.Reset() _ = stream.Reset()
...@@ -573,7 +579,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz. ...@@ -573,7 +579,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
return nil, fmt.Errorf("peer blocklisted") return nil, fmt.Errorf("peer blocklisted")
} }
if exists := s.peers.addIfNotExists(stream.Conn(), overlay); exists { if exists := s.peers.addIfNotExists(stream.Conn(), overlay, i.FullNode); exists {
if err := handshakeStream.FullClose(); err != nil { if err := handshakeStream.FullClose(); err != nil {
_ = s.Disconnect(overlay) _ = s.Disconnect(overlay)
return nil, fmt.Errorf("peer exists, full close: %w", err) return nil, fmt.Errorf("peer exists, full close: %w", err)
...@@ -598,7 +604,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz. ...@@ -598,7 +604,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
s.protocolsmu.RLock() s.protocolsmu.RLock()
for _, tn := range s.protocols { for _, tn := range s.protocols {
if tn.ConnectOut != nil { if tn.ConnectOut != nil {
if err := tn.ConnectOut(ctx, p2p.Peer{Address: overlay}); err != nil { if err := tn.ConnectOut(ctx, p2p.Peer{Address: overlay, FullNode: i.FullNode}); err != nil {
s.logger.Debugf("connectOut: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, overlay, err) s.logger.Debugf("connectOut: protocol: %s, version:%s, peer: %s: %v", tn.Name, tn.Version, overlay, err)
_ = s.Disconnect(overlay) _ = s.Disconnect(overlay)
s.protocolsmu.RUnlock() s.protocolsmu.RUnlock()
...@@ -625,11 +631,12 @@ func (s *Service) Disconnect(overlay swarm.Address) error { ...@@ -625,11 +631,12 @@ func (s *Service) Disconnect(overlay swarm.Address) error {
s.logger.Debugf("libp2p disconnect: disconnecting peer %s", overlay) s.logger.Debugf("libp2p disconnect: disconnecting peer %s", overlay)
found, peerID := s.peers.remove(overlay) // found is checked at the bottom of the function
found, full, peerID := s.peers.remove(overlay)
_ = s.host.Network().ClosePeer(peerID) _ = s.host.Network().ClosePeer(peerID)
peer := p2p.Peer{Address: overlay} peer := p2p.Peer{Address: overlay, FullNode: full}
s.protocolsmu.RLock() s.protocolsmu.RLock()
for _, tn := range s.protocols { for _, tn := range s.protocols {
...@@ -660,6 +667,15 @@ func (s *Service) Disconnect(overlay swarm.Address) error { ...@@ -660,6 +667,15 @@ func (s *Service) Disconnect(overlay swarm.Address) error {
func (s *Service) disconnected(address swarm.Address) { func (s *Service) disconnected(address swarm.Address) {
peer := p2p.Peer{Address: address} peer := p2p.Peer{Address: address}
peerID, found := s.peers.peerID(address)
if !found {
s.logger.Debugf("libp2p disconnected: cannot find peerID for overlay: %s", address.String())
} else {
full, found := s.peers.fullnode(peerID)
if found {
peer.FullNode = full
}
}
s.protocolsmu.RLock() s.protocolsmu.RLock()
for _, tn := range s.protocols { for _, tn := range s.protocols {
if tn.DisconnectIn != nil { if tn.DisconnectIn != nil {
......
...@@ -20,6 +20,7 @@ import ( ...@@ -20,6 +20,7 @@ import (
type peerRegistry struct { type peerRegistry struct {
underlays map[string]libp2ppeer.ID // map overlay address to underlay peer id underlays map[string]libp2ppeer.ID // map overlay address to underlay peer id
overlays map[libp2ppeer.ID]swarm.Address // map underlay peer id to overlay address overlays map[libp2ppeer.ID]swarm.Address // map underlay peer id to overlay address
full map[libp2ppeer.ID]bool // map to track whether a node is full or light node (true=full)
connections map[libp2ppeer.ID]map[network.Conn]struct{} // list of connections for safe removal on Disconnect notification connections map[libp2ppeer.ID]map[network.Conn]struct{} // list of connections for safe removal on Disconnect notification
streams map[libp2ppeer.ID]map[network.Stream]context.CancelFunc streams map[libp2ppeer.ID]map[network.Stream]context.CancelFunc
mu sync.RWMutex mu sync.RWMutex
...@@ -37,6 +38,7 @@ func newPeerRegistry() *peerRegistry { ...@@ -37,6 +38,7 @@ func newPeerRegistry() *peerRegistry {
return &peerRegistry{ return &peerRegistry{
underlays: make(map[string]libp2ppeer.ID), underlays: make(map[string]libp2ppeer.ID),
overlays: make(map[libp2ppeer.ID]swarm.Address), overlays: make(map[libp2ppeer.ID]swarm.Address),
full: make(map[libp2ppeer.ID]bool),
connections: make(map[libp2ppeer.ID]map[network.Conn]struct{}), connections: make(map[libp2ppeer.ID]map[network.Conn]struct{}),
streams: make(map[libp2ppeer.ID]map[network.Stream]context.CancelFunc), streams: make(map[libp2ppeer.ID]map[network.Stream]context.CancelFunc),
...@@ -78,6 +80,7 @@ func (r *peerRegistry) Disconnected(_ network.Network, c network.Conn) { ...@@ -78,6 +80,7 @@ func (r *peerRegistry) Disconnected(_ network.Network, c network.Conn) {
cancel() cancel()
} }
delete(r.streams, peerID) delete(r.streams, peerID)
delete(r.full, peerID)
r.mu.Unlock() r.mu.Unlock()
r.disconnecter.disconnected(overlay) r.disconnecter.disconnected(overlay)
...@@ -115,9 +118,10 @@ func (r *peerRegistry) removeStream(peerID libp2ppeer.ID, stream network.Stream) ...@@ -115,9 +118,10 @@ func (r *peerRegistry) removeStream(peerID libp2ppeer.ID, stream network.Stream)
func (r *peerRegistry) peers() []p2p.Peer { func (r *peerRegistry) peers() []p2p.Peer {
r.mu.RLock() r.mu.RLock()
peers := make([]p2p.Peer, 0, len(r.overlays)) peers := make([]p2p.Peer, 0, len(r.overlays))
for _, a := range r.overlays { for p, a := range r.overlays {
peers = append(peers, p2p.Peer{ peers = append(peers, p2p.Peer{
Address: a, Address: a,
FullNode: r.full[p],
}) })
} }
r.mu.RUnlock() r.mu.RUnlock()
...@@ -127,7 +131,7 @@ func (r *peerRegistry) peers() []p2p.Peer { ...@@ -127,7 +131,7 @@ func (r *peerRegistry) peers() []p2p.Peer {
return peers return peers
} }
func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address) (exists bool) { func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address, full bool) (exists bool) {
peerID := c.RemotePeer() peerID := c.RemotePeer()
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
...@@ -146,6 +150,7 @@ func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address) (ex ...@@ -146,6 +150,7 @@ func (r *peerRegistry) addIfNotExists(c network.Conn, overlay swarm.Address) (ex
r.streams[peerID] = make(map[network.Stream]context.CancelFunc) r.streams[peerID] = make(map[network.Stream]context.CancelFunc)
r.underlays[overlay.ByteString()] = peerID r.underlays[overlay.ByteString()] = peerID
r.overlays[peerID] = overlay r.overlays[peerID] = overlay
r.full[peerID] = full
return false return false
} }
...@@ -164,6 +169,13 @@ func (r *peerRegistry) overlay(peerID libp2ppeer.ID) (swarm.Address, bool) { ...@@ -164,6 +169,13 @@ func (r *peerRegistry) overlay(peerID libp2ppeer.ID) (swarm.Address, bool) {
return overlay, found return overlay, found
} }
func (r *peerRegistry) fullnode(peerID libp2ppeer.ID) (bool, bool) {
r.mu.RLock()
full, found := r.full[peerID]
r.mu.RUnlock()
return full, found
}
func (r *peerRegistry) isConnected(peerID libp2ppeer.ID, remoteAddr ma.Multiaddr) (swarm.Address, bool) { func (r *peerRegistry) isConnected(peerID libp2ppeer.ID, remoteAddr ma.Multiaddr) (swarm.Address, bool) {
if remoteAddr == nil { if remoteAddr == nil {
return swarm.ZeroAddress, false return swarm.ZeroAddress, false
...@@ -193,9 +205,9 @@ func (r *peerRegistry) isConnected(peerID libp2ppeer.ID, remoteAddr ma.Multiaddr ...@@ -193,9 +205,9 @@ func (r *peerRegistry) isConnected(peerID libp2ppeer.ID, remoteAddr ma.Multiaddr
return swarm.ZeroAddress, false return swarm.ZeroAddress, false
} }
func (r *peerRegistry) remove(overlay swarm.Address) (bool, libp2ppeer.ID) { func (r *peerRegistry) remove(overlay swarm.Address) (found, full bool, peerID libp2ppeer.ID) {
r.mu.Lock() r.mu.Lock()
peerID, found := r.underlays[overlay.ByteString()] peerID, found = r.underlays[overlay.ByteString()]
delete(r.overlays, peerID) delete(r.overlays, peerID)
delete(r.underlays, overlay.ByteString()) delete(r.underlays, overlay.ByteString())
delete(r.connections, peerID) delete(r.connections, peerID)
...@@ -203,9 +215,11 @@ func (r *peerRegistry) remove(overlay swarm.Address) (bool, libp2ppeer.ID) { ...@@ -203,9 +215,11 @@ func (r *peerRegistry) remove(overlay swarm.Address) (bool, libp2ppeer.ID) {
cancel() cancel()
} }
delete(r.streams, peerID) delete(r.streams, peerID)
full = r.full[peerID]
delete(r.full, peerID)
r.mu.Unlock() r.mu.Unlock()
return found, peerID return found, full, peerID
} }
func (r *peerRegistry) setDisconnecter(d disconnecter) { func (r *peerRegistry) setDisconnecter(d disconnecter) {
......
...@@ -27,7 +27,81 @@ func TestNewStream(t *testing.T) { ...@@ -27,7 +27,81 @@ func TestNewStream(t *testing.T) {
s2, _ := newService(t, 1, libp2pServiceOpts{}) s2, _ := newService(t, 1, libp2pServiceOpts{})
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error { if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
return nil
})); err != nil {
t.Fatal(err)
}
addr := serviceUnderlayAddress(t, s1)
if _, err := s2.Connect(ctx, addr); err != nil {
t.Fatal(err)
}
stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
if err := stream.Close(); err != nil {
t.Fatal(err)
}
}
// TestNewStream_OnlyFull tests that the handler gets the full
// node information communicated correctly.
func TestNewStream_OnlyFull(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
FullNode: true,
}})
s2, _ := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
FullNode: true,
}})
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
if !p.FullNode {
t.Error("expected full node")
}
return nil
})); err != nil {
t.Fatal(err)
}
addr := serviceUnderlayAddress(t, s1)
if _, err := s2.Connect(ctx, addr); err != nil {
t.Fatal(err)
}
stream, err := s2.NewStream(ctx, overlay1, nil, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
if err := stream.Close(); err != nil {
t.Fatal(err)
}
}
// TestNewStream_Mixed tests that the handler gets the full
// node information communicated correctly for light node
func TestNewStream_Mixed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s1, overlay1 := newService(t, 1, libp2pServiceOpts{libp2pOpts: libp2p.Options{
FullNode: true,
}})
s2, _ := newService(t, 1, libp2pServiceOpts{})
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, p p2p.Peer, _ p2p.Stream) error {
if p.FullNode {
t.Error("expected light node")
}
return nil return nil
})); err != nil { })); err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -61,12 +135,12 @@ func TestNewStreamMulti(t *testing.T) { ...@@ -61,12 +135,12 @@ func TestNewStreamMulti(t *testing.T) {
var ( var (
h1calls, h2calls int32 h1calls, h2calls int32
h1 = func(_ context.Context, _ p2p.Peer, s p2p.Stream) error { h1 = func(_ context.Context, p p2p.Peer, s p2p.Stream) error {
defer s.Close() defer s.Close()
_ = atomic.AddInt32(&h1calls, 1) _ = atomic.AddInt32(&h1calls, 1)
return nil return nil
} }
h2 = func(_ context.Context, _ p2p.Peer, s p2p.Stream) error { h2 = func(_ context.Context, p p2p.Peer, s p2p.Stream) error {
defer s.Close() defer s.Close()
_ = atomic.AddInt32(&h2calls, 1) _ = atomic.AddInt32(&h2calls, 1)
return nil return nil
......
...@@ -94,7 +94,8 @@ type StreamSpec struct { ...@@ -94,7 +94,8 @@ type StreamSpec struct {
// Peer holds information about a Peer. // Peer holds information about a Peer.
type Peer struct { type Peer struct {
Address swarm.Address `json:"address"` Address swarm.Address `json:"address"`
FullNode bool `json:"fullNode"`
} }
// HandlerFunc handles a received Stream from a Peer. // HandlerFunc handles a received Stream from a Peer.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment