Commit 6a95f067 authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

p2p.Stream FullClose (#20)

Stream fullclose functionality + tests.
parent 97199ea1
...@@ -51,3 +51,7 @@ func (s *Stream) Write(p []byte) (n int, err error) { ...@@ -51,3 +51,7 @@ func (s *Stream) Write(p []byte) (n int, err error) {
func (s *Stream) Close() error { func (s *Stream) Close() error {
return nil return nil
} }
func (s *StreamMock) FullClose() error {
return nil
}
...@@ -173,7 +173,7 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -173,7 +173,7 @@ func New(ctx context.Context, o Options) (*Service, error) {
s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) { s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) {
peerID := stream.Conn().RemotePeer() peerID := stream.Conn().RemotePeer()
i, err := s.handshakeService.Handle(stream) i, err := s.handshakeService.Handle(newStream(stream))
if err != nil { if err != nil {
if err == handshake.ErrNetworkIDIncompatible { if err == handshake.ErrNetworkIDIncompatible {
s.logger.Warningf("peer %s has a different network id.", peerID) s.logger.Warningf("peer %s has a different network id.", peerID)
...@@ -237,7 +237,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) { ...@@ -237,7 +237,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
} }
s.metrics.HandledStreamCount.Inc() s.metrics.HandledStreamCount.Inc()
if err := ss.Handler(p2p.Peer{Address: overlay}, stream); err != nil { if err := ss.Handler(p2p.Peer{Address: overlay}, newStream(stream)); err != nil {
var e *p2p.DisconnectError var e *p2p.DisconnectError
if errors.Is(err, e) { if errors.Is(err, e) {
// todo: test connection close and refactor // todo: test connection close and refactor
...@@ -283,14 +283,13 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm ...@@ -283,14 +283,13 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, err return swarm.Address{}, err
} }
i, err := s.handshakeService.Handshake(stream) i, err := s.handshakeService.Handshake(newStream(stream))
if err != nil { if err != nil {
_ = s.disconnect(info.ID) _ = s.disconnect(info.ID)
return swarm.Address{}, fmt.Errorf("handshake: %w", err) return swarm.Address{}, fmt.Errorf("handshake: %w", err)
} }
if err := helpers.FullClose(stream); err != nil { if err := helpers.FullClose(stream); err != nil {
_ = stream.Reset()
return swarm.Address{}, err return swarm.Address{}, err
} }
...@@ -333,7 +332,12 @@ func (s *Service) NewStream(ctx context.Context, overlay swarm.Address, protocol ...@@ -333,7 +332,12 @@ func (s *Service) NewStream(ctx context.Context, overlay swarm.Address, protocol
return nil, p2p.ErrPeerNotFound return nil, p2p.ErrPeerNotFound
} }
return s.newStreamForPeerID(ctx, peerID, protocolName, protocolVersion, streamName) stream, err := s.newStreamForPeerID(ctx, peerID, protocolName, protocolVersion, streamName)
if err != nil {
return nil, err
}
return newStream(stream), nil
} }
func (s *Service) newStreamForPeerID(ctx context.Context, peerID libp2ppeer.ID, protocolName, protocolVersion, streamName string) (network.Stream, error) { func (s *Service) newStreamForPeerID(ctx context.Context, peerID libp2ppeer.ID, protocolName, protocolVersion, streamName string) (network.Stream, error) {
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package libp2p
import (
"github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/network"
)
type stream struct {
network.Stream
}
func (s *stream) FullClose() error {
return helpers.FullClose(s)
}
func newStream(s network.Stream) *stream {
return &stream{Stream: s}
}
...@@ -26,6 +26,7 @@ type Streamer interface { ...@@ -26,6 +26,7 @@ type Streamer interface {
type Stream interface { type Stream interface {
io.ReadWriter io.ReadWriter
io.Closer io.Closer
FullClose() error
} }
type ProtocolSpec struct { type ProtocolSpec struct {
......
...@@ -363,6 +363,10 @@ func (noopWriteCloser) Close() error { ...@@ -363,6 +363,10 @@ func (noopWriteCloser) Close() error {
return nil return nil
} }
func (noopWriteCloser) FullClose() error {
return nil
}
type noopReadCloser struct { type noopReadCloser struct {
io.Writer io.Writer
} }
...@@ -378,3 +382,7 @@ func (noopReadCloser) Read(p []byte) (n int, err error) { ...@@ -378,3 +382,7 @@ func (noopReadCloser) Read(p []byte) (n int, err error) {
func (noopReadCloser) Close() error { func (noopReadCloser) Close() error {
return nil return nil
} }
func (noopReadCloser) FullClose() error {
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package streamtest
import "time"
func SetFullCloseTimeout(t time.Duration) {
fullCloseTimeout = t
}
func ResetFullCloseTimeout() {
fullCloseTimeout = fullCloseTimeoutDefault
}
...@@ -9,14 +9,18 @@ import ( ...@@ -9,14 +9,18 @@ import (
"errors" "errors"
"io" "io"
"sync" "sync"
"time"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
var ( var (
ErrRecordsNotFound = errors.New("records not found") ErrRecordsNotFound = errors.New("records not found")
ErrStreamNotSupported = errors.New("stream not supported") ErrStreamNotSupported = errors.New("stream not supported")
ErrStreamFullcloseTimeout = errors.New("fullclose timeout")
fullCloseTimeout = fullCloseTimeoutDefault // timeout of fullclose
fullCloseTimeoutDefault = 5 * time.Second // default timeout used for helper function to reset timeout when changed
) )
type Recorder struct { type Recorder struct {
...@@ -51,8 +55,10 @@ func New(opts ...Option) *Recorder { ...@@ -51,8 +55,10 @@ func New(opts ...Option) *Recorder {
func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName, protocolVersion, streamName string) (p2p.Stream, error) { func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName, protocolVersion, streamName string) (p2p.Stream, error) {
recordIn := newRecord() recordIn := newRecord()
recordOut := newRecord() recordOut := newRecord()
streamOut := newStream(recordIn, recordOut) closedIn := make(chan struct{})
streamIn := newStream(recordOut, recordIn) closedOut := make(chan struct{})
streamOut := newStream(recordIn, recordOut, closedIn, closedOut)
streamIn := newStream(recordOut, recordIn, closedOut, closedIn)
var handler p2p.HandlerFunc var handler p2p.HandlerFunc
for _, p := range r.protocols { for _, p := range r.protocols {
...@@ -130,12 +136,15 @@ func (r *Record) setErr(err error) { ...@@ -130,12 +136,15 @@ func (r *Record) setErr(err error) {
} }
type stream struct { type stream struct {
in io.WriteCloser in io.WriteCloser
out io.ReadCloser out io.ReadCloser
cin chan struct{}
cout chan struct{}
closeOnce sync.Once
} }
func newStream(in io.WriteCloser, out io.ReadCloser) *stream { func newStream(in io.WriteCloser, out io.ReadCloser, cin, cout chan struct{}) *stream {
return &stream{in: in, out: out} return &stream{in: in, out: out, cin: cin, cout: cout}
} }
func (s *stream) Read(p []byte) (int, error) { func (s *stream) Read(p []byte) (int, error) {
...@@ -147,12 +156,34 @@ func (s *stream) Write(p []byte) (int, error) { ...@@ -147,12 +156,34 @@ func (s *stream) Write(p []byte) (int, error) {
} }
func (s *stream) Close() error { func (s *stream) Close() error {
if err := s.in.Close(); err != nil { var e error
s.closeOnce.Do(func() {
if err := s.in.Close(); err != nil {
e = err
return
}
if err := s.out.Close(); err != nil {
e = err
return
}
close(s.cin)
})
return e
}
func (s *stream) FullClose() error {
if err := s.Close(); err != nil {
return err return err
} }
if err := s.out.Close(); err != nil {
return err select {
case <-s.cout:
case <-time.After(fullCloseTimeout):
return ErrStreamFullcloseTimeout
} }
return nil return nil
} }
......
...@@ -13,14 +13,15 @@ import ( ...@@ -13,14 +13,15 @@ import (
"io/ioutil" "io/ioutil"
"strings" "strings"
"testing" "testing"
"time"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/sync/errgroup"
) )
func TestRecorder(t *testing.T) { func TestRecorder(t *testing.T) {
var answers = map[string]string{ var answers = map[string]string{
"What is your name?": "Sir Lancelot of Camelot", "What is your name?": "Sir Lancelot of Camelot",
"What is your quest?": "To seek the Holy Grail.", "What is your quest?": "To seek the Holy Grail.",
...@@ -120,6 +121,163 @@ func TestRecorder_errStreamNotSupported(t *testing.T) { ...@@ -120,6 +121,163 @@ func TestRecorder_errStreamNotSupported(t *testing.T) {
} }
} }
func TestRecorder_fullcloseWithRemoteClose(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
defer stream.Close()
_, err := bufio.NewReader(stream).ReadString('\n')
return err
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.WriteString("message\n"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return stream.FullClose()
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"message\n",
},
}, nil)
}
func TestRecorder_fullcloseWithoutRemoteClose(t *testing.T) {
streamtest.SetFullCloseTimeout(500 * time.Millisecond)
defer streamtest.ResetFullCloseTimeout()
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
// don't close the stream here to initiate timeout
// just try to read the message that it terminated with
// a new line character
_, err := bufio.NewReader(stream).ReadString('\n')
return err
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.WriteString("message\n"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return stream.FullClose()
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != streamtest.ErrStreamFullcloseTimeout {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"message\n",
},
}, nil)
}
func TestRecorder_multipleParallelFullCloseAndClose(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
if _, err := bufio.NewReader(stream).ReadString('\n'); err != nil {
return err
}
var g errgroup.Group
g.Go(stream.Close)
g.Go(stream.FullClose)
if err := g.Wait(); err != nil {
return err
}
return stream.FullClose()
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.WriteString("message\n"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
var g errgroup.Group
g.Go(stream.Close)
g.Go(stream.FullClose)
if err := g.Wait(); err != nil {
return err
}
return nil
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"message\n",
},
}, nil)
}
func TestRecorder_closeAfterPartialWrite(t *testing.T) { func TestRecorder_closeAfterPartialWrite(t *testing.T) {
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols( streamtest.WithProtocols(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment