Commit 469d00f1 authored by Petar Radovic's avatar Petar Radovic

disconnect + disconnect error

parent 712f8b2c
package libp2p
// This error is handled specially by libp2p
// If returned by specific protocol handler it causes peer disconnect
type disconnectError struct {
err error
}
// Disconnect wraps error and creates a special error that is treated specially by libp2p
// It causes peer disconnect
func Disconnect(err error) error {
return &disconnectError{
err: err,
}
}
// Unwrap returns an underlying error
func (e *disconnectError) Unwrap() error { return e.err }
// Error implements function of the standard go error interface
func (w *disconnectError) Error() string {
return w.err.Error()
}
......@@ -83,7 +83,7 @@ func TestHandshake(t *testing.T) {
expectedErr := fmt.Errorf("handshake write message: %w", testErr)
stream := &StreamMock{writeError: testErr}
res, err := handshakeService.Handshake(stream)
if err.Error() != expectedErr.Error() {
if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err)
}
......@@ -97,7 +97,7 @@ func TestHandshake(t *testing.T) {
expectedErr := fmt.Errorf("handshake read message: %w", testErr)
stream := &StreamMock{writeBuffer: &bytes.Buffer{}, readError: testErr}
res, err := handshakeService.Handshake(stream)
if err.Error() != expectedErr.Error() {
if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err)
}
......@@ -163,7 +163,7 @@ func TestHandle(t *testing.T) {
expectedErr := fmt.Errorf("handshake handler read message: %w", testErr)
stream := &StreamMock{readError: testErr}
res, err := handshakeService.Handle(stream)
if err.Error() != expectedErr.Error() {
if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err)
}
......@@ -188,9 +188,8 @@ func TestHandle(t *testing.T) {
}
stream.writeError = testErr
res, err := handshakeService.Handle(stream)
if err.Error() != expectedErr.Error() {
if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err)
}
......
......@@ -7,6 +7,7 @@ package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
......@@ -218,13 +219,13 @@ func New(ctx context.Context, o Options) (*Service, error) {
if err != nil {
s.logger.Errorf("handshake with x %s: %w", peerID, err)
// todo: test connection close and refactor
stream.Conn().Close()
_ = stream.Conn().Close()
return
}
if i.NetworkID != s.networkID {
s.logger.Errorf("handshake with peer %s: invalid network id %v", peerID, i.NetworkID)
// todo: test connection close and refactor
stream.Conn().Close()
_ = stream.Conn().Close()
return
}
s.peers.add(peerID, i.Address)
......@@ -264,13 +265,22 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
peerID := stream.Conn().RemotePeer()
overlay, found := s.peers.overlay(peerID)
if !found {
// todo: handle better
// todo: this should never happen, should we disconnect in this case?
// todo: test connection close and refactor
_ = stream.Conn().Close()
s.logger.Errorf("overlay address for peer %q not found", peerID)
return
}
s.metrics.HandledStreamCount.Inc()
if err := ss.Handler(p2p.Peer{Address: overlay}, stream); err != nil {
var e *disconnectError
if errors.Is(err, e) {
// todo: test connection close and refactor
s.peers.remove(peerID)
_ = stream.Conn().Close()
}
s.logger.Errorf("%s: %s/%s: %w", p.Name, ss.Name, ss.Version, err)
}
})
......
......@@ -43,3 +43,11 @@ func (r *peerRegistry) overlay(peerID libp2ppeer.ID) (overlay string, found bool
r.mu.RUnlock()
return overlay, found
}
func (r *peerRegistry) remove(peerID libp2ppeer.ID) {
r.mu.Lock()
overlay := r.overlays[peerID]
delete(r.overlays, peerID)
delete(r.peers, overlay)
r.mu.Unlock()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment