Commit e3a26858 authored by Janos Guljas's avatar Janos Guljas

add message timeouts for handshake protocol

parent f2d976f6
......@@ -7,6 +7,7 @@ package handshake
import (
"errors"
"fmt"
"time"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
......@@ -23,12 +24,15 @@ const (
// ErrNetworkIDIncompatible should be returned by handshake handlers if
// response from the other peer does not have valid networkID.
var ErrNetworkIDIncompatible = errors.New("incompatible networkID")
var ErrNetworkIDIncompatible = errors.New("incompatible network ID")
// ErrHandshakeDuplicate should be returned by handshake handlers if
// the handshake response has been received by an already processed peer.
var ErrHandshakeDuplicate = errors.New("duplicate handshake")
// messageTimeout is the maximal allowed time for a message to be read or written.
var messageTimeout = 5 * time.Second
// PeerFinder has the information if the peer already exists in swarm.
type PeerFinder interface {
Exists(overlay swarm.Address) (found bool)
......@@ -52,16 +56,17 @@ func New(peerFinder PeerFinder, overlay swarm.Address, networkID int32, logger l
func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
w, r := protobuf.NewWriterAndReader(stream)
var resp pb.SynAck
if err := w.WriteMsg(&pb.Syn{
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{
Address: s.overlay.Bytes(),
NetworkID: s.networkID,
}); err != nil {
return nil, fmt.Errorf("write message: %w", err)
return nil, fmt.Errorf("write syn message: %w", err)
}
if err := r.ReadMsg(&resp); err != nil {
return nil, fmt.Errorf("read message: %w", err)
var resp pb.SynAck
if err := r.ReadMsgWithTimeout(messageTimeout, &resp); err != nil {
return nil, fmt.Errorf("read synack message: %w", err)
}
address := swarm.NewAddress(resp.Syn.Address)
......@@ -73,8 +78,10 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, ErrNetworkIDIncompatible
}
if err := w.WriteMsg(&pb.Ack{Address: resp.Syn.Address}); err != nil {
return nil, fmt.Errorf("ack: write message: %w", err)
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{
Address: resp.Syn.Address,
}); err != nil {
return nil, fmt.Errorf("write ack message: %w", err)
}
s.logger.Tracef("handshake finished for peer %s", address)
......@@ -91,8 +98,8 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
defer stream.Close()
var req pb.Syn
if err := r.ReadMsg(&req); err != nil {
return nil, fmt.Errorf("read message: %w", err)
if err := r.ReadMsgWithTimeout(messageTimeout, &req); err != nil {
return nil, fmt.Errorf("read syn message: %w", err)
}
address := swarm.NewAddress(req.Address)
......@@ -104,19 +111,19 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
return nil, ErrNetworkIDIncompatible
}
if err := w.WriteMsg(&pb.SynAck{
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{
Syn: &pb.Syn{
Address: s.overlay.Bytes(),
NetworkID: s.networkID,
},
Ack: &pb.Ack{Address: req.Address},
}); err != nil {
return nil, fmt.Errorf("write message: %w", err)
return nil, fmt.Errorf("write synack message: %w", err)
}
var ack pb.Ack
if err := r.ReadMsg(&ack); err != nil {
return nil, fmt.Errorf("ack: read message: %w", err)
if err := r.ReadMsgWithTimeout(messageTimeout, &ack); err != nil {
return nil, fmt.Errorf("read ack message: %w", err)
}
s.logger.Tracef("handshake finished for peer %s", address)
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package handshake
import (
......@@ -67,7 +68,7 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - Syn write error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("write message: %w", testErr)
expectedErr := fmt.Errorf("write syn message: %w", testErr)
stream := &mock.StreamMock{}
stream.SetWriteErr(testErr, 0)
res, err := handshakeService.Handshake(stream)
......@@ -82,7 +83,7 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - Syn read error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("read message: %w", testErr)
expectedErr := fmt.Errorf("read synack message: %w", testErr)
stream := mock.NewStream(nil, &bytes.Buffer{})
stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handshake(stream)
......@@ -97,7 +98,7 @@ func TestHandshake(t *testing.T) {
t.Run("ERROR - ack write error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("ack: write message: %w", testErr)
expectedErr := fmt.Errorf("write ack message: %w", testErr)
expectedInfo := Info{
Address: node2Addr,
NetworkID: 0,
......@@ -262,7 +263,7 @@ func TestHandle(t *testing.T) {
t.Run("ERROR - read error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("read message: %w", testErr)
expectedErr := fmt.Errorf("read syn message: %w", testErr)
stream := &mock.StreamMock{}
stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handle(stream)
......@@ -277,7 +278,7 @@ func TestHandle(t *testing.T) {
t.Run("ERROR - write error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("write message: %w", testErr)
expectedErr := fmt.Errorf("write synack message: %w", testErr)
var buffer bytes.Buffer
stream := mock.NewStream(&buffer, &buffer)
stream.SetWriteErr(testErr, 1)
......@@ -302,7 +303,7 @@ func TestHandle(t *testing.T) {
t.Run("ERROR - ack read error ", func(t *testing.T) {
testErr := errors.New("test error")
expectedErr := fmt.Errorf("ack: read message: %w", testErr)
expectedErr := fmt.Errorf("read ack message: %w", testErr)
node2Info := Info{
Address: node2Addr,
NetworkID: 0,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment