Commit e3a26858 authored by Janos Guljas's avatar Janos Guljas

add message timeouts for handshake protocol

parent f2d976f6
...@@ -7,6 +7,7 @@ package handshake ...@@ -7,6 +7,7 @@ package handshake
import ( import (
"errors" "errors"
"fmt" "fmt"
"time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
...@@ -23,12 +24,15 @@ const ( ...@@ -23,12 +24,15 @@ const (
// ErrNetworkIDIncompatible should be returned by handshake handlers if // ErrNetworkIDIncompatible should be returned by handshake handlers if
// response from the other peer does not have valid networkID. // response from the other peer does not have valid networkID.
var ErrNetworkIDIncompatible = errors.New("incompatible networkID") var ErrNetworkIDIncompatible = errors.New("incompatible network ID")
// ErrHandshakeDuplicate should be returned by handshake handlers if // ErrHandshakeDuplicate should be returned by handshake handlers if
// the handshake response has been received by an already processed peer. // the handshake response has been received by an already processed peer.
var ErrHandshakeDuplicate = errors.New("duplicate handshake") var ErrHandshakeDuplicate = errors.New("duplicate handshake")
// messageTimeout is the maximal allowed time for a message to be read or written.
var messageTimeout = 5 * time.Second
// PeerFinder has the information if the peer already exists in swarm. // PeerFinder has the information if the peer already exists in swarm.
type PeerFinder interface { type PeerFinder interface {
Exists(overlay swarm.Address) (found bool) Exists(overlay swarm.Address) (found bool)
...@@ -52,16 +56,17 @@ func New(peerFinder PeerFinder, overlay swarm.Address, networkID int32, logger l ...@@ -52,16 +56,17 @@ func New(peerFinder PeerFinder, overlay swarm.Address, networkID int32, logger l
func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
var resp pb.SynAck
if err := w.WriteMsg(&pb.Syn{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{
Address: s.overlay.Bytes(), Address: s.overlay.Bytes(),
NetworkID: s.networkID, NetworkID: s.networkID,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write message: %w", err) return nil, fmt.Errorf("write syn message: %w", err)
} }
if err := r.ReadMsg(&resp); err != nil { var resp pb.SynAck
return nil, fmt.Errorf("read message: %w", err) if err := r.ReadMsgWithTimeout(messageTimeout, &resp); err != nil {
return nil, fmt.Errorf("read synack message: %w", err)
} }
address := swarm.NewAddress(resp.Syn.Address) address := swarm.NewAddress(resp.Syn.Address)
...@@ -73,8 +78,10 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { ...@@ -73,8 +78,10 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, ErrNetworkIDIncompatible return nil, ErrNetworkIDIncompatible
} }
if err := w.WriteMsg(&pb.Ack{Address: resp.Syn.Address}); err != nil { if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{
return nil, fmt.Errorf("ack: write message: %w", err) Address: resp.Syn.Address,
}); err != nil {
return nil, fmt.Errorf("write ack message: %w", err)
} }
s.logger.Tracef("handshake finished for peer %s", address) s.logger.Tracef("handshake finished for peer %s", address)
...@@ -91,8 +98,8 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) { ...@@ -91,8 +98,8 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
defer stream.Close() defer stream.Close()
var req pb.Syn var req pb.Syn
if err := r.ReadMsg(&req); err != nil { if err := r.ReadMsgWithTimeout(messageTimeout, &req); err != nil {
return nil, fmt.Errorf("read message: %w", err) return nil, fmt.Errorf("read syn message: %w", err)
} }
address := swarm.NewAddress(req.Address) address := swarm.NewAddress(req.Address)
...@@ -104,19 +111,19 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) { ...@@ -104,19 +111,19 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
return nil, ErrNetworkIDIncompatible return nil, ErrNetworkIDIncompatible
} }
if err := w.WriteMsg(&pb.SynAck{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{
Syn: &pb.Syn{ Syn: &pb.Syn{
Address: s.overlay.Bytes(), Address: s.overlay.Bytes(),
NetworkID: s.networkID, NetworkID: s.networkID,
}, },
Ack: &pb.Ack{Address: req.Address}, Ack: &pb.Ack{Address: req.Address},
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write message: %w", err) return nil, fmt.Errorf("write synack message: %w", err)
} }
var ack pb.Ack var ack pb.Ack
if err := r.ReadMsg(&ack); err != nil { if err := r.ReadMsgWithTimeout(messageTimeout, &ack); err != nil {
return nil, fmt.Errorf("ack: read message: %w", err) return nil, fmt.Errorf("read ack message: %w", err)
} }
s.logger.Tracef("handshake finished for peer %s", address) s.logger.Tracef("handshake finished for peer %s", address)
......
// Copyright 2020 The Swarm Authors. All rights reserved. // Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package handshake package handshake
import ( import (
...@@ -67,7 +68,7 @@ func TestHandshake(t *testing.T) { ...@@ -67,7 +68,7 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - Syn write error ", func(t *testing.T) { t.Run("ERROR - Syn write error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write message: %w", testErr) expectedErr := fmt.Errorf("write syn message: %w", testErr)
stream := &mock.StreamMock{} stream := &mock.StreamMock{}
stream.SetWriteErr(testErr, 0) stream.SetWriteErr(testErr, 0)
res, err := handshakeService.Handshake(stream) res, err := handshakeService.Handshake(stream)
...@@ -82,7 +83,7 @@ func TestHandshake(t *testing.T) { ...@@ -82,7 +83,7 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - Syn read error ", func(t *testing.T) { t.Run("ERROR - Syn read error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("read message: %w", testErr) expectedErr := fmt.Errorf("read synack message: %w", testErr)
stream := mock.NewStream(nil, &bytes.Buffer{}) stream := mock.NewStream(nil, &bytes.Buffer{})
stream.SetReadErr(testErr, 0) stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handshake(stream) res, err := handshakeService.Handshake(stream)
...@@ -97,7 +98,7 @@ func TestHandshake(t *testing.T) { ...@@ -97,7 +98,7 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - ack write error ", func(t *testing.T) { t.Run("ERROR - ack write error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("ack: write message: %w", testErr) expectedErr := fmt.Errorf("write ack message: %w", testErr)
expectedInfo := Info{ expectedInfo := Info{
Address: node2Addr, Address: node2Addr,
NetworkID: 0, NetworkID: 0,
...@@ -262,7 +263,7 @@ func TestHandle(t *testing.T) { ...@@ -262,7 +263,7 @@ func TestHandle(t *testing.T) {
t.Run("ERROR - read error ", func(t *testing.T) { t.Run("ERROR - read error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("read message: %w", testErr) expectedErr := fmt.Errorf("read syn message: %w", testErr)
stream := &mock.StreamMock{} stream := &mock.StreamMock{}
stream.SetReadErr(testErr, 0) stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handle(stream) res, err := handshakeService.Handle(stream)
...@@ -277,7 +278,7 @@ func TestHandle(t *testing.T) { ...@@ -277,7 +278,7 @@ func TestHandle(t *testing.T) {
t.Run("ERROR - write error ", func(t *testing.T) { t.Run("ERROR - write error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write message: %w", testErr) expectedErr := fmt.Errorf("write synack message: %w", testErr)
var buffer bytes.Buffer var buffer bytes.Buffer
stream := mock.NewStream(&buffer, &buffer) stream := mock.NewStream(&buffer, &buffer)
stream.SetWriteErr(testErr, 1) stream.SetWriteErr(testErr, 1)
...@@ -302,7 +303,7 @@ func TestHandle(t *testing.T) { ...@@ -302,7 +303,7 @@ func TestHandle(t *testing.T) {
t.Run("ERROR - ack read error ", func(t *testing.T) { t.Run("ERROR - ack read error ", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("ack: read message: %w", testErr) expectedErr := fmt.Errorf("read ack message: %w", testErr)
node2Info := Info{ node2Info := Info{
Address: node2Addr, Address: node2Addr,
NetworkID: 0, NetworkID: 0,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment