Commit 0a053f9e authored by Janos Guljas's avatar Janos Guljas

fix streamtest tests

parent 226b12d1
......@@ -73,7 +73,9 @@ func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName
record := &Record{in: recordIn, out: recordOut}
go func() {
err := handler(p2p.Peer{Address: addr}, streamIn)
if err != nil && err != io.EOF {
record.setErr(err)
}
}()
id := addr.String() + p2p.NewSwarmStreamName(protocolName, streamName, version)
......
......@@ -108,7 +108,7 @@ func TestRecorder(t *testing.T) {
"What is your name?\nWhat is your quest?\nWhat is your favorite color?\n",
"Sir Lancelot of Camelot\nTo seek the Holy Grail.\nBlue.\n",
},
})
}, nil)
}
func TestRecorder_errStreamNotSupported(t *testing.T) {
......@@ -179,7 +179,7 @@ func TestRecorder_closeAfterPartialWrite(t *testing.T) {
"unterminated message",
"",
},
})
}, nil)
}
func TestRecorder_withMiddlewares(t *testing.T) {
......@@ -199,7 +199,7 @@ func TestRecorder_withMiddlewares(t *testing.T) {
return err
}
return stream.Close()
return nil
}),
),
streamtest.WithMiddlewares(
......@@ -247,6 +247,16 @@ func TestRecorder_withMiddlewares(t *testing.T) {
return nil
}
},
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(peer p2p.Peer, stream p2p.Stream) error {
if err := h(peer, stream); err != nil {
return err
}
// close stream after all previous middlewares wrote to it
// so that the receiving peer can get all the post messages
return stream.Close()
}
},
),
)
......@@ -284,7 +294,65 @@ func TestRecorder_withMiddlewares(t *testing.T) {
"test\n",
"pre 1, pre 2, pre 3, handler, post 3, post 2, post 1, ",
},
})
}, nil)
}
func TestRecorder_recordErr(t *testing.T) {
testErr := errors.New("test error")
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
defer stream.Close()
if _, err := rw.ReadString('\n'); err != nil {
return err
}
if _, err := rw.WriteString("resp\n"); err != nil {
return err
}
if err := rw.Flush(); err != nil {
return err
}
return testErr
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
if _, err = stream.Write([]byte("req\n")); err != nil {
return err
}
_, err = ioutil.ReadAll(stream)
return err
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"req\n",
"resp\n",
},
}, testErr)
}
const (
......@@ -306,7 +374,9 @@ func newTestProtocol(h p2p.HandlerFunc) p2p.ProtocolSpec {
}
}
func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string) {
func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string, wantErr error) {
t.Helper()
lr := len(records)
lw := len(want)
if lr != lw {
......@@ -316,8 +386,8 @@ func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string) {
for i := 0; i < lr; i++ {
record := records[i]
if err := record.Err(); err != nil {
t.Fatalf("got error from record %v, want nil", err)
if err := record.Err(); err != wantErr {
t.Fatalf("got error from record %v, want %v", err, wantErr)
}
w := want[i]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment