Commit 9bf7e6dd authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

wait for remote stream to close in handshake (#13)

parent b46a390c
...@@ -11,6 +11,8 @@ import ( ...@@ -11,6 +11,8 @@ import (
"fmt" "fmt"
"net" "net"
"github.com/libp2p/go-libp2p-core/helpers"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
handshake "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake" handshake "github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake"
...@@ -262,9 +264,8 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -262,9 +264,8 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
stream, err := s.newStreamForPeerID(ctx, info.ID, handshake.ProtocolName, handshake.ProtocolVersion, handshake.StreamName) stream, err := s.newStreamForPeerID(ctx, info.ID, handshake.ProtocolName, handshake.ProtocolVersion, handshake.StreamName)
if err != nil { if err != nil {
_ = s.host.Network().ClosePeer(info.ID) _ = s.host.Network().ClosePeer(info.ID)
return swarm.Address{}, fmt.Errorf("new stream: %w", err) return swarm.Address{}, err
} }
defer stream.Close()
i, err := s.handshakeService.Handshake(stream) i, err := s.handshakeService.Handshake(stream)
if err != nil { if err != nil {
...@@ -272,6 +273,11 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -272,6 +273,11 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, err return swarm.Address{}, err
} }
if err := helpers.FullClose(stream); err != nil {
_ = stream.Reset()
return swarm.Address{}, err
}
s.peers.add(info.ID, i.Address) s.peers.add(info.ID, i.Address)
s.metrics.CreatedConnectionCount.Inc() s.metrics.CreatedConnectionCount.Inc()
s.logger.Infof("peer %s connected", i.Address) s.logger.Infof("peer %s connected", i.Address)
...@@ -299,7 +305,7 @@ func (s *Service) NewStream(ctx context.Context, overlay swarm.Address, protocol ...@@ -299,7 +305,7 @@ func (s *Service) NewStream(ctx context.Context, overlay swarm.Address, protocol
return s.newStreamForPeerID(ctx, peerID, protocolName, protocolVersion, streamName) return s.newStreamForPeerID(ctx, peerID, protocolName, protocolVersion, streamName)
} }
func (s *Service) newStreamForPeerID(ctx context.Context, peerID libp2ppeer.ID, protocolName, protocolVersion, streamName string) (p2p.Stream, error) { func (s *Service) newStreamForPeerID(ctx context.Context, peerID libp2ppeer.ID, protocolName, protocolVersion, streamName string) (network.Stream, error) {
swarmStreamName := p2p.NewSwarmStreamName(protocolName, protocolVersion, streamName) swarmStreamName := p2p.NewSwarmStreamName(protocolName, protocolVersion, streamName)
st, err := s.host.NewStream(ctx, peerID, protocol.ID(swarmStreamName)) st, err := s.host.NewStream(ctx, peerID, protocol.ID(swarmStreamName))
if err != nil { if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment