Commit 0a053f9e authored by Janos Guljas's avatar Janos Guljas

fix streamtest tests

parent 226b12d1
...@@ -73,7 +73,9 @@ func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName ...@@ -73,7 +73,9 @@ func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName
record := &Record{in: recordIn, out: recordOut} record := &Record{in: recordIn, out: recordOut}
go func() { go func() {
err := handler(p2p.Peer{Address: addr}, streamIn) err := handler(p2p.Peer{Address: addr}, streamIn)
record.setErr(err) if err != nil && err != io.EOF {
record.setErr(err)
}
}() }()
id := addr.String() + p2p.NewSwarmStreamName(protocolName, streamName, version) id := addr.String() + p2p.NewSwarmStreamName(protocolName, streamName, version)
......
...@@ -108,7 +108,7 @@ func TestRecorder(t *testing.T) { ...@@ -108,7 +108,7 @@ func TestRecorder(t *testing.T) {
"What is your name?\nWhat is your quest?\nWhat is your favorite color?\n", "What is your name?\nWhat is your quest?\nWhat is your favorite color?\n",
"Sir Lancelot of Camelot\nTo seek the Holy Grail.\nBlue.\n", "Sir Lancelot of Camelot\nTo seek the Holy Grail.\nBlue.\n",
}, },
}) }, nil)
} }
func TestRecorder_errStreamNotSupported(t *testing.T) { func TestRecorder_errStreamNotSupported(t *testing.T) {
...@@ -179,7 +179,7 @@ func TestRecorder_closeAfterPartialWrite(t *testing.T) { ...@@ -179,7 +179,7 @@ func TestRecorder_closeAfterPartialWrite(t *testing.T) {
"unterminated message", "unterminated message",
"", "",
}, },
}) }, nil)
} }
func TestRecorder_withMiddlewares(t *testing.T) { func TestRecorder_withMiddlewares(t *testing.T) {
...@@ -199,7 +199,7 @@ func TestRecorder_withMiddlewares(t *testing.T) { ...@@ -199,7 +199,7 @@ func TestRecorder_withMiddlewares(t *testing.T) {
return err return err
} }
return stream.Close() return nil
}), }),
), ),
streamtest.WithMiddlewares( streamtest.WithMiddlewares(
...@@ -247,6 +247,16 @@ func TestRecorder_withMiddlewares(t *testing.T) { ...@@ -247,6 +247,16 @@ func TestRecorder_withMiddlewares(t *testing.T) {
return nil return nil
} }
}, },
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(peer p2p.Peer, stream p2p.Stream) error {
if err := h(peer, stream); err != nil {
return err
}
// close stream after all previous middlewares wrote to it
// so that the receiving peer can get all the post messages
return stream.Close()
}
},
), ),
) )
...@@ -284,7 +294,65 @@ func TestRecorder_withMiddlewares(t *testing.T) { ...@@ -284,7 +294,65 @@ func TestRecorder_withMiddlewares(t *testing.T) {
"test\n", "test\n",
"pre 1, pre 2, pre 3, handler, post 3, post 2, post 1, ", "pre 1, pre 2, pre 3, handler, post 3, post 2, post 1, ",
}, },
}) }, nil)
}
func TestRecorder_recordErr(t *testing.T) {
testErr := errors.New("test error")
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
defer stream.Close()
if _, err := rw.ReadString('\n'); err != nil {
return err
}
if _, err := rw.WriteString("resp\n"); err != nil {
return err
}
if err := rw.Flush(); err != nil {
return err
}
return testErr
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
if _, err = stream.Write([]byte("req\n")); err != nil {
return err
}
_, err = ioutil.ReadAll(stream)
return err
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"req\n",
"resp\n",
},
}, testErr)
} }
const ( const (
...@@ -306,7 +374,9 @@ func newTestProtocol(h p2p.HandlerFunc) p2p.ProtocolSpec { ...@@ -306,7 +374,9 @@ func newTestProtocol(h p2p.HandlerFunc) p2p.ProtocolSpec {
} }
} }
func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string) { func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string, wantErr error) {
t.Helper()
lr := len(records) lr := len(records)
lw := len(want) lw := len(want)
if lr != lw { if lr != lw {
...@@ -316,8 +386,8 @@ func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string) { ...@@ -316,8 +386,8 @@ func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string) {
for i := 0; i < lr; i++ { for i := 0; i < lr; i++ {
record := records[i] record := records[i]
if err := record.Err(); err != nil { if err := record.Err(); err != wantErr {
t.Fatalf("got error from record %v, want nil", err) t.Fatalf("got error from record %v, want %v", err, wantErr)
} }
w := want[i] w := want[i]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment