Commit b65f50ef authored by Janoš Guljaš's avatar Janoš Guljaš Committed by GitHub

propagate context to message read and write in p2p protocols (#840)

parent 19fbbf77
...@@ -107,7 +107,7 @@ func (s *Service) sendPeers(ctx context.Context, peer swarm.Address, peers []swa ...@@ -107,7 +107,7 @@ func (s *Service) sendPeers(ctx context.Context, peer swarm.Address, peers []swa
}) })
} }
if err := w.WriteMsg(&peersRequest); err != nil { if err := w.WriteMsgWithContext(ctx, &peersRequest); err != nil {
return fmt.Errorf("write Peers message: %w", err) return fmt.Errorf("write Peers message: %w", err)
} }
...@@ -116,8 +116,10 @@ func (s *Service) sendPeers(ctx context.Context, peer swarm.Address, peers []swa ...@@ -116,8 +116,10 @@ func (s *Service) sendPeers(ctx context.Context, peer swarm.Address, peers []swa
func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error { func (s *Service) peersHandler(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error {
_, r := protobuf.NewWriterAndReader(stream) _, r := protobuf.NewWriterAndReader(stream)
ctx, cancel := context.WithTimeout(ctx, messageTimeout)
defer cancel()
var peersReq pb.Peers var peersReq pb.Peers
if err := r.ReadMsgWithTimeout(messageTimeout, &peersReq); err != nil { if err := r.ReadMsgWithContext(ctx, &peersReq); err != nil {
_ = stream.Reset() _ = stream.Reset()
return fmt.Errorf("read requestPeers message: %w", err) return fmt.Errorf("read requestPeers message: %w", err)
} }
......
...@@ -297,7 +297,7 @@ func TestConnectRepeatHandshake(t *testing.T) { ...@@ -297,7 +297,7 @@ func TestConnectRepeatHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if _, err := s2.HandshakeService().Handshake(libp2p.NewStream(stream), info.Addrs[0], info.ID); err == nil { if _, err := s2.HandshakeService().Handshake(ctx, libp2p.NewStream(stream), info.Addrs[0], info.ID); err == nil {
t.Fatalf("expected stream error") t.Fatalf("expected stream error")
} }
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package handshake package handshake
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
...@@ -33,7 +34,7 @@ const ( ...@@ -33,7 +34,7 @@ const (
StreamName = "handshake" StreamName = "handshake"
// MaxWelcomeMessageLength is maximum number of characters allowed in the welcome message. // MaxWelcomeMessageLength is maximum number of characters allowed in the welcome message.
MaxWelcomeMessageLength = 140 MaxWelcomeMessageLength = 140
messageTimeout = 5 * time.Second handshakeTimeout = 15 * time.Second
) )
var ( var (
...@@ -101,7 +102,10 @@ func New(signer crypto.Signer, advertisableAddresser AdvertisableAddressResolver ...@@ -101,7 +102,10 @@ func New(signer crypto.Signer, advertisableAddresser AdvertisableAddressResolver
} }
// Handshake initiates a handshake with a peer. // Handshake initiates a handshake with a peer.
func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerID libp2ppeer.ID) (i *Info, err error) { func (s *Service) Handshake(ctx context.Context, stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerID libp2ppeer.ID) (i *Info, err error) {
ctx, cancel := context.WithTimeout(ctx, handshakeTimeout)
defer cancel()
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
fullRemoteMA, err := buildFullMA(peerMultiaddr, peerID) fullRemoteMA, err := buildFullMA(peerMultiaddr, peerID)
if err != nil { if err != nil {
...@@ -113,14 +117,14 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI ...@@ -113,14 +117,14 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI
return nil, err return nil, err
} }
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Syn{ if err := w.WriteMsgWithContext(ctx, &pb.Syn{
ObservedUnderlay: fullRemoteMABytes, ObservedUnderlay: fullRemoteMABytes,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write syn message: %w", err) return nil, fmt.Errorf("write syn message: %w", err)
} }
var resp pb.SynAck var resp pb.SynAck
if err := r.ReadMsgWithTimeout(messageTimeout, &resp); err != nil { if err := r.ReadMsgWithContext(ctx, &resp); err != nil {
return nil, fmt.Errorf("read synack message: %w", err) return nil, fmt.Errorf("read synack message: %w", err)
} }
...@@ -151,7 +155,7 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI ...@@ -151,7 +155,7 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI
// Synced read: // Synced read:
welcomeMessage := s.GetWelcomeMessage() welcomeMessage := s.GetWelcomeMessage()
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{ if err := w.WriteMsgWithContext(ctx, &pb.Ack{
Address: &pb.BzzAddress{ Address: &pb.BzzAddress{
Underlay: advertisableUnderlayBytes, Underlay: advertisableUnderlayBytes,
Overlay: bzzAddress.Overlay.Bytes(), Overlay: bzzAddress.Overlay.Bytes(),
...@@ -176,7 +180,10 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI ...@@ -176,7 +180,10 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI
} }
// Handle handles an incoming handshake from a peer. // Handle handles an incoming handshake from a peer.
func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remotePeerID libp2ppeer.ID) (i *Info, err error) { func (s *Service) Handle(ctx context.Context, stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remotePeerID libp2ppeer.ID) (i *Info, err error) {
ctx, cancel := context.WithTimeout(ctx, handshakeTimeout)
defer cancel()
s.receivedHandshakesMu.Lock() s.receivedHandshakesMu.Lock()
if _, exists := s.receivedHandshakes[remotePeerID]; exists { if _, exists := s.receivedHandshakes[remotePeerID]; exists {
s.receivedHandshakesMu.Unlock() s.receivedHandshakesMu.Unlock()
...@@ -197,7 +204,7 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote ...@@ -197,7 +204,7 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote
} }
var syn pb.Syn var syn pb.Syn
if err := r.ReadMsgWithTimeout(messageTimeout, &syn); err != nil { if err := r.ReadMsgWithContext(ctx, &syn); err != nil {
return nil, fmt.Errorf("read syn message: %w", err) return nil, fmt.Errorf("read syn message: %w", err)
} }
...@@ -223,7 +230,7 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote ...@@ -223,7 +230,7 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote
welcomeMessage := s.GetWelcomeMessage() welcomeMessage := s.GetWelcomeMessage()
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{ if err := w.WriteMsgWithContext(ctx, &pb.SynAck{
Syn: &pb.Syn{ Syn: &pb.Syn{
ObservedUnderlay: fullRemoteMABytes, ObservedUnderlay: fullRemoteMABytes,
}, },
...@@ -242,7 +249,7 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote ...@@ -242,7 +249,7 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote
} }
var ack pb.Ack var ack pb.Ack
if err := r.ReadMsgWithTimeout(messageTimeout, &ack); err != nil { if err := r.ReadMsgWithContext(ctx, &ack); err != nil {
return nil, fmt.Errorf("read ack message: %w", err) return nil, fmt.Errorf("read ack message: %w", err)
} }
......
...@@ -6,6 +6,7 @@ package handshake_test ...@@ -6,6 +6,7 @@ package handshake_test
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
...@@ -121,7 +122,7 @@ func TestHandshake(t *testing.T) { ...@@ -121,7 +122,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handshake(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -193,7 +194,7 @@ func TestHandshake(t *testing.T) { ...@@ -193,7 +194,7 @@ func TestHandshake(t *testing.T) {
expectedErr := fmt.Errorf("write syn message: %w", testErr) expectedErr := fmt.Errorf("write syn message: %w", testErr)
stream := &mock.Stream{} stream := &mock.Stream{}
stream.SetWriteErr(testErr, 0) stream.SetWriteErr(testErr, 0)
res, err := handshakeService.Handshake(stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handshake(context.Background(), stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -208,7 +209,7 @@ func TestHandshake(t *testing.T) { ...@@ -208,7 +209,7 @@ func TestHandshake(t *testing.T) {
expectedErr := fmt.Errorf("read synack message: %w", testErr) expectedErr := fmt.Errorf("read synack message: %w", testErr)
stream := mock.NewStream(nil, &bytes.Buffer{}) stream := mock.NewStream(nil, &bytes.Buffer{})
stream.SetReadErr(testErr, 0) stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handshake(stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handshake(context.Background(), stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -246,7 +247,7 @@ func TestHandshake(t *testing.T) { ...@@ -246,7 +247,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handshake(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -280,7 +281,7 @@ func TestHandshake(t *testing.T) { ...@@ -280,7 +281,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handshake(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if res != nil { if res != nil {
t.Fatal("res should be nil") t.Fatal("res should be nil")
} }
...@@ -314,7 +315,7 @@ func TestHandshake(t *testing.T) { ...@@ -314,7 +315,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handshake(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if res != nil { if res != nil {
t.Fatal("res should be nil") t.Fatal("res should be nil")
} }
...@@ -354,7 +355,7 @@ func TestHandshake(t *testing.T) { ...@@ -354,7 +355,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handshake(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handshake(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err != testError { if err != testError {
t.Fatalf("expected error %v got %v", testError, err) t.Fatalf("expected error %v got %v", testError, err)
...@@ -395,7 +396,7 @@ func TestHandshake(t *testing.T) { ...@@ -395,7 +396,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -432,7 +433,7 @@ func TestHandshake(t *testing.T) { ...@@ -432,7 +433,7 @@ func TestHandshake(t *testing.T) {
expectedErr := fmt.Errorf("read syn message: %w", testErr) expectedErr := fmt.Errorf("read syn message: %w", testErr)
stream := &mock.Stream{} stream := &mock.Stream{}
stream.SetReadErr(testErr, 0) stream.SetReadErr(testErr, 0)
res, err := handshakeService.Handle(stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handle(context.Background(), stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -459,7 +460,7 @@ func TestHandshake(t *testing.T) { ...@@ -459,7 +460,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handle(context.Background(), stream, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -488,7 +489,7 @@ func TestHandshake(t *testing.T) { ...@@ -488,7 +489,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err == nil || err.Error() != expectedErr.Error() { if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err) t.Fatal("expected:", expectedErr, "got:", err)
} }
...@@ -527,7 +528,7 @@ func TestHandshake(t *testing.T) { ...@@ -527,7 +528,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if res != nil { if res != nil {
t.Fatal("res should be nil") t.Fatal("res should be nil")
} }
...@@ -566,7 +567,7 @@ func TestHandshake(t *testing.T) { ...@@ -566,7 +567,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -593,7 +594,7 @@ func TestHandshake(t *testing.T) { ...@@ -593,7 +594,7 @@ func TestHandshake(t *testing.T) {
Light: got.Ack.Light, Light: got.Ack.Light,
}) })
_, err = handshakeService.Handle(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) _, err = handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err != handshake.ErrHandshakeDuplicate { if err != handshake.ErrHandshakeDuplicate {
t.Fatalf("expected %s, got %s", handshake.ErrHandshakeDuplicate, err) t.Fatalf("expected %s, got %s", handshake.ErrHandshakeDuplicate, err)
} }
...@@ -628,7 +629,7 @@ func TestHandshake(t *testing.T) { ...@@ -628,7 +629,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
_, err = handshakeService.Handle(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) _, err = handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err != handshake.ErrInvalidAck { if err != handshake.ErrInvalidAck {
t.Fatalf("expected %s, got %v", handshake.ErrInvalidAck, err) t.Fatalf("expected %s, got %v", handshake.ErrInvalidAck, err)
} }
...@@ -657,7 +658,7 @@ func TestHandshake(t *testing.T) { ...@@ -657,7 +658,7 @@ func TestHandshake(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
res, err := handshakeService.Handle(stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID) res, err := handshakeService.Handle(context.Background(), stream1, node2AddrInfo.Addrs[0], node2AddrInfo.ID)
if err != testError { if err != testError {
t.Fatal("expected error") t.Fatal("expected error")
} }
......
...@@ -229,7 +229,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay ...@@ -229,7 +229,7 @@ func New(ctx context.Context, signer beecrypto.Signer, networkID uint64, overlay
s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) { s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) {
peerID := stream.Conn().RemotePeer() peerID := stream.Conn().RemotePeer()
handshakeStream := NewStream(stream) handshakeStream := NewStream(stream)
i, err := s.handshakeService.Handle(handshakeStream, stream.Conn().RemoteMultiaddr(), peerID) i, err := s.handshakeService.Handle(ctx, handshakeStream, stream.Conn().RemoteMultiaddr(), peerID)
if err != nil { if err != nil {
s.logger.Debugf("handshake: handle %s: %v", peerID, err) s.logger.Debugf("handshake: handle %s: %v", peerID, err)
s.logger.Errorf("unable to handshake with peer %v", peerID) s.logger.Errorf("unable to handshake with peer %v", peerID)
...@@ -455,7 +455,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz. ...@@ -455,7 +455,7 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.
} }
handshakeStream := NewStream(stream) handshakeStream := NewStream(stream)
i, err := s.handshakeService.Handshake(handshakeStream, stream.Conn().RemoteMultiaddr(), stream.Conn().RemotePeer()) i, err := s.handshakeService.Handshake(ctx, handshakeStream, stream.Conn().RemoteMultiaddr(), stream.Conn().RemotePeer())
if err != nil { if err != nil {
_ = handshakeStream.Reset() _ = handshakeStream.Reset()
_ = s.host.Network().ClosePeer(info.ID) _ = s.host.Network().ClosePeer(info.ID)
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"context" "context"
"errors" "errors"
"io" "io"
"time"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
ggio "github.com/gogo/protobuf/io" ggio "github.com/gogo/protobuf/io"
...@@ -70,23 +69,6 @@ func (r Reader) ReadMsgWithContext(ctx context.Context, msg proto.Message) error ...@@ -70,23 +69,6 @@ func (r Reader) ReadMsgWithContext(ctx context.Context, msg proto.Message) error
} }
} }
func (r Reader) ReadMsgWithTimeout(d time.Duration, msg proto.Message) error {
errChan := make(chan error, 1)
go func() {
errChan <- r.ReadMsg(msg)
}()
timer := time.NewTimer(d)
defer timer.Stop()
select {
case err := <-errChan:
return err
case <-timer.C:
return ErrTimeout
}
}
type Writer struct { type Writer struct {
ggio.Writer ggio.Writer
} }
...@@ -108,20 +90,3 @@ func (w Writer) WriteMsgWithContext(ctx context.Context, msg proto.Message) erro ...@@ -108,20 +90,3 @@ func (w Writer) WriteMsgWithContext(ctx context.Context, msg proto.Message) erro
return ctx.Err() return ctx.Err()
} }
} }
func (w Writer) WriteMsgWithTimeout(d time.Duration, msg proto.Message) error {
errChan := make(chan error, 1)
go func() {
errChan <- w.WriteMsg(msg)
}()
timer := time.NewTimer(d)
defer timer.Stop()
select {
case err := <-errChan:
return err
case <-timer.C:
return ErrTimeout
}
}
...@@ -121,34 +121,6 @@ func TestReader_timeout(t *testing.T) { ...@@ -121,34 +121,6 @@ func TestReader_timeout(t *testing.T) {
} }
} }
}) })
t.Run("WithTimeout", func(t *testing.T) {
r := tc.readerFunc()
var msg pb.Message
for i := 0; i < len(messages); i++ {
var timeout time.Duration
if i == 0 {
timeout = 600 * time.Millisecond
} else {
timeout = 10 * time.Millisecond
}
err := r.ReadMsgWithTimeout(timeout, &msg)
if i == 0 {
if err != nil {
t.Fatal(err)
}
} else {
if err != protobuf.ErrTimeout {
t.Fatalf("got error %v, want %v", err, protobuf.ErrTimeout)
}
break
}
want := messages[i]
got := msg.Text
if got != want {
t.Errorf("got message %q, want %q", got, want)
}
}
})
}) })
} }
} }
...@@ -248,34 +220,6 @@ func TestWriter_timeout(t *testing.T) { ...@@ -248,34 +220,6 @@ func TestWriter_timeout(t *testing.T) {
} }
} }
}) })
t.Run("WithTimeout", func(t *testing.T) {
w, msgs := tc.writerFunc()
for i, m := range messages {
var timeout time.Duration
if i == 0 {
timeout = 600 * time.Millisecond
} else {
timeout = 10 * time.Millisecond
}
err := w.WriteMsgWithTimeout(timeout, &pb.Message{
Text: m,
})
if i == 0 {
if err != nil {
t.Fatal(err)
}
} else {
if err != protobuf.ErrTimeout {
t.Fatalf("got error %v, want %v", err, protobuf.ErrTimeout)
}
break
}
if got := <-msgs; got != m {
t.Fatalf("got message %q, want %q", got, m)
}
}
})
}) })
} }
} }
......
...@@ -96,7 +96,9 @@ func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Head ...@@ -96,7 +96,9 @@ func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Head
go func() { go func() {
defer close(record.done) defer close(record.done)
err := handler(ctx, p2p.Peer{Address: addr}, streamIn) // pass a new context to handler,
// do not cancel it with the client stream context
err := handler(context.Background(), p2p.Peer{Address: addr}, streamIn)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
record.setErr(err) record.setErr(err)
} }
......
...@@ -74,14 +74,14 @@ func (s *Service) Ping(ctx context.Context, address swarm.Address, msgs ...strin ...@@ -74,14 +74,14 @@ func (s *Service) Ping(ctx context.Context, address swarm.Address, msgs ...strin
var pong pb.Pong var pong pb.Pong
for _, msg := range msgs { for _, msg := range msgs {
if err := w.WriteMsg(&pb.Ping{ if err := w.WriteMsgWithContext(ctx, &pb.Ping{
Greeting: msg, Greeting: msg,
}); err != nil { }); err != nil {
return 0, fmt.Errorf("write message: %w", err) return 0, fmt.Errorf("write message: %w", err)
} }
s.metrics.PingSentCount.Inc() s.metrics.PingSentCount.Inc()
if err := r.ReadMsg(&pong); err != nil { if err := r.ReadMsgWithContext(ctx, &pong); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
...@@ -103,7 +103,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) er ...@@ -103,7 +103,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) er
var ping pb.Ping var ping pb.Ping
for { for {
if err := r.ReadMsg(&ping); err != nil { if err := r.ReadMsgWithContext(ctx, &ping); err != nil {
if err == io.EOF { if err == io.EOF {
break break
} }
......
...@@ -75,7 +75,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -75,7 +75,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
}() }()
var req pb.AnnouncePaymentThreshold var req pb.AnnouncePaymentThreshold
if err := r.ReadMsg(&req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
s.logger.Debugf("error receiving payment threshold announcement from peer %v", p.Address) s.logger.Debugf("error receiving payment threshold announcement from peer %v", p.Address)
return fmt.Errorf("read request from peer %v: %w", p.Address, err) return fmt.Errorf("read request from peer %v: %w", p.Address, err)
} }
......
...@@ -358,7 +358,7 @@ func (p *Puller) histSyncWorker(ctx context.Context, peer swarm.Address, bin uin ...@@ -358,7 +358,7 @@ func (p *Puller) histSyncWorker(ctx context.Context, peer swarm.Address, bin uin
p.metrics.HistWorkerErrCounter.Inc() p.metrics.HistWorkerErrCounter.Inc()
return return
} }
if err := p.syncer.CancelRuid(peer, ruid); err != nil && logMore { if err := p.syncer.CancelRuid(ctx, peer, ruid); err != nil && logMore {
p.logger.Debugf("histSyncWorker cancel ruid: %v", err) p.logger.Debugf("histSyncWorker cancel ruid: %v", err)
} }
return return
...@@ -402,7 +402,7 @@ func (p *Puller) liveSyncWorker(ctx context.Context, peer swarm.Address, bin uin ...@@ -402,7 +402,7 @@ func (p *Puller) liveSyncWorker(ctx context.Context, peer swarm.Address, bin uin
p.metrics.LiveWorkerErrCounter.Inc() p.metrics.LiveWorkerErrCounter.Inc()
return return
} }
if err := p.syncer.CancelRuid(peer, ruid); err != nil && logMore { if err := p.syncer.CancelRuid(ctx, peer, ruid); err != nil && logMore {
p.logger.Debugf("histSyncWorker cancel ruid: %v", err) p.logger.Debugf("histSyncWorker cancel ruid: %v", err)
} }
return return
......
...@@ -209,7 +209,7 @@ func (p *PullSyncMock) SyncCalls(peer swarm.Address) (res []SyncCall) { ...@@ -209,7 +209,7 @@ func (p *PullSyncMock) SyncCalls(peer swarm.Address) (res []SyncCall) {
return res return res
} }
func (p *PullSyncMock) CancelRuid(peer swarm.Address, ruid uint32) error { func (p *PullSyncMock) CancelRuid(ctx context.Context, peer swarm.Address, ruid uint32) error {
return nil return nil
} }
......
...@@ -45,7 +45,7 @@ var maxPage = 50 ...@@ -45,7 +45,7 @@ var maxPage = 50
type Interface interface { type Interface interface {
SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, ruid uint32, err error) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, ruid uint32, err error)
GetCursors(ctx context.Context, peer swarm.Address) ([]uint64, error) GetCursors(ctx context.Context, peer swarm.Address) ([]uint64, error)
CancelRuid(peer swarm.Address, ruid uint32) error CancelRuid(ctx context.Context, peer swarm.Address, ruid uint32) error
} }
type Syncer struct { type Syncer struct {
...@@ -404,8 +404,8 @@ func (s *Syncer) cursorHandler(ctx context.Context, p p2p.Peer, stream p2p.Strea ...@@ -404,8 +404,8 @@ func (s *Syncer) cursorHandler(ctx context.Context, p p2p.Peer, stream p2p.Strea
return nil return nil
} }
func (s *Syncer) CancelRuid(peer swarm.Address, ruid uint32) (err error) { func (s *Syncer) CancelRuid(ctx context.Context, peer swarm.Address, ruid uint32) (err error) {
stream, err := s.streamer.NewStream(context.Background(), peer, nil, protocolName, protocolVersion, cancelStreamName) stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, cancelStreamName)
if err != nil { if err != nil {
return fmt.Errorf("new stream: %w", err) return fmt.Errorf("new stream: %w", err)
} }
...@@ -419,9 +419,12 @@ func (s *Syncer) CancelRuid(peer swarm.Address, ruid uint32) (err error) { ...@@ -419,9 +419,12 @@ func (s *Syncer) CancelRuid(peer swarm.Address, ruid uint32) (err error) {
} }
}() }()
ctx, cancel := context.WithTimeout(ctx, cancellationTimeout)
defer cancel()
var c pb.Cancel var c pb.Cancel
c.Ruid = ruid c.Ruid = ruid
if err := w.WriteMsgWithTimeout(cancellationTimeout, &c); err != nil { if err := w.WriteMsgWithContext(ctx, &c); err != nil {
return fmt.Errorf("send cancellation: %w", err) return fmt.Errorf("send cancellation: %w", err)
} }
return nil return nil
......
...@@ -85,7 +85,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -85,7 +85,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
} }
}() }()
var req pb.Payment var req pb.Payment
if err := r.ReadMsg(&req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
return fmt.Errorf("read request from peer %v: %w", p.Address, err) return fmt.Errorf("read request from peer %v: %w", p.Address, err)
} }
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"context" "context"
"io/ioutil" "io/ioutil"
"testing" "testing"
"time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
...@@ -20,13 +21,19 @@ import ( ...@@ -20,13 +21,19 @@ import (
) )
type testObserver struct { type testObserver struct {
called bool called chan struct{}
peer swarm.Address peer swarm.Address
amount uint64 amount uint64
} }
func newTestObserver() *testObserver {
return &testObserver{
called: make(chan struct{}),
}
}
func (t *testObserver) NotifyPayment(peer swarm.Address, amount uint64) error { func (t *testObserver) NotifyPayment(peer swarm.Address, amount uint64) error {
t.called = true close(t.called)
t.peer = peer t.peer = peer
t.amount = amount t.amount = amount
return nil return nil
...@@ -38,7 +45,7 @@ func TestPayment(t *testing.T) { ...@@ -38,7 +45,7 @@ func TestPayment(t *testing.T) {
storeRecipient := mock.NewStateStore() storeRecipient := mock.NewStateStore()
defer storeRecipient.Close() defer storeRecipient.Close()
observer := &testObserver{} observer := newTestObserver()
recipient := pseudosettle.New(nil, logger, storeRecipient) recipient := pseudosettle.New(nil, logger, storeRecipient)
recipient.SetPaymentObserver(observer) recipient.SetPaymentObserver(observer)
...@@ -70,6 +77,10 @@ func TestPayment(t *testing.T) { ...@@ -70,6 +77,10 @@ func TestPayment(t *testing.T) {
record := records[0] record := records[0]
if err := record.Err(); err != nil {
t.Fatalf("record error: %v", err)
}
messages, err := protobuf.ReadMessages( messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()), bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Payment) }, func() protobuf.Message { return new(pb.Payment) },
...@@ -87,7 +98,9 @@ func TestPayment(t *testing.T) { ...@@ -87,7 +98,9 @@ func TestPayment(t *testing.T) {
t.Fatalf("got message with amount %v, want %v", sentAmount, amount) t.Fatalf("got message with amount %v, want %v", sentAmount, amount)
} }
if !observer.called { select {
case <-observer.called:
case <-time.After(time.Second):
t.Fatal("expected observer to be called") t.Fatal("expected observer to be called")
} }
......
...@@ -91,7 +91,7 @@ func (s *Service) initHandler(ctx context.Context, p p2p.Peer, stream p2p.Stream ...@@ -91,7 +91,7 @@ func (s *Service) initHandler(ctx context.Context, p p2p.Peer, stream p2p.Stream
} }
}() }()
var req pb.Handshake var req pb.Handshake
if err := r.ReadMsg(&req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
return fmt.Errorf("read request from peer %v: %w", p.Address, err) return fmt.Errorf("read request from peer %v: %w", p.Address, err)
} }
...@@ -136,7 +136,7 @@ func (s *Service) init(ctx context.Context, p p2p.Peer) error { ...@@ -136,7 +136,7 @@ func (s *Service) init(ctx context.Context, p p2p.Peer) error {
} }
var req pb.Handshake var req pb.Handshake
if err := r.ReadMsg(&req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
return fmt.Errorf("read request from peer %v: %w", p.Address, err) return fmt.Errorf("read request from peer %v: %w", p.Address, err)
} }
...@@ -160,7 +160,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -160,7 +160,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
} }
}() }()
var req pb.EmitCheque var req pb.EmitCheque
if err := r.ReadMsg(&req); err != nil { if err := r.ReadMsgWithContext(ctx, &req); err != nil {
return fmt.Errorf("read request from peer %v: %w", p.Address, err) return fmt.Errorf("read request from peer %v: %w", p.Address, err)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment