Commit 8b7bab1a authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

Reset in streamtest (#449)

* reset in streamtest
parent 0a40f19f
......@@ -19,6 +19,7 @@ import (
var (
ErrRecordsNotFound = errors.New("records not found")
ErrStreamNotSupported = errors.New("stream not supported")
ErrStreamClosed = errors.New("stream closed")
ErrStreamFullcloseTimeout = errors.New("fullclose timeout")
fullCloseTimeout = fullCloseTimeoutDefault // timeout of fullclose
fullCloseTimeoutDefault = 5 * time.Second // default timeout used for helper function to reset timeout when changed
......@@ -67,10 +68,8 @@ func (r *Recorder) SetProtocols(protocols ...p2p.ProtocolSpec) {
func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Headers, protocolName, protocolVersion, streamName string) (p2p.Stream, error) {
recordIn := newRecord()
recordOut := newRecord()
closedIn := make(chan struct{})
closedOut := make(chan struct{})
streamOut := newStream(recordIn, recordOut, closedIn, closedOut)
streamIn := newStream(recordOut, recordIn, closedOut, closedIn)
streamOut := newStream(recordIn, recordOut)
streamIn := newStream(recordOut, recordIn)
var handler p2p.HandlerFunc
var headler p2p.HeadlerFunc
......@@ -179,16 +178,13 @@ func (r *Record) setErr(err error) {
}
type stream struct {
in io.WriteCloser
out io.ReadCloser
headers p2p.Headers
cin chan struct{}
cout chan struct{}
closeOnce sync.Once
in *record
out *record
headers p2p.Headers
}
func newStream(in io.WriteCloser, out io.ReadCloser, cin, cout chan struct{}) *stream {
return &stream{in: in, out: out, cin: cin, cout: cout}
func newStream(in, out *record) *stream {
return &stream{in: in, out: out}
}
func (s *stream) Read(p []byte) (int, error) {
......@@ -204,47 +200,45 @@ func (s *stream) Headers() p2p.Headers {
}
func (s *stream) Close() error {
var e error
s.closeOnce.Do(func() {
if err := s.in.Close(); err != nil {
e = err
return
}
if err := s.out.Close(); err != nil {
e = err
return
}
close(s.cin)
})
return e
return s.in.Close()
}
func (s *stream) FullClose() error {
if err := s.Close(); err != nil {
_ = s.Reset()
return err
}
select {
case <-s.cout:
case <-time.After(fullCloseTimeout):
return ErrStreamFullcloseTimeout
}
waitStart := time.Now()
return nil
for {
if s.out.Closed() {
return nil
}
if time.Since(waitStart) >= fullCloseTimeout {
return ErrStreamFullcloseTimeout
}
time.Sleep(10 * time.Millisecond)
}
}
func (s *stream) Reset() error {
//todo: :implement appropriately after all protocols are migrated and tested
return s.Close()
func (s *stream) Reset() (err error) {
if err := s.in.Close(); err != nil {
_ = s.out.Close()
return err
}
return s.out.Close()
}
type record struct {
b []byte
c int
closed bool
cond *sync.Cond
b []byte
c int
closed bool
closeMu sync.RWMutex
cond *sync.Cond
}
func newRecord() *record {
......@@ -257,7 +251,7 @@ func (r *record) Read(p []byte) (n int, err error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
for r.c == len(r.b) && !r.closed {
for r.c == len(r.b) && !r.Closed() {
r.cond.Wait()
}
end := r.c + len(p)
......@@ -266,15 +260,19 @@ func (r *record) Read(p []byte) (n int, err error) {
}
n = copy(p, r.b[r.c:end])
r.c += n
if r.closed {
if r.Closed() {
err = io.EOF
}
return n, err
}
func (r *record) Write(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
if r.Closed() {
return 0, ErrStreamClosed
}
defer r.cond.Signal()
......@@ -288,10 +286,19 @@ func (r *record) Close() error {
defer r.cond.Broadcast()
r.closeMu.Lock()
r.closed = true
r.closeMu.Unlock()
return nil
}
func (r *record) Closed() bool {
r.closeMu.RLock()
defer r.closeMu.RUnlock()
return r.closed
}
func (r *record) bytes() []byte {
r.cond.L.Lock()
defer r.cond.L.Unlock()
......
......@@ -314,6 +314,72 @@ func TestRecorder_closeAfterPartialWrite(t *testing.T) {
return err
}
// stream should be closed and write should return err
if _, err := rw.WriteString("expect err message"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err == nil {
return fmt.Errorf("expected err")
}
return nil
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"unterminated message",
"",
},
}, nil)
}
func TestRecorder_resetAfterPartialWrite(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(_ context.Context, peer p2p.Peer, stream p2p.Stream) error {
// just try to read the message that it terminated with
// a new line character
_, err := bufio.NewReader(stream).ReadString('\n')
return err
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, nil, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
// write a message, but do not write a new line character for handler to
// know that it is complete
if _, err := rw.WriteString("unterminated message"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
// deliberately reset the stream before the new line character is
// written to the stream
if err := stream.Reset(); err != nil {
return err
}
// stream should be closed and read should return EOF
if _, err := rw.ReadString('\n'); err != io.EOF {
return fmt.Errorf("got error %v, want %v", err, io.EOF)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment