Commit c96ebdc2 authored by Petar Radovic's avatar Petar Radovic

Merge branch 'master' of github.com:ethersphere/bee

parents c71539e0 d2649013
...@@ -27,6 +27,11 @@ type Stream interface { ...@@ -27,6 +27,11 @@ type Stream interface {
io.Closer io.Closer
} }
// PeerSuggester suggests a peer to retrieve a chunk from
type PeerSuggester interface {
SuggestPeer(addr swarm.Address) (peerAddr swarm.Address, err error)
}
type ProtocolSpec struct { type ProtocolSpec struct {
Name string Name string
StreamSpecs []StreamSpec StreamSpecs []StreamSpec
......
...@@ -6,12 +6,17 @@ package streamtest ...@@ -6,12 +6,17 @@ package streamtest
import ( import (
"context" "context"
"fmt" "errors"
"github.com/ethersphere/bee/pkg/swarm"
"io" "io"
"sync" "sync"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/swarm"
)
var (
ErrRecordsNotFound = errors.New("records not found")
ErrStreamNotSupported = errors.New("stream not supported")
) )
type Recorder struct { type Recorder struct {
...@@ -60,10 +65,10 @@ func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName ...@@ -60,10 +65,10 @@ func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName
} }
} }
if handler == nil { if handler == nil {
return nil, fmt.Errorf("unsupported protocol stream %q %q %q", protocolName, streamName, version) return nil, ErrStreamNotSupported
} }
for _, m := range r.middlewares { for i := len(r.middlewares) - 1; i >= 0; i-- {
handler = m(handler) handler = r.middlewares[i](handler)
} }
record := &Record{in: recordIn, out: recordOut} record := &Record{in: recordIn, out: recordOut}
go func() { go func() {
...@@ -88,7 +93,7 @@ func (r *Recorder) Records(addr swarm.Address, protocolName, streamName, version ...@@ -88,7 +93,7 @@ func (r *Recorder) Records(addr swarm.Address, protocolName, streamName, version
records, ok := r.records[id] records, ok := r.records[id]
if !ok { if !ok {
return nil, fmt.Errorf("records not found for %q %q %q %q", addr, protocolName, streamName, version) return nil, ErrRecordsNotFound
} }
return records, nil return records, nil
} }
...@@ -166,7 +171,7 @@ func (r *record) Read(p []byte) (n int, err error) { ...@@ -166,7 +171,7 @@ func (r *record) Read(p []byte) (n int, err error) {
r.cond.L.Lock() r.cond.L.Lock()
defer r.cond.L.Unlock() defer r.cond.L.Unlock()
for r.c == len(r.b) || r.closed { for r.c == len(r.b) && !r.closed {
r.cond.Wait() r.cond.Wait()
} }
end := r.c + len(p) end := r.c + len(p)
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package streamtest_test
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"strings"
"testing"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestRecorder(t *testing.T) {
var answers = map[string]string{
"What is your name?": "Sir Lancelot of Camelot",
"What is your quest?": "To seek the Holy Grail.",
"What is your favorite color?": "Blue.",
"What is the air-speed velocity of an unladen swallow?": "What do you mean? An African or European swallow?",
}
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
for {
q, err := rw.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("read: %w", err)
}
q = strings.TrimRight(q, "\n")
if _, err = rw.WriteString(answers[q] + "\n"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
}
return nil
}),
),
)
ask := func(ctx context.Context, s p2p.Streamer, address swarm.Address, questions ...string) (answers []string, err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
return nil, fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
for _, q := range questions {
if _, err := rw.WriteString(q + "\n"); err != nil {
return nil, fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return nil, fmt.Errorf("flush: %w", err)
}
a, err := rw.ReadString('\n')
if err != nil {
return nil, fmt.Errorf("read: %w", err)
}
a = strings.TrimRight(a, "\n")
answers = append(answers, a)
}
return answers, nil
}
questions := []string{"What is your name?", "What is your quest?", "What is your favorite color?"}
aa, err := ask(context.Background(), recorder, swarm.ZeroAddress, questions...)
if err != nil {
t.Fatal(err)
}
for i, q := range questions {
if aa[i] != answers[q] {
t.Errorf("got answer %q for question %q, want %q", aa[i], q, answers[q])
}
}
_, err = recorder.Records(swarm.ZeroAddress, testProtocolName, "invalid stream name", testStreamVersion)
if err != streamtest.ErrRecordsNotFound {
t.Errorf("got error %v, want %v", err, streamtest.ErrRecordsNotFound)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"What is your name?\nWhat is your quest?\nWhat is your favorite color?\n",
"Sir Lancelot of Camelot\nTo seek the Holy Grail.\nBlue.\n",
},
})
}
func TestRecorder_errStreamNotSupported(t *testing.T) {
r := streamtest.New()
_, err := r.NewStream(context.Background(), swarm.ZeroAddress, "testing", "messages", "1.0.1")
if !errors.Is(err, streamtest.ErrStreamNotSupported) {
t.Fatalf("got error %v, want %v", err, streamtest.ErrStreamNotSupported)
}
}
func TestRecorder_closeAfterPartialWrite(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
// just try to read the message that it terminated with
// a new line character
_, err := bufio.NewReader(stream).ReadString('\n')
return err
}),
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) (err error) {
stream, err := s.NewStream(ctx, address, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
// write a message, but do not write a new line character for handler to
// know that it is complete
if _, err := rw.WriteString("unterminated message"); err != nil {
return fmt.Errorf("write: %w", err)
}
if err := rw.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
// deliberately close the stream before the new line character is
// written to the stream
if err := stream.Close(); err != nil {
return err
}
// stream should be closed and read should return EOF
if _, err := rw.ReadString('\n'); err != io.EOF {
return fmt.Errorf("got error %v, want %v", err, io.EOF)
}
return nil
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"unterminated message",
"",
},
})
}
func TestRecorder_withMiddlewares(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
newTestProtocol(func(peer p2p.Peer, stream p2p.Stream) error {
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.ReadString('\n'); err != nil {
return err
}
if _, err := rw.WriteString("handler, "); err != nil {
return err
}
if err := rw.Flush(); err != nil {
return err
}
return stream.Close()
}),
),
streamtest.WithMiddlewares(
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(peer p2p.Peer, stream p2p.Stream) error {
if _, err := stream.Write([]byte("pre 1, ")); err != nil {
return err
}
if err := h(peer, stream); err != nil {
return err
}
if _, err := stream.Write([]byte("post 1, ")); err != nil {
return err
}
return nil
}
},
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(peer p2p.Peer, stream p2p.Stream) error {
if _, err := stream.Write([]byte("pre 2, ")); err != nil {
return err
}
if err := h(peer, stream); err != nil {
return err
}
if _, err := stream.Write([]byte("post 2, ")); err != nil {
return err
}
return nil
}
},
),
streamtest.WithMiddlewares(
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(peer p2p.Peer, stream p2p.Stream) error {
if _, err := stream.Write([]byte("pre 3, ")); err != nil {
return err
}
if err := h(peer, stream); err != nil {
return err
}
if _, err := stream.Write([]byte("post 3, ")); err != nil {
return err
}
return nil
}
},
),
)
request := func(ctx context.Context, s p2p.Streamer, address swarm.Address) error {
stream, err := s.NewStream(ctx, address, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
rw := bufio.NewReadWriter(bufio.NewReader(stream), bufio.NewWriter(stream))
if _, err := rw.WriteString("test\n"); err != nil {
return err
}
if err := rw.Flush(); err != nil {
return err
}
_, err = ioutil.ReadAll(rw)
return err
}
err := request(context.Background(), recorder, swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
records, err := recorder.Records(swarm.ZeroAddress, testProtocolName, testStreamName, testStreamVersion)
if err != nil {
t.Fatal(err)
}
testRecords(t, records, [][2]string{
{
"test\n",
"pre 1, pre 2, pre 3, handler, post 3, post 2, post 1, ",
},
})
}
const (
testProtocolName = "testing"
testStreamName = "messages"
testStreamVersion = "1.0.1"
)
func newTestProtocol(h p2p.HandlerFunc) p2p.ProtocolSpec {
return p2p.ProtocolSpec{
Name: testProtocolName,
StreamSpecs: []p2p.StreamSpec{
{
Name: testStreamName,
Version: testStreamVersion,
Handler: h,
},
},
}
}
func testRecords(t *testing.T, records []*streamtest.Record, want [][2]string) {
lr := len(records)
lw := len(want)
if lr != lw {
t.Fatalf("got %v records, want %v", lr, lw)
}
for i := 0; i < lr; i++ {
record := records[i]
if err := record.Err(); err != nil {
t.Fatalf("got error from record %v, want nil", err)
}
w := want[i]
gotIn := string(record.In())
if gotIn != w[0] {
t.Errorf("got stream in %q, want %q", gotIn, w[0])
}
gotOut := string(record.Out())
if gotOut != w[1] {
t.Errorf("got stream out %q, want %q", gotOut, w[1])
}
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate sh -c "protoc -I . -I \"$(go list -f '{{ .Dir }}' -m github.com/gogo/protobuf)/protobuf\" --gogofaster_out=. retrieval.proto"
package pb
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: retrieval.proto
package pb
import (
fmt "fmt"
io "io"
math "math"
math_bits "math/bits"
proto "github.com/gogo/protobuf/proto"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type Request struct {
Addr []byte `protobuf:"bytes,1,opt,name=Addr,proto3" json:"Addr,omitempty"`
}
func (m *Request) Reset() { *m = Request{} }
func (m *Request) String() string { return proto.CompactTextString(m) }
func (*Request) ProtoMessage() {}
func (*Request) Descriptor() ([]byte, []int) {
return fileDescriptor_fcade0a564e5dcd4, []int{0}
}
func (m *Request) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Request) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Request.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Request) XXX_Merge(src proto.Message) {
xxx_messageInfo_Request.Merge(m, src)
}
func (m *Request) XXX_Size() int {
return m.Size()
}
func (m *Request) XXX_DiscardUnknown() {
xxx_messageInfo_Request.DiscardUnknown(m)
}
var xxx_messageInfo_Request proto.InternalMessageInfo
func (m *Request) GetAddr() []byte {
if m != nil {
return m.Addr
}
return nil
}
type Delivery struct {
Data []byte `protobuf:"bytes,1,opt,name=Data,proto3" json:"Data,omitempty"`
}
func (m *Delivery) Reset() { *m = Delivery{} }
func (m *Delivery) String() string { return proto.CompactTextString(m) }
func (*Delivery) ProtoMessage() {}
func (*Delivery) Descriptor() ([]byte, []int) {
return fileDescriptor_fcade0a564e5dcd4, []int{1}
}
func (m *Delivery) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Delivery) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Delivery.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Delivery) XXX_Merge(src proto.Message) {
xxx_messageInfo_Delivery.Merge(m, src)
}
func (m *Delivery) XXX_Size() int {
return m.Size()
}
func (m *Delivery) XXX_DiscardUnknown() {
xxx_messageInfo_Delivery.DiscardUnknown(m)
}
var xxx_messageInfo_Delivery proto.InternalMessageInfo
func (m *Delivery) GetData() []byte {
if m != nil {
return m.Data
}
return nil
}
func init() {
proto.RegisterType((*Request)(nil), "retrieval.Request")
proto.RegisterType((*Delivery)(nil), "retrieval.Delivery")
}
func init() { proto.RegisterFile("retrieval.proto", fileDescriptor_fcade0a564e5dcd4) }
var fileDescriptor_fcade0a564e5dcd4 = []byte{
// 127 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2f, 0x4a, 0x2d, 0x29,
0xca, 0x4c, 0x2d, 0x4b, 0xcc, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x84, 0x0b, 0x28,
0xc9, 0x72, 0xb1, 0x07, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, 0x08, 0x09, 0x71, 0xb1, 0x38, 0xa6,
0xa4, 0x14, 0x49, 0x30, 0x2a, 0x30, 0x6a, 0xf0, 0x04, 0x81, 0xd9, 0x4a, 0x72, 0x5c, 0x1c, 0x2e,
0xa9, 0x39, 0x99, 0x65, 0xa9, 0x45, 0x95, 0x20, 0x79, 0x97, 0xc4, 0x92, 0x44, 0x98, 0x3c, 0x88,
0xed, 0x24, 0x71, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f, 0x1e, 0xc9, 0x31, 0x4e,
0x78, 0x2c, 0xc7, 0x70, 0xe1, 0xb1, 0x1c, 0xc3, 0x8d, 0xc7, 0x72, 0x0c, 0x49, 0x6c, 0x60, 0xab,
0x8c, 0x01, 0x01, 0x00, 0x00, 0xff, 0xff, 0x00, 0x18, 0xd7, 0x30, 0x7d, 0x00, 0x00, 0x00,
}
func (m *Request) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Request) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Request) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.Addr) > 0 {
i -= len(m.Addr)
copy(dAtA[i:], m.Addr)
i = encodeVarintRetrieval(dAtA, i, uint64(len(m.Addr)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func (m *Delivery) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Delivery) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Delivery) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.Data) > 0 {
i -= len(m.Data)
copy(dAtA[i:], m.Data)
i = encodeVarintRetrieval(dAtA, i, uint64(len(m.Data)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func encodeVarintRetrieval(dAtA []byte, offset int, v uint64) int {
offset -= sovRetrieval(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
func (m *Request) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Addr)
if l > 0 {
n += 1 + l + sovRetrieval(uint64(l))
}
return n
}
func (m *Delivery) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Data)
if l > 0 {
n += 1 + l + sovRetrieval(uint64(l))
}
return n
}
func sovRetrieval(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
func sozRetrieval(x uint64) (n int) {
return sovRetrieval(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *Request) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowRetrieval
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Request: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Request: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Addr", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowRetrieval
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthRetrieval
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthRetrieval
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Addr = append(m.Addr[:0], dAtA[iNdEx:postIndex]...)
if m.Addr == nil {
m.Addr = []byte{}
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipRetrieval(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthRetrieval
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthRetrieval
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *Delivery) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowRetrieval
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Delivery: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Delivery: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Data", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowRetrieval
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthRetrieval
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthRetrieval
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Data = append(m.Data[:0], dAtA[iNdEx:postIndex]...)
if m.Data == nil {
m.Data = []byte{}
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipRetrieval(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthRetrieval
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthRetrieval
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipRetrieval(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
depth := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowRetrieval
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowRetrieval
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
case 1:
iNdEx += 8
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowRetrieval
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLengthRetrieval
}
iNdEx += length
case 3:
depth++
case 4:
if depth == 0 {
return 0, ErrUnexpectedEndOfGroupRetrieval
}
depth--
case 5:
iNdEx += 4
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
if iNdEx < 0 {
return 0, ErrInvalidLengthRetrieval
}
if depth == 0 {
return iNdEx, nil
}
}
return 0, io.ErrUnexpectedEOF
}
var (
ErrInvalidLengthRetrieval = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowRetrieval = fmt.Errorf("proto: integer overflow")
ErrUnexpectedEndOfGroupRetrieval = fmt.Errorf("proto: unexpected end of group")
)
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
syntax = "proto3";
package pb;
message Request {
bytes Addr = 1;
}
message Delivery {
bytes Data = 1;
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package retrieval
import (
"context"
"fmt"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
pb "github.com/ethersphere/bee/pkg/retrieval/pb"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
)
const (
protocolName = "retrieval"
streamName = "retrieval"
streamVersion = "1.0.0"
)
type Service struct {
streamer p2p.Streamer
peerSuggester p2p.PeerSuggester
storer storage.Storer
logger logging.Logger
}
type Options struct {
Streamer p2p.Streamer
PeerSuggester p2p.PeerSuggester
Storer storage.Storer
Logger logging.Logger
}
type Storer interface {
}
func New(o Options) *Service {
return &Service{
streamer: o.Streamer,
peerSuggester: o.PeerSuggester,
storer: o.Storer,
logger: o.Logger,
}
}
func (s *Service) Protocol() p2p.ProtocolSpec {
return p2p.ProtocolSpec{
Name: protocolName,
StreamSpecs: []p2p.StreamSpec{
{
Name: streamName,
Version: streamVersion,
Handler: s.Handler,
},
},
}
}
func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data []byte, err error) {
peerID, err := s.peerSuggester.SuggestPeer(addr)
if err != nil {
return nil, err
}
stream, err := s.streamer.NewStream(ctx, peerID, protocolName, streamName, streamVersion)
if err != nil {
return nil, fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
w, r := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsg(&pb.Request{
Addr: addr.Bytes(),
}); err != nil {
return nil, fmt.Errorf("stream write: %w", err)
}
var d pb.Delivery
if err := r.ReadMsg(&d); err != nil {
return nil, err
}
return d.Data, nil
}
func (s *Service) Handler(p p2p.Peer, stream p2p.Stream) error {
w, r := protobuf.NewWriterAndReader(stream)
defer stream.Close()
var req pb.Request
if err := r.ReadMsg(&req); err != nil {
return err
}
data, err := s.storer.Get(context.TODO(), swarm.NewAddress(req.Addr))
if err != nil {
return err
}
if err := w.WriteMsg(&pb.Delivery{
Data: data,
}); err != nil {
return err
}
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package retrieval_test
import (
"bytes"
"context"
"encoding/hex"
"io/ioutil"
"testing"
"time"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/retrieval"
pb "github.com/ethersphere/bee/pkg/retrieval/pb"
storemock "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
)
var testTimeout = 5 * time.Second
// TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
mockStorer := storemock.NewStorer()
reqAddr, err := swarm.ParseHexAddress("00112233")
if err != nil {
t.Fatal(err)
}
reqData := []byte("data data data")
// put testdata in the mock store of the server
_ = mockStorer.Put(context.TODO(), reqAddr, reqData)
// create the server that will handle the request and will serve the response
server := retrieval.New(retrieval.Options{
Storer: mockStorer,
Logger: logger,
})
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
)
// client mock storer does not store any data at this point
// but should be checked at at the end of the test for the
// presence of the reqAddr key and value to ensure delivery
// was successful
clientMockStorer := storemock.NewStorer()
ps := mockPeerSuggester{spFunc: func(_ swarm.Address) (swarm.Address, error) {
v, err := swarm.ParseHexAddress("9ee7add7")
return v, err
}}
client := retrieval.New(retrieval.Options{
Streamer: recorder,
PeerSuggester: ps,
Storer: clientMockStorer,
Logger: logger,
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
v, err := client.RetrieveChunk(ctx, reqAddr)
if err != nil {
return
}
if !bytes.Equal(v, reqData) {
t.Fatalf("request and response data not equal. got %s want %s", v, reqData)
}
peerID, _ := ps.SuggestPeer(swarm.ZeroAddress)
records, err := recorder.Records(peerID, "retrieval", "retrieval", "1.0.0")
if err != nil {
t.Fatal(err)
}
if l := len(records); l != 1 {
t.Fatalf("got %v records, want %v", l, 1)
}
record := records[0]
messages, err := protobuf.ReadMessages(
bytes.NewReader(record.In()),
func() protobuf.Message { return new(pb.Request) },
)
if err != nil {
t.Fatal(err)
}
var reqs []string
for _, m := range messages {
reqs = append(reqs, hex.EncodeToString(m.(*pb.Request).Addr))
}
if len(reqs) != 1 {
t.Fatalf("got too many requests. want 1 got %d", len(reqs))
}
messages, err = protobuf.ReadMessages(
bytes.NewReader(record.Out()),
func() protobuf.Message { return new(pb.Delivery) },
)
if err != nil {
t.Fatal(err)
}
var gotDeliveries []string
for _, m := range messages {
gotDeliveries = append(gotDeliveries, string(m.(*pb.Delivery).Data))
}
if len(gotDeliveries) != 1 {
t.Fatalf("got too many deliveries. want 1 got %d", len(gotDeliveries))
}
}
type mockPeerSuggester struct {
spFunc func(swarm.Address) (swarm.Address, error)
}
func (v mockPeerSuggester) SuggestPeer(addr swarm.Address) (swarm.Address, error) {
return v.spFunc(addr)
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"context"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
)
type mockStorer struct {
store map[string][]byte
}
func NewStorer() storage.Storer {
s := &mockStorer{
store: make(map[string][]byte),
}
return s
}
func (m *mockStorer) Get(ctx context.Context, addr swarm.Address) (data []byte, err error) {
v, has := m.store[addr.String()]
if !has {
return nil, storage.ErrNotFound
}
return v, nil
}
func (m *mockStorer) Put(ctx context.Context, addr swarm.Address, data []byte) error {
m.store[addr.String()] = data
return nil
}
package mock_test
import (
"bytes"
"context"
"testing"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestMockStorer(t *testing.T) {
s := mock.NewStorer()
keyFound, err := swarm.ParseHexAddress("aabbcc")
if err != nil {
t.Fatal(err)
}
keyNotFound, err := swarm.ParseHexAddress("bbccdd")
if err != nil {
t.Fatal(err)
}
valueFound := []byte("data data data")
ctx := context.Background()
if _, err := s.Get(ctx, keyFound); err != storage.ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
if _, err := s.Get(ctx, keyNotFound); err != storage.ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
if err := s.Put(ctx, keyFound, valueFound); err != nil {
t.Fatalf("expected not error but got: %v", err)
}
if data, err := s.Get(ctx, keyFound); err != nil {
t.Fatalf("expected not error but got: %v", err)
} else {
if !bytes.Equal(data, valueFound) {
t.Fatalf("expected value %s but got %s", valueFound, data)
}
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.package storage
package storage
import (
"context"
"errors"
"github.com/ethersphere/bee/pkg/swarm"
)
var ErrNotFound = errors.New("storage: not found")
type Storer interface {
Get(ctx context.Context, addr swarm.Address) (data []byte, err error)
Put(ctx context.Context, addr swarm.Address, data []byte) error
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment