Commit 6a95f067 authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

p2p.Stream FullClose (#20)

Stream fullclose functionality + tests.
parent 97199ea1
......@@ -51,3 +51,7 @@ func (s *Stream) Write(p []byte) (n int, err error) {
func (s *Stream) Close() error {
return nil
}
func (s *StreamMock) FullClose() error {
return nil
}
......@@ -173,7 +173,7 @@ func New(ctx context.Context, o Options) (*Service, error) {
s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) {
peerID := stream.Conn().RemotePeer()
i, err := s.handshakeService.Handle(stream)
i, err := s.handshakeService.Handle(newStream(stream))
if err != nil {
if err == handshake.ErrNetworkIDIncompatible {
s.logger.Warningf("peer %s has a different network id.", peerID)
......@@ -237,7 +237,7 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
}
s.metrics.HandledStreamCount.Inc()
if err := ss.Handler(p2p.Peer{Address: overlay}, stream); err != nil {
if err := ss.Handler(p2p.Peer{Address: overlay}, newStream(stream)); err != nil {
var e *p2p.DisconnectError
if errors.Is(err, e) {
// todo: test connection close and refactor
......@@ -283,14 +283,13 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay swarm
return swarm.Address{}, err
}
i, err := s.handshakeService.Handshake(stream)
i, err := s.handshakeService.Handshake(newStream(stream))
if err != nil {
_ = s.disconnect(info.ID)
return swarm.Address{}, fmt.Errorf("handshake: %w", err)
}
if err := helpers.FullClose(stream); err != nil {
_ = stream.Reset()
return swarm.Address{}, err
}
......@@ -333,7 +332,12 @@ func (s *Service) NewStream(ctx context.Context, overlay swarm.Address, protocol
return nil, p2p.ErrPeerNotFound
}
return s.newStreamForPeerID(ctx, peerID, protocolName, protocolVersion, streamName)
stream, err := s.newStreamForPeerID(ctx, peerID, protocolName, protocolVersion, streamName)
if err != nil {
return nil, err
}
return newStream(stream), nil
}
func (s *Service) newStreamForPeerID(ctx context.Context, peerID libp2ppeer.ID, protocolName, protocolVersion, streamName string) (network.Stream, error) {
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package libp2p
import (
"github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/network"
)
type stream struct {
network.Stream
}
func (s *stream) FullClose() error {
return helpers.FullClose(s)
}
func newStream(s network.Stream) *stream {
return &stream{Stream: s}
}
......@@ -26,6 +26,7 @@ type Streamer interface {
type Stream interface {
io.ReadWriter
io.Closer
FullClose() error
}
type ProtocolSpec struct {
......
......@@ -363,6 +363,10 @@ func (noopWriteCloser) Close() error {
return nil
}
func (noopWriteCloser) FullClose() error {
return nil
}
type noopReadCloser struct {
io.Writer
}
......@@ -378,3 +382,7 @@ func (noopReadCloser) Read(p []byte) (n int, err error) {
func (noopReadCloser) Close() error {
return nil
}
func (noopReadCloser) FullClose() error {
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package streamtest
import "time"
func SetFullCloseTimeout(t time.Duration) {
fullCloseTimeout = t
}
func ResetFullCloseTimeout() {
fullCloseTimeout = fullCloseTimeoutDefault
}
......@@ -9,14 +9,18 @@ import (
"errors"
"io"
"sync"
"time"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
)
var (
ErrRecordsNotFound = errors.New("records not found")
ErrStreamNotSupported = errors.New("stream not supported")
ErrRecordsNotFound = errors.New("records not found")
ErrStreamNotSupported = errors.New("stream not supported")
ErrStreamFullcloseTimeout = errors.New("fullclose timeout")
fullCloseTimeout = fullCloseTimeoutDefault // timeout of fullclose
fullCloseTimeoutDefault = 5 * time.Second // default timeout used for helper function to reset timeout when changed
)
type Recorder struct {
......@@ -51,8 +55,10 @@ func New(opts ...Option) *Recorder {
func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName, protocolVersion, streamName string) (p2p.Stream, error) {
recordIn := newRecord()
recordOut := newRecord()
streamOut := newStream(recordIn, recordOut)
streamIn := newStream(recordOut, recordIn)
closedIn := make(chan struct{})
closedOut := make(chan struct{})
streamOut := newStream(recordIn, recordOut, closedIn, closedOut)
streamIn := newStream(recordOut, recordIn, closedOut, closedIn)
var handler p2p.HandlerFunc
for _, p := range r.protocols {
......@@ -130,12 +136,15 @@ func (r *Record) setErr(err error) {
}
type stream struct {
in io.WriteCloser
out io.ReadCloser
in io.WriteCloser
out io.ReadCloser
cin chan struct{}
cout chan struct{}
closeOnce sync.Once
}
func newStream(in io.WriteCloser, out io.ReadCloser) *stream {
return &stream{in: in, out: out}
func newStream(in io.WriteCloser, out io.ReadCloser, cin, cout chan struct{}) *stream {
return &stream{in: in, out: out, cin: cin, cout: cout}
}
func (s *stream) Read(p []byte) (int, error) {
......@@ -147,12 +156,34 @@ func (s *stream) Write(p []byte) (int, error) {
}
func (s *stream) Close() error {
if err := s.in.Close(); err != nil {
var e error
s.closeOnce.Do(func() {
if err := s.in.Close(); err != nil {
e = err
return
}
if err := s.out.Close(); err != nil {
e = err
return
}
close(s.cin)
})
return e
}
func (s *stream) FullClose() error {
if err := s.Close(); err != nil {
return err
}
if err := s.out.Close(); err != nil {
return err
select {
case <-s.cout:
case <-time.After(fullCloseTimeout):
return ErrStreamFullcloseTimeout
}
return nil
}
......
......@@ -13,14 +13,15 @@ import (
"io/ioutil"
"strings"
"testing"
"time"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/sync/errgroup"
)
func TestRecorder(t *testing.T) {
var answers = map[string]string{
"What is your name?": "Sir Lancelot of Camelot",
"What is your quest?": "To seek the Holy Grail.",
......@@ -120,6 +121,163 @@ func TestRecorder_errStreamNotSupported(t *testing.T) {
}
}
func TestRecorder_fullcloseWithRemoteClose(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
defer stream.Close()
_, err := bufio.NewReader(stream).ReadString('\n')
return err
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.WriteString("message\n"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return stream.FullClose()
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"message\n",
},
}, nil)
}
func TestRecorder_fullcloseWithoutRemoteClose(t *testing.T) {
streamtest.SetFullCloseTimeout(500 * time.Millisecond)
defer streamtest.ResetFullCloseTimeout()
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
// don't close the stream here to initiate timeout
// just try to read the message that it terminated with
// a new line character
_, err := bufio.NewReader(stream).ReadString('\n')
return err
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.WriteString("message\n"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return stream.FullClose()
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != streamtest.ErrStreamFullcloseTimeout {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"message\n",
},
}, nil)
}
func TestRecorder_multipleParallelFullCloseAndClose(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
if _, err := bufio.NewReader(stream).ReadString('\n'); err != nil {
return err
}
var g errgroup.Group
g.Go(stream.Close)
g.Go(stream.FullClose)
if err := g.Wait(); err != nil {
return err
}
return stream.FullClose()
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.WriteString("message\n"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
var g errgroup.Group
g.Go(stream.Close)
g.Go(stream.FullClose)
if err := g.Wait(); err != nil {
return err
}
return nil
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"message\n",
},
}, nil)
}
func TestRecorder_closeAfterPartialWrite(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment