Commit 44adb6ec authored by Janos Guljas's avatar Janos Guljas

add protobuf ReadMessages function

parent 03a459a2
...@@ -2,12 +2,15 @@ package protobuf ...@@ -2,12 +2,15 @@ package protobuf
import ( import (
ggio "github.com/gogo/protobuf/io" ggio "github.com/gogo/protobuf/io"
"github.com/gogo/protobuf/proto"
"github.com/janos/bee/pkg/p2p" "github.com/janos/bee/pkg/p2p"
"io" "io"
) )
const delimitedReaderMaxSize = 128 * 1024 // max message size const delimitedReaderMaxSize = 128 * 1024 // max message size
type Message = proto.Message
func NewWriterAndReader(s p2p.Stream) (w ggio.Writer, r ggio.Reader) { func NewWriterAndReader(s p2p.Stream) (w ggio.Writer, r ggio.Reader) {
r = ggio.NewDelimitedReader(s, delimitedReaderMaxSize) r = ggio.NewDelimitedReader(s, delimitedReaderMaxSize)
w = ggio.NewDelimitedWriter(s) w = ggio.NewDelimitedWriter(s)
...@@ -21,3 +24,18 @@ func NewReader(r io.Reader) ggio.Reader { ...@@ -21,3 +24,18 @@ func NewReader(r io.Reader) ggio.Reader {
func NewWriter(w io.Writer) ggio.Writer { func NewWriter(w io.Writer) ggio.Writer {
return ggio.NewDelimitedWriter(w) return ggio.NewDelimitedWriter(w)
} }
func ReadMessages(r io.Reader, newMessage func() Message) (m []Message, err error) {
pr := NewReader(r)
for {
msg := newMessage()
if err := pr.ReadMsg(msg); err != nil {
if err == io.EOF {
break
}
return nil, err
}
m = append(m, msg)
}
return m, nil
}
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"testing" "testing"
"github.com/janos/bee/pkg/p2p/mock" "github.com/janos/bee/pkg/p2p/mock"
...@@ -34,39 +33,38 @@ func TestPing(t *testing.T) { ...@@ -34,39 +33,38 @@ func TestPing(t *testing.T) {
t.Errorf("invalid RTT value %v", rtt) t.Errorf("invalid RTT value %v", rtt)
} }
// validate received ping greetings // validate received ping greetings from the client
r := protobuf.NewReader(bytes.NewReader(streamer.In.Bytes())) wantGreetings := greetings
messages, err := protobuf.ReadMessages(
bytes.NewReader(streamer.In.Bytes()),
func() protobuf.Message { return new(pingpong.Ping) },
)
if err != nil {
t.Fatal(err)
}
var gotGreetings []string var gotGreetings []string
for { for _, m := range messages {
var ping pingpong.Ping gotGreetings = append(gotGreetings, m.(*pingpong.Ping).Greeting)
if err := r.ReadMsg(&ping); err != nil {
if err == io.EOF {
break
}
t.Fatal(err)
}
gotGreetings = append(gotGreetings, ping.Greeting)
} }
if fmt.Sprint(gotGreetings) != fmt.Sprint(greetings) { if fmt.Sprint(gotGreetings) != fmt.Sprint(wantGreetings) {
t.Errorf("got greetings %v, want %v", gotGreetings, greetings) t.Errorf("got greetings %v, want %v", gotGreetings, wantGreetings)
} }
// validate send pong responses by handler // validate sent pong responses by handler
r = protobuf.NewReader(bytes.NewReader(streamer.Out.Bytes()))
var wantResponses []string var wantResponses []string
for _, g := range greetings { for _, g := range greetings {
wantResponses = append(wantResponses, "{"+g+"}") wantResponses = append(wantResponses, "{"+g+"}")
} }
messages, err = protobuf.ReadMessages(
bytes.NewReader(streamer.Out.Bytes()),
func() protobuf.Message { return new(pingpong.Pong) },
)
if err != nil {
t.Fatal(err)
}
var gotResponses []string var gotResponses []string
for { for _, m := range messages {
var pong pingpong.Pong gotResponses = append(gotResponses, m.(*pingpong.Pong).Response)
if err := r.ReadMsg(&pong); err != nil {
if err == io.EOF {
break
}
t.Fatal(err)
}
gotResponses = append(gotResponses, pong.Response)
} }
if fmt.Sprint(gotResponses) != fmt.Sprint(wantResponses) { if fmt.Sprint(gotResponses) != fmt.Sprint(wantResponses) {
t.Errorf("got responses %v, want %v", gotResponses, wantResponses) t.Errorf("got responses %v, want %v", gotResponses, wantResponses)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment