Commit a31ea6e3 authored by Janos Guljas's avatar Janos Guljas

pull out streamtest package from p2p/mock and improve p2p/mock

parent 3839f3fa
...@@ -24,14 +24,12 @@ func TestConnect(t *testing.T) { ...@@ -24,14 +24,12 @@ func TestConnect(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
client, cleanup := newTestServer(t, testServerOptions{ client, cleanup := newTestServer(t, testServerOptions{
P2P: &mock.Service{ P2P: mock.New(mock.WithConnectFunc(func(ctx context.Context, addr ma.Multiaddr) (string, error) {
ConnectFunc: func(ctx context.Context, addr ma.Multiaddr) (string, error) {
if addr.String() == errorUnderlay { if addr.String() == errorUnderlay {
return "", testErr return "", testErr
} }
return overlay, nil return overlay, nil
}, })),
},
}) })
defer cleanup() defer cleanup()
......
...@@ -6,213 +6,58 @@ package mock ...@@ -6,213 +6,58 @@ package mock
import ( import (
"context" "context"
"fmt"
"io"
"sync"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
) )
type Service struct { type Service struct {
AddProtocolFunc func(p2p.ProtocolSpec) error addProtocolFunc func(p2p.ProtocolSpec) error
ConnectFunc func(ctx context.Context, addr ma.Multiaddr) (overlay string, err error) connectFunc func(ctx context.Context, addr ma.Multiaddr) (overlay string, err error)
DisconnectFunc func(overlay string) error disconnectFunc func(overlay string) error
} }
func (s *Service) AddProtocol(spec p2p.ProtocolSpec) error { func WithAddProtocolFunc(f func(p2p.ProtocolSpec) error) Option {
return s.AddProtocolFunc(spec) return optionFunc(func(s *Service) {
} s.addProtocolFunc = f
})
func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay string, err error) {
return s.ConnectFunc(ctx, addr)
}
func (s *Service) Disconnect(overlay string) error {
return s.DisconnectFunc(overlay)
}
type Recorder struct {
records map[string][]Record
recordsMu sync.Mutex
protocols []p2p.ProtocolSpec
middlewares []p2p.HandlerMiddleware
} }
func WithProtocols(protocols ...p2p.ProtocolSpec) Option { func WithConnectFunc(f func(ctx context.Context, addr ma.Multiaddr) (overlay string, err error)) Option {
return optionFunc(func(r *Recorder) { return optionFunc(func(s *Service) {
r.protocols = append(r.protocols, protocols...) s.connectFunc = f
}) })
} }
func WithMiddlewares(middlewares ...p2p.HandlerMiddleware) Option { func WithDisconnectFunc(f func(overlay string) error) Option {
return optionFunc(func(r *Recorder) { return optionFunc(func(s *Service) {
r.middlewares = append(r.middlewares, middlewares...) s.disconnectFunc = f
}) })
} }
func NewRecorder(opts ...Option) *Recorder { func New(opts ...Option) *Service {
r := &Recorder{ s := new(Service)
records: make(map[string][]Record),
}
for _, o := range opts { for _, o := range opts {
o.apply(r) o.apply(s)
}
return r
}
func (r *Recorder) NewStream(_ context.Context, overlay, protocolName, streamName, version string) (p2p.Stream, error) {
recordIn := newRecord()
recordOut := newRecord()
streamOut := newStream(recordIn, recordOut)
streamIn := newStream(recordOut, recordIn)
var handler p2p.HandlerFunc
for _, p := range r.protocols {
if p.Name == protocolName {
for _, s := range p.StreamSpecs {
if s.Name == streamName && s.Version == version {
handler = s.Handler
}
}
}
}
if handler == nil {
return nil, fmt.Errorf("unsupported protocol stream %q %q %q", protocolName, streamName, version)
}
for _, m := range r.middlewares {
handler = m(handler)
}
go func() {
if err := handler(p2p.Peer{Address: overlay}, streamIn); err != nil {
panic(err) // todo: store error and export error records for inspection
}
}()
id := overlay + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
r.records[id] = append(r.records[id], Record{in: recordIn, out: recordOut})
return streamOut, nil
}
func (r *Recorder) Records(peerID, protocolName, streamName, version string) ([]Record, error) {
id := peerID + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
records, ok := r.records[id]
if !ok {
return nil, fmt.Errorf("records not found for %q %q %q %q", peerID, protocolName, streamName, version)
}
return records, nil
}
type Record struct {
in *record
out *record
}
func (r *Record) In() []byte {
return r.in.bytes()
}
func (r *Record) Out() []byte {
return r.out.bytes()
}
type stream struct {
in io.WriteCloser
out io.ReadCloser
}
func newStream(in io.WriteCloser, out io.ReadCloser) *stream {
return &stream{in: in, out: out}
}
func (s *stream) Read(p []byte) (int, error) {
return s.out.Read(p)
}
func (s *stream) Write(p []byte) (int, error) {
return s.in.Write(p)
}
func (s *stream) Close() error {
if err := s.in.Close(); err != nil {
return err
} }
if err := s.out.Close(); err != nil { return s
return err
}
return nil
}
type record struct {
b []byte
c int
closed bool
cond *sync.Cond
} }
func newRecord() *record { func (s *Service) AddProtocol(spec p2p.ProtocolSpec) error {
return &record{ return s.addProtocolFunc(spec)
cond: sync.NewCond(new(sync.Mutex)),
}
}
func (r *record) Read(p []byte) (n int, err error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
for r.c == len(r.b) || r.closed {
r.cond.Wait()
}
end := r.c + len(p)
if end > len(r.b) {
end = len(r.b)
}
n = copy(p, r.b[r.c:end])
r.c += n
if r.closed {
err = io.EOF
}
return n, err
}
func (r *record) Write(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
defer r.cond.Signal()
r.b = append(r.b, p...)
return len(p), nil
} }
func (r *record) Close() error { func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay string, err error) {
r.cond.L.Lock() return s.connectFunc(ctx, addr)
defer r.cond.L.Unlock()
defer r.cond.Broadcast()
r.closed = true
return nil
} }
func (r *record) bytes() []byte { func (s *Service) Disconnect(overlay string) error {
r.cond.L.Lock() return s.disconnectFunc(overlay)
defer r.cond.L.Unlock()
return r.b
} }
type Option interface { type Option interface {
apply(*Recorder) apply(*Service)
} }
type optionFunc func(*Recorder) type optionFunc func(*Service)
func (f optionFunc) apply(r *Recorder) { f(r) } func (f optionFunc) apply(r *Service) { f(r) }
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package streamtest
import (
"context"
"fmt"
"io"
"sync"
"github.com/ethersphere/bee/pkg/p2p"
)
type Recorder struct {
records map[string][]Record
recordsMu sync.Mutex
protocols []p2p.ProtocolSpec
middlewares []p2p.HandlerMiddleware
}
func WithProtocols(protocols ...p2p.ProtocolSpec) Option {
return optionFunc(func(r *Recorder) {
r.protocols = append(r.protocols, protocols...)
})
}
func WithMiddlewares(middlewares ...p2p.HandlerMiddleware) Option {
return optionFunc(func(r *Recorder) {
r.middlewares = append(r.middlewares, middlewares...)
})
}
func New(opts ...Option) *Recorder {
r := &Recorder{
records: make(map[string][]Record),
}
for _, o := range opts {
o.apply(r)
}
return r
}
func (r *Recorder) NewStream(_ context.Context, overlay, protocolName, streamName, version string) (p2p.Stream, error) {
recordIn := newRecord()
recordOut := newRecord()
streamOut := newStream(recordIn, recordOut)
streamIn := newStream(recordOut, recordIn)
var handler p2p.HandlerFunc
for _, p := range r.protocols {
if p.Name == protocolName {
for _, s := range p.StreamSpecs {
if s.Name == streamName && s.Version == version {
handler = s.Handler
}
}
}
}
if handler == nil {
return nil, fmt.Errorf("unsupported protocol stream %q %q %q", protocolName, streamName, version)
}
for _, m := range r.middlewares {
handler = m(handler)
}
go func() {
if err := handler(p2p.Peer{Address: overlay}, streamIn); err != nil {
panic(err) // todo: store error and export error records for inspection
}
}()
id := overlay + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
r.records[id] = append(r.records[id], Record{in: recordIn, out: recordOut})
return streamOut, nil
}
func (r *Recorder) Records(peerID, protocolName, streamName, version string) ([]Record, error) {
id := peerID + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
records, ok := r.records[id]
if !ok {
return nil, fmt.Errorf("records not found for %q %q %q %q", peerID, protocolName, streamName, version)
}
return records, nil
}
type Record struct {
in *record
out *record
}
func (r *Record) In() []byte {
return r.in.bytes()
}
func (r *Record) Out() []byte {
return r.out.bytes()
}
type stream struct {
in io.WriteCloser
out io.ReadCloser
}
func newStream(in io.WriteCloser, out io.ReadCloser) *stream {
return &stream{in: in, out: out}
}
func (s *stream) Read(p []byte) (int, error) {
return s.out.Read(p)
}
func (s *stream) Write(p []byte) (int, error) {
return s.in.Write(p)
}
func (s *stream) Close() error {
if err := s.in.Close(); err != nil {
return err
}
if err := s.out.Close(); err != nil {
return err
}
return nil
}
type record struct {
b []byte
c int
closed bool
cond *sync.Cond
}
func newRecord() *record {
return &record{
cond: sync.NewCond(new(sync.Mutex)),
}
}
func (r *record) Read(p []byte) (n int, err error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
for r.c == len(r.b) || r.closed {
r.cond.Wait()
}
end := r.c + len(p)
if end > len(r.b) {
end = len(r.b)
}
n = copy(p, r.b[r.c:end])
r.c += n
if r.closed {
err = io.EOF
}
return n, err
}
func (r *record) Write(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
defer r.cond.Signal()
r.b = append(r.b, p...)
return len(p), nil
}
func (r *record) Close() error {
r.cond.L.Lock()
defer r.cond.L.Unlock()
defer r.cond.Broadcast()
r.closed = true
return nil
}
func (r *record) bytes() []byte {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.b
}
type Option interface {
apply(*Recorder)
}
type optionFunc func(*Recorder)
func (f optionFunc) apply(r *Recorder) { f(r) }
...@@ -15,8 +15,8 @@ import ( ...@@ -15,8 +15,8 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/mock"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pingpong" "github.com/ethersphere/bee/pkg/pingpong"
"github.com/ethersphere/bee/pkg/pingpong/pb" "github.com/ethersphere/bee/pkg/pingpong/pb"
) )
...@@ -30,9 +30,9 @@ func TestPing(t *testing.T) { ...@@ -30,9 +30,9 @@ func TestPing(t *testing.T) {
}) })
// setup the stream recorder to record stream data // setup the stream recorder to record stream data
recorder := mock.NewRecorder( recorder := streamtest.New(
mock.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
mock.WithMiddlewares(func(f p2p.HandlerFunc) p2p.HandlerFunc { streamtest.WithMiddlewares(func(f p2p.HandlerFunc) p2p.HandlerFunc {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
// windows has a bit lower time resolution // windows has a bit lower time resolution
// so, slow down the handler with a middleware // so, slow down the handler with a middleware
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment