Commit 4f17e337 authored by Janos Guljas's avatar Janos Guljas

support multiple protocols and records in p2p mock Recorder

parent f34a13b3
...@@ -2,6 +2,7 @@ package mock ...@@ -2,6 +2,7 @@ package mock
import ( import (
"context" "context"
"fmt"
"io" "io"
"sync" "sync"
...@@ -10,33 +11,75 @@ import ( ...@@ -10,33 +11,75 @@ import (
) )
type Recorder struct { type Recorder struct {
in, out *record records map[string][]Record
handler func(p2p.Peer) recordsMu sync.Mutex
protocols []p2p.ProtocolSpec
} }
func NewRecorder(handler func(p2p.Peer)) *Recorder { func NewRecorder(protocols ...p2p.ProtocolSpec) *Recorder {
return &Recorder{ return &Recorder{
in: newRecord(), records: make(map[string][]Record),
out: newRecord(), protocols: protocols,
handler: handler,
} }
} }
func (r *Recorder) NewStream(ctx context.Context, peerID, protocolName, streamName, version string) (p2p.Stream, error) { func (r *Recorder) NewStream(_ context.Context, peerID, protocolName, streamName, version string) (p2p.Stream, error) {
out := newStream(r.in, r.out) recordIn := newRecord()
in := newStream(r.out, r.in) recordOut := newRecord()
go r.handler(p2p.Peer{ streamOut := newStream(recordIn, recordOut)
streamIn := newStream(recordOut, recordIn)
var handler func(p2p.Peer)
for _, p := range r.protocols {
if p.Name == protocolName {
for _, s := range p.StreamSpecs {
if s.Name == streamName && s.Version == version {
handler = s.Handler
}
}
}
}
if handler == nil {
return nil, fmt.Errorf("unsupported protocol stream %q %q %q", protocolName, streamName, version)
}
go handler(p2p.Peer{
Addr: ma.StringCast(peerID), Addr: ma.StringCast(peerID),
Stream: in, Stream: streamIn,
}) })
return out, nil
id := peerID + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
r.records[id] = append(r.records[id], Record{in: recordIn, out: recordOut})
return streamOut, nil
} }
func (r *Recorder) In() []byte { func (r *Recorder) Records(peerID, protocolName, streamName, version string) ([]Record, error) {
id := peerID + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
records, ok := r.records[id]
if !ok {
return nil, fmt.Errorf("records not found for %q %q %q %q", peerID, protocolName, streamName, version)
}
return records, nil
}
type Record struct {
in *record
out *record
}
func (r *Record) In() []byte {
return r.in.bytes() return r.in.bytes()
} }
func (r *Recorder) Out() []byte { func (r *Record) Out() []byte {
return r.out.bytes() return r.out.bytes()
} }
...@@ -120,5 +163,8 @@ func (r *record) Close() error { ...@@ -120,5 +163,8 @@ func (r *record) Close() error {
} }
func (r *record) bytes() []byte { func (r *record) bytes() []byte {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.b return r.b
} }
...@@ -16,14 +16,15 @@ func TestPing(t *testing.T) { ...@@ -16,14 +16,15 @@ func TestPing(t *testing.T) {
server := pingpong.New(nil) server := pingpong.New(nil)
// setup the stream recorder to record stream data // setup the stream recorder to record stream data
recorder := mock.NewRecorder(server.Handler) recorder := mock.NewRecorder(server.Protocol())
// create a pingpong client that will do pinging // create a pingpong client that will do pinging
client := pingpong.New(recorder) client := pingpong.New(recorder)
// ping // ping
peerID := "/p2p/QmZt98UimwpW9ptJumKTq7B7t3FzNfyoWVNGcd8PFCd7XS"
greetings := []string{"hey", "there", "fella"} greetings := []string{"hey", "there", "fella"}
rtt, err := client.Ping(context.Background(), "/p2p/QmZt98UimwpW9ptJumKTq7B7t3FzNfyoWVNGcd8PFCd7XS", greetings...) rtt, err := client.Ping(context.Background(), peerID, greetings...)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -33,10 +34,20 @@ func TestPing(t *testing.T) { ...@@ -33,10 +34,20 @@ func TestPing(t *testing.T) {
t.Errorf("invalid RTT value %v", rtt) t.Errorf("invalid RTT value %v", rtt)
} }
// get a record for this stream
records, err := recorder.Records(peerID, "pingpong", "pingpong", "1.0.0")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
// validate received ping greetings from the client // validate received ping greetings from the client
wantGreetings := greetings wantGreetings := greetings
messages, err := protobuf.ReadMessages( messages, err := protobuf.ReadMessages(
bytes.NewReader(recorder.In()), bytes.NewReader(record.In()),
func() protobuf.Message { return new(pingpong.Ping) }, func() protobuf.Message { return new(pingpong.Ping) },
) )
if err != nil { if err != nil {
...@@ -56,7 +67,7 @@ func TestPing(t *testing.T) { ...@@ -56,7 +67,7 @@ func TestPing(t *testing.T) {
wantResponses = append(wantResponses, "{"+g+"}") wantResponses = append(wantResponses, "{"+g+"}")
} }
messages, err = protobuf.ReadMessages( messages, err = protobuf.ReadMessages(
bytes.NewReader(recorder.Out()), bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pingpong.Pong) }, func() protobuf.Message { return new(pingpong.Pong) },
) )
if err != nil { if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment