Commit 334b9c9f authored by Janos Guljas's avatar Janos Guljas

add streamtest test for middlewares

parent 86a27534
......@@ -67,8 +67,8 @@ func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName
if handler == nil {
return nil, ErrStreamNotSupported
}
for _, m := range r.middlewares {
handler = m(handler)
for i := len(r.middlewares) - 1; i >= 0; i-- {
handler = r.middlewares[i](handler)
}
record := &Record{in: recordIn, out: recordOut}
go func() {
......
......@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"strings"
"testing"
......@@ -102,27 +103,12 @@ func TestRecorder(t *testing.T) {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want 1", l)
}
record := records[0]
if err := record.Err(); err != nil {
t.Fatalf("got error from record %v, want nil", err)
}
wantIn := "What is your name?\nWhat is your quest?\nWhat is your favorite color?\n"
gotIn := string(record.In())
if gotIn != wantIn {
t.Errorf("got stream in %q, want %q", gotIn, wantIn)
}
wantOut := "Sir Lancelot of Camelot\nTo seek the Holy Grail.\nBlue.\n"
gotOut := string(record.Out())
if gotOut != wantOut {
t.Errorf("got stream out %q, want %q", gotOut, wantOut)
}
testRecords(t, records, [][2]string{
{
"What is your name?\nWhat is your quest?\nWhat is your favorite color?\n",
"Sir Lancelot of Camelot\nTo seek the Holy Grail.\nBlue.\n",
},
})
}
func TestRecorder_errStreamNotSupported(t *testing.T) {
......@@ -188,27 +174,117 @@ func TestRecorder_closeAfterPartialWrite(t *testing.T) {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want 1", l)
}
testRecords(t, records, [][2]string{
{
"unterminated message",
"",
},
})
}
func TestRecorder_withMiddlewares(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.ReadString('\n'); err != nil {
return err
}
record := records[0]
if _, err := rw.WriteString("handler, "); err != nil {
return err
}
if err := rw.Flush(); err != nil {
return err
}
if err := record.Err(); err != nil {
t.Fatalf("got error from record %v, want nil", err)
return stream.Close()
}),
),
streamtest.WithMiddlewares(
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(peer p2p.Peer, stream p2p.Stream) error {
if _, err := stream.Write([]byte("pre 1, ")); err != nil {
return err
}
if err := h(peer, stream); err != nil {
return err
}
if _, err := stream.Write([]byte("post 1, ")); err != nil {
return err
}
return nil
}
},
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(peer p2p.Peer, stream p2p.Stream) error {
if _, err := stream.Write([]byte("pre 2, ")); err != nil {
return err
}
if err := h(peer, stream); err != nil {
return err
}
if _, err := stream.Write([]byte("post 2, ")); err != nil {
return err
}
return nil
}
},
),
streamtest.WithMiddlewares(
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(peer p2p.Peer, stream p2p.Stream) error {
if _, err := stream.Write([]byte("pre 3, ")); err != nil {
return err
}
if err := h(peer, stream); err != nil {
return err
}
if _, err := stream.Write([]byte("post 3, ")); err != nil {
return err
}
return nil
}
},
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) error {
stream, err := s.NewStream(ctx, address, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.WriteString("test\n"); err != nil {
return err
}
if err := rw.Flush(); err != nil {
return err
}
_, err = ioutil.ReadAll(rw)
return err
}
wantIn := "unterminated message"
gotIn := string(record.In())
if gotIn != wantIn {
t.Errorf("got stream in %q, want %q", gotIn, wantIn)
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
wantOut := ""
gotOut := string(record.Out())
if gotOut != wantOut {
t.Errorf("got stream out %q, want %q", gotOut, wantOut)
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"test\n",
"pre 1, pre 2, pre 3, handler, post 3, post 2, post 1, ",
},
})
}
const (
......@@ -229,3 +305,31 @@ func newTestProtocol(h p2p.HandlerFunc) p2p.ProtocolSpec {
},
}
}
func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string) {
lr := len(records)
lw := len(want)
if lr != lw {
t.Fatalf("got %v records, want %v", lr, lw)
}
for i := 0; i < lr; i++ {
record := records[i]
if err := record.Err(); err != nil {
t.Fatalf("got error from record %v, want nil", err)
}
w := want[i]
gotIn := string(record.In())
if gotIn != w[0] {
t.Errorf("got stream in %q, want %q", gotIn, w[0])
}
gotOut := string(record.Out())
if gotOut != w[1] {
t.Errorf("got stream out %q, want %q", gotOut, w[1])
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment