Commit 8b7bab1a authored by Petar Radovic's avatar Petar Radovic Committed by GitHub

Reset in streamtest (#449)

* reset in streamtest
parent 0a40f19f
...@@ -19,6 +19,7 @@ import ( ...@@ -19,6 +19,7 @@ import (
var ( var (
ErrRecordsNotFound = errors.New("records not found") ErrRecordsNotFound = errors.New("records not found")
ErrStreamNotSupported = errors.New("stream not supported") ErrStreamNotSupported = errors.New("stream not supported")
ErrStreamClosed = errors.New("stream closed")
ErrStreamFullcloseTimeout = errors.New("fullclose timeout") ErrStreamFullcloseTimeout = errors.New("fullclose timeout")
fullCloseTimeout = fullCloseTimeoutDefault // timeout of fullclose fullCloseTimeout = fullCloseTimeoutDefault // timeout of fullclose
fullCloseTimeoutDefault = 5 * time.Second // default timeout used for helper function to reset timeout when changed fullCloseTimeoutDefault = 5 * time.Second // default timeout used for helper function to reset timeout when changed
...@@ -67,10 +68,8 @@ func (r *Recorder) SetProtocols(protocols ...p2p.ProtocolSpec) { ...@@ -67,10 +68,8 @@ func (r *Recorder) SetProtocols(protocols ...p2p.ProtocolSpec) {
func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Headers, protocolName, protocolVersion, streamName string) (p2p.Stream, error) { func (r *Recorder) NewStream(ctx context.Context, addr swarm.Address, h p2p.Headers, protocolName, protocolVersion, streamName string) (p2p.Stream, error) {
recordIn := newRecord() recordIn := newRecord()
recordOut := newRecord() recordOut := newRecord()
closedIn := make(chan struct{}) streamOut := newStream(recordIn, recordOut)
closedOut := make(chan struct{}) streamIn := newStream(recordOut, recordIn)
streamOut := newStream(recordIn, recordOut, closedIn, closedOut)
streamIn := newStream(recordOut, recordIn, closedOut, closedIn)
var handler p2p.HandlerFunc var handler p2p.HandlerFunc
var headler p2p.HeadlerFunc var headler p2p.HeadlerFunc
...@@ -179,16 +178,13 @@ func (r *Record) setErr(err error) { ...@@ -179,16 +178,13 @@ func (r *Record) setErr(err error) {
} }
type stream struct { type stream struct {
in io.WriteCloser in *record
out io.ReadCloser out *record
headers p2p.Headers headers p2p.Headers
cin chan struct{}
cout chan struct{}
closeOnce sync.Once
} }
func newStream(in io.WriteCloser, out io.ReadCloser, cin, cout chan struct{}) *stream { func newStream(in, out *record) *stream {
return &stream{in: in, out: out, cin: cin, cout: cout} return &stream{in: in, out: out}
} }
func (s *stream) Read(p []byte) (int, error) { func (s *stream) Read(p []byte) (int, error) {
...@@ -204,46 +200,44 @@ func (s *stream) Headers() p2p.Headers { ...@@ -204,46 +200,44 @@ func (s *stream) Headers() p2p.Headers {
} }
func (s *stream) Close() error { func (s *stream) Close() error {
var e error return s.in.Close()
s.closeOnce.Do(func() {
if err := s.in.Close(); err != nil {
e = err
return
}
if err := s.out.Close(); err != nil {
e = err
return
}
close(s.cin)
})
return e
} }
func (s *stream) FullClose() error { func (s *stream) FullClose() error {
if err := s.Close(); err != nil { if err := s.Close(); err != nil {
_ = s.Reset()
return err return err
} }
select { waitStart := time.Now()
case <-s.cout:
case <-time.After(fullCloseTimeout): for {
if s.out.Closed() {
return nil
}
if time.Since(waitStart) >= fullCloseTimeout {
return ErrStreamFullcloseTimeout return ErrStreamFullcloseTimeout
} }
return nil time.Sleep(10 * time.Millisecond)
}
} }
func (s *stream) Reset() error { func (s *stream) Reset() (err error) {
//todo: :implement appropriately after all protocols are migrated and tested if err := s.in.Close(); err != nil {
return s.Close() _ = s.out.Close()
return err
}
return s.out.Close()
} }
type record struct { type record struct {
b []byte b []byte
c int c int
closed bool closed bool
closeMu sync.RWMutex
cond *sync.Cond cond *sync.Cond
} }
...@@ -257,7 +251,7 @@ func (r *record) Read(p []byte) (n int, err error) { ...@@ -257,7 +251,7 @@ func (r *record) Read(p []byte) (n int, err error) {
r.cond.L.Lock() r.cond.L.Lock()
defer r.cond.L.Unlock() defer r.cond.L.Unlock()
for r.c == len(r.b) && !r.closed { for r.c == len(r.b) && !r.Closed() {
r.cond.Wait() r.cond.Wait()
} }
end := r.c + len(p) end := r.c + len(p)
...@@ -266,15 +260,19 @@ func (r *record) Read(p []byte) (n int, err error) { ...@@ -266,15 +260,19 @@ func (r *record) Read(p []byte) (n int, err error) {
} }
n = copy(p, r.b[r.c:end]) n = copy(p, r.b[r.c:end])
r.c += n r.c += n
if r.closed { if r.Closed() {
err = io.EOF err = io.EOF
} }
return n, err return n, err
} }
func (r *record) Write(p []byte) (int, error) { func (r *record) Write(p []byte) (int, error) {
r.cond.L.Lock() r.cond.L.Lock()
defer r.cond.L.Unlock() defer r.cond.L.Unlock()
if r.Closed() {
return 0, ErrStreamClosed
}
defer r.cond.Signal() defer r.cond.Signal()
...@@ -288,10 +286,19 @@ func (r *record) Close() error { ...@@ -288,10 +286,19 @@ func (r *record) Close() error {
defer r.cond.Broadcast() defer r.cond.Broadcast()
r.closeMu.Lock()
r.closed = true r.closed = true
r.closeMu.Unlock()
return nil return nil
} }
func (r *record) Closed() bool {
r.closeMu.RLock()
defer r.closeMu.RUnlock()
return r.closed
}
func (r *record) bytes() []byte { func (r *record) bytes() []byte {
r.cond.L.Lock() r.cond.L.Lock()
defer r.cond.L.Unlock() defer r.cond.L.Unlock()
......
...@@ -314,6 +314,72 @@ func TestRecorder_closeAfterPartialWrite(t *testing.T) { ...@@ -314,6 +314,72 @@ func TestRecorder_closeAfterPartialWrite(t *testing.T) {
return err return err
} }
// stream should be closed and write should return err
if _, err := rw.WriteString("expect err message"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err == nil {
return fmt.Errorf("expected err")
}
return nil
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"unterminated message",
"",
},
}, nil)
}
func TestRecorder_resetAfterPartialWrite(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(_ context.Context, peer p2p.Peer, stream p2p.Stream) error {
// just try to read the message that it terminated with
// a new line character
_, err := bufio.NewReader(stream).ReadString('\n')
return err
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, nil, testProtocolName, testProtocolVersion, testStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
// write a message, but do not write a new line character for handler to
// know that it is complete
if _, err := rw.WriteString("unterminated message"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
// deliberately reset the stream before the new line character is
// written to the stream
if err := stream.Reset(); err != nil {
return err
}
// stream should be closed and read should return EOF // stream should be closed and read should return EOF
if _, err := rw.ReadString('\n'); err != io.EOF { if _, err := rw.ReadString('\n'); err != io.EOF {
return fmt.Errorf("got error %v, want %v", err, io.EOF) return fmt.Errorf("got error %v, want %v", err, io.EOF)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment