Commit a31ea6e3 authored by Janos Guljas's avatar Janos Guljas

pull out streamtest package from p2p/mock and improve p2p/mock

parent 3839f3fa
......@@ -24,14 +24,12 @@ func TestConnect(t *testing.T) {
testErr := errors.New("test error")
client, cleanup := newTestServer(t, testServerOptions{
P2P: &mock.Service{
ConnectFunc: func(ctx context.Context, addr ma.Multiaddr) (string, error) {
P2P: mock.New(mock.WithConnectFunc(func(ctx context.Context, addr ma.Multiaddr) (string, error) {
if addr.String() == errorUnderlay {
return "", testErr
}
return overlay, nil
},
},
})),
})
defer cleanup()
......
......@@ -6,213 +6,58 @@ package mock
import (
"context"
"fmt"
"io"
"sync"
"github.com/ethersphere/bee/pkg/p2p"
ma "github.com/multiformats/go-multiaddr"
)
type Service struct {
AddProtocolFunc func(p2p.ProtocolSpec) error
ConnectFunc func(ctx context.Context, addr ma.Multiaddr) (overlay string, err error)
DisconnectFunc func(overlay string) error
addProtocolFunc func(p2p.ProtocolSpec) error
connectFunc func(ctx context.Context, addr ma.Multiaddr) (overlay string, err error)
disconnectFunc func(overlay string) error
}
func (s *Service) AddProtocol(spec p2p.ProtocolSpec) error {
return s.AddProtocolFunc(spec)
}
func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay string, err error) {
return s.ConnectFunc(ctx, addr)
}
func (s *Service) Disconnect(overlay string) error {
return s.DisconnectFunc(overlay)
}
type Recorder struct {
records map[string][]Record
recordsMu sync.Mutex
protocols []p2p.ProtocolSpec
middlewares []p2p.HandlerMiddleware
func WithAddProtocolFunc(f func(p2p.ProtocolSpec) error) Option {
return optionFunc(func(s *Service) {
s.addProtocolFunc = f
})
}
func WithProtocols(protocols ...p2p.ProtocolSpec) Option {
return optionFunc(func(r *Recorder) {
r.protocols = append(r.protocols, protocols...)
func WithConnectFunc(f func(ctx context.Context, addr ma.Multiaddr) (overlay string, err error)) Option {
return optionFunc(func(s *Service) {
s.connectFunc = f
})
}
func WithMiddlewares(middlewares ...p2p.HandlerMiddleware) Option {
return optionFunc(func(r *Recorder) {
r.middlewares = append(r.middlewares, middlewares...)
func WithDisconnectFunc(f func(overlay string) error) Option {
return optionFunc(func(s *Service) {
s.disconnectFunc = f
})
}
func NewRecorder(opts ...Option) *Recorder {
r := &Recorder{
records: make(map[string][]Record),
}
func New(opts ...Option) *Service {
s := new(Service)
for _, o := range opts {
o.apply(r)
}
return r
}
func (r *Recorder) NewStream(_ context.Context, overlay, protocolName, streamName, version string) (p2p.Stream, error) {
recordIn := newRecord()
recordOut := newRecord()
streamOut := newStream(recordIn, recordOut)
streamIn := newStream(recordOut, recordIn)
var handler p2p.HandlerFunc
for _, p := range r.protocols {
if p.Name == protocolName {
for _, s := range p.StreamSpecs {
if s.Name == streamName && s.Version == version {
handler = s.Handler
}
}
}
}
if handler == nil {
return nil, fmt.Errorf("unsupported protocol stream %q %q %q", protocolName, streamName, version)
}
for _, m := range r.middlewares {
handler = m(handler)
}
go func() {
if err := handler(p2p.Peer{Address: overlay}, streamIn); err != nil {
panic(err) // todo: store error and export error records for inspection
}
}()
id := overlay + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
r.records[id] = append(r.records[id], Record{in: recordIn, out: recordOut})
return streamOut, nil
}
func (r *Recorder) Records(peerID, protocolName, streamName, version string) ([]Record, error) {
id := peerID + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
records, ok := r.records[id]
if !ok {
return nil, fmt.Errorf("records not found for %q %q %q %q", peerID, protocolName, streamName, version)
}
return records, nil
}
type Record struct {
in *record
out *record
}
func (r *Record) In() []byte {
return r.in.bytes()
}
func (r *Record) Out() []byte {
return r.out.bytes()
}
type stream struct {
in io.WriteCloser
out io.ReadCloser
}
func newStream(in io.WriteCloser, out io.ReadCloser) *stream {
return &stream{in: in, out: out}
}
func (s *stream) Read(p []byte) (int, error) {
return s.out.Read(p)
}
func (s *stream) Write(p []byte) (int, error) {
return s.in.Write(p)
}
func (s *stream) Close() error {
if err := s.in.Close(); err != nil {
return err
o.apply(s)
}
if err := s.out.Close(); err != nil {
return err
}
return nil
}
type record struct {
b []byte
c int
closed bool
cond *sync.Cond
return s
}
func newRecord() *record {
return &record{
cond: sync.NewCond(new(sync.Mutex)),
}
}
func (r *record) Read(p []byte) (n int, err error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
for r.c == len(r.b) || r.closed {
r.cond.Wait()
}
end := r.c + len(p)
if end > len(r.b) {
end = len(r.b)
}
n = copy(p, r.b[r.c:end])
r.c += n
if r.closed {
err = io.EOF
}
return n, err
}
func (r *record) Write(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
defer r.cond.Signal()
r.b = append(r.b, p...)
return len(p), nil
func (s *Service) AddProtocol(spec p2p.ProtocolSpec) error {
return s.addProtocolFunc(spec)
}
func (r *record) Close() error {
r.cond.L.Lock()
defer r.cond.L.Unlock()
defer r.cond.Broadcast()
r.closed = true
return nil
func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (overlay string, err error) {
return s.connectFunc(ctx, addr)
}
func (r *record) bytes() []byte {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.b
func (s *Service) Disconnect(overlay string) error {
return s.disconnectFunc(overlay)
}
type Option interface {
apply(*Recorder)
apply(*Service)
}
type optionFunc func(*Recorder)
type optionFunc func(*Service)
func (f optionFunc) apply(r *Recorder) { f(r) }
func (f optionFunc) apply(r *Service) { f(r) }
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package streamtest
import (
"context"
"fmt"
"io"
"sync"
"github.com/ethersphere/bee/pkg/p2p"
)
type Recorder struct {
records map[string][]Record
recordsMu sync.Mutex
protocols []p2p.ProtocolSpec
middlewares []p2p.HandlerMiddleware
}
func WithProtocols(protocols ...p2p.ProtocolSpec) Option {
return optionFunc(func(r *Recorder) {
r.protocols = append(r.protocols, protocols...)
})
}
func WithMiddlewares(middlewares ...p2p.HandlerMiddleware) Option {
return optionFunc(func(r *Recorder) {
r.middlewares = append(r.middlewares, middlewares...)
})
}
func New(opts ...Option) *Recorder {
r := &Recorder{
records: make(map[string][]Record),
}
for _, o := range opts {
o.apply(r)
}
return r
}
func (r *Recorder) NewStream(_ context.Context, overlay, protocolName, streamName, version string) (p2p.Stream, error) {
recordIn := newRecord()
recordOut := newRecord()
streamOut := newStream(recordIn, recordOut)
streamIn := newStream(recordOut, recordIn)
var handler p2p.HandlerFunc
for _, p := range r.protocols {
if p.Name == protocolName {
for _, s := range p.StreamSpecs {
if s.Name == streamName && s.Version == version {
handler = s.Handler
}
}
}
}
if handler == nil {
return nil, fmt.Errorf("unsupported protocol stream %q %q %q", protocolName, streamName, version)
}
for _, m := range r.middlewares {
handler = m(handler)
}
go func() {
if err := handler(p2p.Peer{Address: overlay}, streamIn); err != nil {
panic(err) // todo: store error and export error records for inspection
}
}()
id := overlay + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
r.records[id] = append(r.records[id], Record{in: recordIn, out: recordOut})
return streamOut, nil
}
func (r *Recorder) Records(peerID, protocolName, streamName, version string) ([]Record, error) {
id := peerID + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock()
defer r.recordsMu.Unlock()
records, ok := r.records[id]
if !ok {
return nil, fmt.Errorf("records not found for %q %q %q %q", peerID, protocolName, streamName, version)
}
return records, nil
}
type Record struct {
in *record
out *record
}
func (r *Record) In() []byte {
return r.in.bytes()
}
func (r *Record) Out() []byte {
return r.out.bytes()
}
type stream struct {
in io.WriteCloser
out io.ReadCloser
}
func newStream(in io.WriteCloser, out io.ReadCloser) *stream {
return &stream{in: in, out: out}
}
func (s *stream) Read(p []byte) (int, error) {
return s.out.Read(p)
}
func (s *stream) Write(p []byte) (int, error) {
return s.in.Write(p)
}
func (s *stream) Close() error {
if err := s.in.Close(); err != nil {
return err
}
if err := s.out.Close(); err != nil {
return err
}
return nil
}
type record struct {
b []byte
c int
closed bool
cond *sync.Cond
}
func newRecord() *record {
return &record{
cond: sync.NewCond(new(sync.Mutex)),
}
}
func (r *record) Read(p []byte) (n int, err error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
for r.c == len(r.b) || r.closed {
r.cond.Wait()
}
end := r.c + len(p)
if end > len(r.b) {
end = len(r.b)
}
n = copy(p, r.b[r.c:end])
r.c += n
if r.closed {
err = io.EOF
}
return n, err
}
func (r *record) Write(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
defer r.cond.Signal()
r.b = append(r.b, p...)
return len(p), nil
}
func (r *record) Close() error {
r.cond.L.Lock()
defer r.cond.L.Unlock()
defer r.cond.Broadcast()
r.closed = true
return nil
}
func (r *record) bytes() []byte {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.b
}
type Option interface {
apply(*Recorder)
}
type optionFunc func(*Recorder)
func (f optionFunc) apply(r *Recorder) { f(r) }
......@@ -15,8 +15,8 @@ import (
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/mock"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pingpong"
"github.com/ethersphere/bee/pkg/pingpong/pb"
)
......@@ -30,9 +30,9 @@ func TestPing(t *testing.T) {
})
// setup the stream recorder to record stream data
recorder := mock.NewRecorder(
mock.WithProtocols(server.Protocol()),
mock.WithMiddlewares(func(f p2p.HandlerFunc) p2p.HandlerFunc {
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithMiddlewares(func(f p2p.HandlerFunc) p2p.HandlerFunc {
if runtime.GOOS == "windows" {
// windows has a bit lower time resolution
// so, slow down the handler with a middleware
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment