Commit 1f11fab5 authored by acud's avatar acud Committed by GitHub

pull sync: initial version (#193)

* pullsync, puller, pullstorage: initial pull sync protocol implementation
parent 3ab2f7a9
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bitvector
import (
"errors"
)
var errInvalidLength = errors.New("invalid length")
// BitVector is a convenience object for manipulating and representing bit vectors
type BitVector struct {
len int
b []byte
}
// New creates a new bit vector with the given length
func New(l int) (*BitVector, error) {
return NewFromBytes(make([]byte, l/8+1), l)
}
// NewFromBytes creates a bit vector from the passed byte slice.
//
// Leftmost bit in byte slice becomes leftmost bit in bit vector
func NewFromBytes(b []byte, l int) (*BitVector, error) {
if l <= 0 {
return nil, errInvalidLength
}
if len(b)*8 < l {
return nil, errInvalidLength
}
return &BitVector{
len: l,
b: b,
}, nil
}
// Get gets the corresponding bit, counted from left to right
func (bv *BitVector) Get(i int) bool {
bi := i / 8
return bv.b[bi]&(0x1<<uint(i%8)) != 0
}
// Set sets the bit corresponding to the index in the bitvector, counted from left to right
func (bv *BitVector) set(i int, v bool) {
bi := i / 8
cv := bv.Get(i)
if cv != v {
bv.b[bi] ^= 0x1 << uint8(i%8)
}
}
// Set sets the bit corresponding to the index in the bitvector, counted from left to right
func (bv *BitVector) Set(i int) {
bv.set(i, true)
}
// Unset UNSETS the corresponding bit, counted from left to right
func (bv *BitVector) Unset(i int) {
bv.set(i, false)
}
// SetBytes sets all bits in the bitvector that are set in the argument
//
// The argument must be the same as the bitvector length
func (bv *BitVector) SetBytes(bs []byte) error {
if len(bs) != bv.len {
return errors.New("invalid length")
}
for i := 0; i < bv.len*8; i++ {
bi := i / 8
if bs[bi]&(0x01<<uint(i%8)) > 0 {
bv.set(i, true)
}
}
return nil
}
// UnsetBytes UNSETS all bits in the bitvector that are set in the argument
//
// The argument must be the same as the bitvector length
func (bv *BitVector) UnsetBytes(bs []byte) error {
if len(bs) != bv.len {
return errors.New("invalid length")
}
for i := 0; i < bv.len*8; i++ {
bi := i / 8
if bs[bi]&(0x01<<uint(i%8)) > 0 {
bv.set(i, false)
}
}
return nil
}
// String implements Stringer interface
func (bv *BitVector) String() (s string) {
for i := 0; i < bv.len*8; i++ {
if bv.Get(i) {
s += "1"
} else {
s += "0"
}
}
return s
}
// Bytes retrieves the underlying bytes of the bitvector
func (bv *BitVector) Bytes() []byte {
return bv.b
}
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package bitvector
import "testing"
// TestBitvectorNew checks that enforcements of argument length works in the constructors
func TestBitvectorNew(t *testing.T) {
_, err := New(0)
if err != errInvalidLength {
t.Errorf("expected err %v, got %v", errInvalidLength, err)
}
_, err = NewFromBytes(nil, 0)
if err != errInvalidLength {
t.Errorf("expected err %v, got %v", errInvalidLength, err)
}
_, err = NewFromBytes([]byte{0}, 9)
if err != errInvalidLength {
t.Errorf("expected err %v, got %v", errInvalidLength, err)
}
_, err = NewFromBytes(make([]byte, 8), 8)
if err != nil {
t.Error(err)
}
}
// TestBitvectorGetSet tests correctness of individual Set and Get commands
func TestBitvectorGetSet(t *testing.T) {
for _, length := range []int{
1,
2,
4,
8,
9,
15,
16,
} {
bv, err := New(length)
if err != nil {
t.Errorf("error for length %v: %v", length, err)
}
for i := 0; i < length; i++ {
if bv.Get(i) {
t.Errorf("expected false for element on index %v", i)
}
}
func() {
defer func() {
if err := recover(); err == nil {
t.Errorf("expecting panic")
}
}()
bv.Get(length + 8)
}()
for i := 0; i < length; i++ {
bv.Set(i)
for j := 0; j < length; j++ {
if j == i {
if !bv.Get(j) {
t.Errorf("element on index %v is not set to true", i)
}
} else {
if bv.Get(j) {
t.Errorf("element on index %v is not false", i)
}
}
}
bv.Unset(i)
if bv.Get(i) {
t.Errorf("element on index %v is not set to false", i)
}
}
}
}
// TestBitvectorNewFromBytesGet tests that bit vector is initialized correctly from underlying byte slice
func TestBitvectorNewFromBytesGet(t *testing.T) {
bv, err := NewFromBytes([]byte{8}, 8)
if err != nil {
t.Error(err)
}
if !bv.Get(3) {
t.Fatalf("element 3 is not set to true: state %08b", bv.b[0])
}
}
// TestBitVectorString tests that string representation of bit vector is correct
func TestBitVectorString(t *testing.T) {
b := []byte{0xa5, 0x81}
expect := "1010010110000001"
bv, err := NewFromBytes(b, 2)
if err != nil {
t.Fatal(err)
}
if bv.String() != expect {
t.Fatalf("bitvector string fail: got %s, expect %s", bv.String(), expect)
}
}
// TestBitVectorSetUnsetBytes tests that setting and unsetting by byte slice modifies the bit vector correctly
func TestBitVectorSetBytes(t *testing.T) {
b := []byte{0xff, 0xff}
cb := []byte{0xa5, 0x81}
expectUnset := "0101101001111110"
expectReset := "1111111111111111"
bv, err := NewFromBytes(b, 2)
if err != nil {
t.Fatal(err)
}
err = bv.UnsetBytes(cb)
if err != nil {
t.Fatal(err)
}
if bv.String() != expectUnset {
t.Fatalf("bitvector unset bytes fail: got %s, expect %s", bv.String(), expectUnset)
}
err = bv.SetBytes(cb)
if err != nil {
t.Fatal(err)
}
if bv.String() != expectReset {
t.Fatalf("bitvector reset bytes fail: got %s, expect %s", bv.String(), expectReset)
}
}
...@@ -58,6 +58,8 @@ type Kad struct { ...@@ -58,6 +58,8 @@ type Kad struct {
manageC chan struct{} // trigger the manage forever loop to connect to new peers manageC chan struct{} // trigger the manage forever loop to connect to new peers
waitNext map[string]retryInfo // sanction connections to a peer, key is overlay string and value is a retry information waitNext map[string]retryInfo // sanction connections to a peer, key is overlay string and value is a retry information
waitNextMu sync.Mutex // synchronize map waitNextMu sync.Mutex // synchronize map
peerSig []chan struct{}
peerSigMtx sync.Mutex
logger logging.Logger // logger logger logging.Logger // logger
quit chan struct{} // quit channel quit chan struct{} // quit channel
done chan struct{} // signal that `manage` has quit done chan struct{} // signal that `manage` has quit
...@@ -165,6 +167,8 @@ func (k *Kad) manage() { ...@@ -165,6 +167,8 @@ func (k *Kad) manage() {
k.logger.Debugf("connected to peer: %s old depth: %d new depth: %d", peer, currentDepth, k.NeighborhoodDepth()) k.logger.Debugf("connected to peer: %s old depth: %d new depth: %d", peer, currentDepth, k.NeighborhoodDepth())
k.notifyPeerSig()
select { select {
case <-k.quit: case <-k.quit:
return true, false, nil return true, false, nil
...@@ -347,6 +351,8 @@ func (k *Kad) Connected(ctx context.Context, addr swarm.Address) error { ...@@ -347,6 +351,8 @@ func (k *Kad) Connected(ctx context.Context, addr swarm.Address) error {
k.depth = k.recalcDepth() k.depth = k.recalcDepth()
k.depthMu.Unlock() k.depthMu.Unlock()
k.notifyPeerSig()
select { select {
case k.manageC <- struct{}{}: case k.manageC <- struct{}{}:
default: default:
...@@ -370,6 +376,22 @@ func (k *Kad) Disconnected(addr swarm.Address) { ...@@ -370,6 +376,22 @@ func (k *Kad) Disconnected(addr swarm.Address) {
case k.manageC <- struct{}{}: case k.manageC <- struct{}{}:
default: default:
} }
k.notifyPeerSig()
}
func (k *Kad) notifyPeerSig() {
k.peerSigMtx.Lock()
defer k.peerSigMtx.Unlock()
for _, c := range k.peerSig {
// Every peerSig channel has a buffer capacity of 1,
// so every receiver will get the signal even if the
// select statement has the default case to avoid blocking.
select {
case c <- struct{}{}:
default:
}
}
} }
// ClosestPeer returns the closest peer to a given address. // ClosestPeer returns the closest peer to a given address.
...@@ -408,6 +430,44 @@ func (k *Kad) ClosestPeer(addr swarm.Address) (swarm.Address, error) { ...@@ -408,6 +430,44 @@ func (k *Kad) ClosestPeer(addr swarm.Address) (swarm.Address, error) {
return closest, nil return closest, nil
} }
// EachPeer iterates from closest bin to farthest
func (k *Kad) EachPeer(f topology.EachPeerFunc) error {
return k.connectedPeers.EachBin(f)
}
// EachPeerRev iterates from farthest bin to closest
func (k *Kad) EachPeerRev(f topology.EachPeerFunc) error {
return k.connectedPeers.EachBinRev(f)
}
// SubscribePeersChange returns the channel that signals when the connected peers
// set changes. Returned function is safe to be called multiple times.
func (k *Kad) SubscribePeersChange() (c <-chan struct{}, unsubscribe func()) {
channel := make(chan struct{}, 1)
var closeOnce sync.Once
k.peerSigMtx.Lock()
defer k.peerSigMtx.Unlock()
k.peerSig = append(k.peerSig, channel)
unsubscribe = func() {
k.peerSigMtx.Lock()
defer k.peerSigMtx.Unlock()
for i, c := range k.peerSig {
if c == channel {
k.peerSig = append(k.peerSig[:i], k.peerSig[i+1:]...)
break
}
}
closeOnce.Do(func() { close(channel) })
}
return channel, unsubscribe
}
// NeighborhoodDepth returns the current Kademlia depth. // NeighborhoodDepth returns the current Kademlia depth.
func (k *Kad) NeighborhoodDepth() uint8 { func (k *Kad) NeighborhoodDepth() uint8 {
k.depthMu.RLock() k.depthMu.RLock()
......
...@@ -501,6 +501,104 @@ func TestClosestPeer(t *testing.T) { ...@@ -501,6 +501,104 @@ func TestClosestPeer(t *testing.T) {
} }
} }
func TestKademlia_SubscribePeersChange(t *testing.T) {
testSignal := func(t *testing.T, k *kademlia.Kad, c <-chan struct{}) {
t.Helper()
select {
case _, ok := <-c:
if !ok {
t.Error("closed signal channel")
}
case <-time.After(1 * time.Second):
t.Error("timeout")
}
}
t.Run("single subscription", func(t *testing.T) {
base, kad, ab, _, sg := newTestKademlia(nil, nil, nil)
c, u := kad.SubscribePeersChange()
defer u()
addr := test.RandomAddressAt(base, 9)
addOne(t, sg, kad, ab, addr)
testSignal(t, kad, c)
})
t.Run("single subscription, remove peer", func(t *testing.T) {
base, kad, ab, _, sg := newTestKademlia(nil, nil, nil)
c, u := kad.SubscribePeersChange()
defer u()
addr := test.RandomAddressAt(base, 9)
addOne(t, sg, kad, ab, addr)
testSignal(t, kad, c)
removeOne(kad, addr)
testSignal(t, kad, c)
})
t.Run("multiple subscriptions", func(t *testing.T) {
base, kad, ab, _, sg := newTestKademlia(nil, nil, nil)
c1, u1 := kad.SubscribePeersChange()
defer u1()
c2, u2 := kad.SubscribePeersChange()
defer u2()
for i := 0; i < 4; i++ {
addr := test.RandomAddressAt(base, i)
addOne(t, sg, kad, ab, addr)
}
testSignal(t, kad, c1)
testSignal(t, kad, c2)
})
t.Run("multiple changes", func(t *testing.T) {
base, kad, ab, _, sg := newTestKademlia(nil, nil, nil)
c, u := kad.SubscribePeersChange()
defer u()
for i := 0; i < 4; i++ {
addr := test.RandomAddressAt(base, i)
addOne(t, sg, kad, ab, addr)
}
testSignal(t, kad, c)
for i := 0; i < 4; i++ {
addr := test.RandomAddressAt(base, i)
addOne(t, sg, kad, ab, addr)
}
testSignal(t, kad, c)
})
t.Run("no depth change", func(t *testing.T) {
_, kad, _, _, _ := newTestKademlia(nil, nil, nil)
c, u := kad.SubscribePeersChange()
defer u()
select {
case _, ok := <-c:
if !ok {
t.Error("closed signal channel")
}
t.Error("signal received")
case <-time.After(1 * time.Second):
// all fine
}
})
}
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
var ( var (
_, kad, ab, _, signer = newTestKademlia(nil, nil, nil) _, kad, ab, _, signer = newTestKademlia(nil, nil, nil)
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"context"
"sync"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
)
type AddrTuple struct {
Addr swarm.Address // the peer address
PO uint8 // the po
}
func WithEachPeerRevCalls(addrs ...AddrTuple) Option {
return optionFunc(func(m *Mock) {
for _, a := range addrs {
a := a
m.eachPeerRev = append(m.eachPeerRev, a)
}
})
}
func WithDepth(d uint8) Option {
return optionFunc(func(m *Mock) {
m.depth = d
})
}
func WithDepthCalls(d ...uint8) Option {
return optionFunc(func(m *Mock) {
m.depthReplies = d
})
}
type Mock struct {
mtx sync.Mutex
peers []swarm.Address
eachPeerRev []AddrTuple
depth uint8
depthReplies []uint8
depthCalls int
trigs []chan struct{}
trigMtx sync.Mutex
}
func NewMockKademlia(o ...Option) *Mock {
m := &Mock{}
for _, v := range o {
v.apply(m)
}
return m
}
// AddPeer is called when a peer is added to the topology backlog
// for further processing by connectivity strategy.
func (m *Mock) AddPeer(ctx context.Context, addr swarm.Address) error {
panic("not implemented") // TODO: Implement
}
func (m *Mock) ClosestPeer(addr swarm.Address) (peerAddr swarm.Address, err error) {
panic("not implemented") // TODO: Implement
}
// EachPeer iterates from closest bin to farthest
func (m *Mock) EachPeer(f topology.EachPeerFunc) error {
m.mtx.Lock()
defer m.mtx.Unlock()
for i := len(m.peers) - 1; i > 0; i-- {
stop, _, err := f(m.peers[i], uint8(i))
if stop {
return nil
}
if err != nil {
return err
}
}
return nil
}
// EachPeerRev iterates from farthest bin to closest
func (m *Mock) EachPeerRev(f topology.EachPeerFunc) error {
m.mtx.Lock()
defer m.mtx.Unlock()
for _, v := range m.eachPeerRev {
stop, _, err := f(v.Addr, v.PO)
if stop {
return nil
}
if err != nil {
return err
}
}
return nil
}
func (m *Mock) NeighborhoodDepth() uint8 {
m.mtx.Lock()
defer m.mtx.Unlock()
m.depthCalls++
if len(m.depthReplies) > 0 {
return m.depthReplies[m.depthCalls]
}
return m.depth
}
// Connected is called when a peer dials in.
func (m *Mock) Connected(_ context.Context, addr swarm.Address) error {
m.mtx.Lock()
m.peers = append(m.peers, addr)
m.mtx.Unlock()
m.Trigger()
return nil
}
// Disconnected is called when a peer disconnects.
func (m *Mock) Disconnected(_ swarm.Address) {
m.Trigger()
}
func (m *Mock) SubscribePeersChange() (c <-chan struct{}, unsubscribe func()) {
channel := make(chan struct{}, 1)
var closeOnce sync.Once
m.trigMtx.Lock()
defer m.trigMtx.Unlock()
m.trigs = append(m.trigs, channel)
unsubscribe = func() {
m.trigMtx.Lock()
defer m.trigMtx.Unlock()
for i, c := range m.trigs {
if c == channel {
m.trigs = append(m.trigs[:i], m.trigs[i+1:]...)
break
}
}
closeOnce.Do(func() { close(channel) })
}
return channel, unsubscribe
}
func (m *Mock) Trigger() {
m.trigMtx.Lock()
defer m.trigMtx.Unlock()
for _, c := range m.trigs {
select {
case c <- struct{}{}:
default:
}
}
}
func (m *Mock) ResetPeers() {
m.mtx.Lock()
defer m.mtx.Unlock()
m.peers = nil
m.eachPeerRev = nil
}
func (m *Mock) Close() error {
panic("not implemented") // TODO: Implement
}
type Option interface {
apply(*Mock)
}
type optionFunc func(*Mock)
func (f optionFunc) apply(r *Mock) { f(r) }
...@@ -36,7 +36,7 @@ import ( ...@@ -36,7 +36,7 @@ import (
// function will terminate current and further iterations without errors, and also close the returned channel. // function will terminate current and further iterations without errors, and also close the returned channel.
// Make sure that you check the second returned parameter from the channel to stop iteration when its value // Make sure that you check the second returned parameter from the channel to stop iteration when its value
// is false. // is false.
func (db *DB) SubscribePull(ctx context.Context, bin uint8, since, until uint64) (c <-chan storage.Descriptor, stop func()) { func (db *DB) SubscribePull(ctx context.Context, bin uint8, since, until uint64) (c <-chan storage.Descriptor, closed <-chan struct{}, stop func()) {
db.metrics.SubscribePull.Inc() db.metrics.SubscribePull.Inc()
chunkDescriptors := make(chan storage.Descriptor) chunkDescriptors := make(chan storage.Descriptor)
...@@ -176,7 +176,7 @@ func (db *DB) SubscribePull(ctx context.Context, bin uint8, since, until uint64) ...@@ -176,7 +176,7 @@ func (db *DB) SubscribePull(ctx context.Context, bin uint8, since, until uint64)
} }
} }
return chunkDescriptors, stop return chunkDescriptors, db.close, stop
} }
// LastPullSubscriptionBinID returns chunk bin id of the latest Chunk // LastPullSubscriptionBinID returns chunk bin id of the latest Chunk
......
...@@ -54,7 +54,7 @@ func TestDB_SubscribePull_first(t *testing.T) { ...@@ -54,7 +54,7 @@ func TestDB_SubscribePull_first(t *testing.T) {
since := chunksInGivenBin + 1 since := chunksInGivenBin + 1
go func() { go func() {
ch, stop := db.SubscribePull(context.TODO(), bin, since, 0) ch, _, stop := db.SubscribePull(context.TODO(), bin, since, 0)
defer stop() defer stop()
chnk := <-ch chnk := <-ch
...@@ -100,7 +100,7 @@ func TestDB_SubscribePull(t *testing.T) { ...@@ -100,7 +100,7 @@ func TestDB_SubscribePull(t *testing.T) {
errChan := make(chan error) errChan := make(chan error)
for bin := uint8(0); bin <= swarm.MaxPO; bin++ { for bin := uint8(0); bin <= swarm.MaxPO; bin++ {
ch, stop := db.SubscribePull(ctx, bin, 0, 0) ch, _, stop := db.SubscribePull(ctx, bin, 0, 0)
defer stop() defer stop()
// receive and validate addresses from the subscription // receive and validate addresses from the subscription
...@@ -149,7 +149,7 @@ func TestDB_SubscribePull_multiple(t *testing.T) { ...@@ -149,7 +149,7 @@ func TestDB_SubscribePull_multiple(t *testing.T) {
// that all of them will write every address error to errChan // that all of them will write every address error to errChan
for j := 0; j < subsCount; j++ { for j := 0; j < subsCount; j++ {
for bin := uint8(0); bin <= swarm.MaxPO; bin++ { for bin := uint8(0); bin <= swarm.MaxPO; bin++ {
ch, stop := db.SubscribePull(ctx, bin, 0, 0) ch, _, stop := db.SubscribePull(ctx, bin, 0, 0)
defer stop() defer stop()
// receive and validate addresses from the subscription // receive and validate addresses from the subscription
...@@ -237,7 +237,7 @@ func TestDB_SubscribePull_since(t *testing.T) { ...@@ -237,7 +237,7 @@ func TestDB_SubscribePull_since(t *testing.T) {
if !ok { if !ok {
continue continue
} }
ch, stop := db.SubscribePull(ctx, bin, since, 0) ch, _, stop := db.SubscribePull(ctx, bin, since, 0)
defer stop() defer stop()
// receive and validate addresses from the subscription // receive and validate addresses from the subscription
...@@ -313,7 +313,7 @@ func TestDB_SubscribePull_until(t *testing.T) { ...@@ -313,7 +313,7 @@ func TestDB_SubscribePull_until(t *testing.T) {
if !ok { if !ok {
continue continue
} }
ch, stop := db.SubscribePull(ctx, bin, 0, until) ch, _, stop := db.SubscribePull(ctx, bin, 0, until)
defer stop() defer stop()
// receive and validate addresses from the subscription // receive and validate addresses from the subscription
...@@ -404,7 +404,7 @@ func TestDB_SubscribePull_sinceAndUntil(t *testing.T) { ...@@ -404,7 +404,7 @@ func TestDB_SubscribePull_sinceAndUntil(t *testing.T) {
// skip this bin from testing // skip this bin from testing
continue continue
} }
ch, stop := db.SubscribePull(ctx, bin, since, until) ch, _, stop := db.SubscribePull(ctx, bin, since, until)
defer stop() defer stop()
// receive and validate addresses from the subscription // receive and validate addresses from the subscription
...@@ -491,7 +491,7 @@ func TestDB_SubscribePull_rangeOnRemovedChunks(t *testing.T) { ...@@ -491,7 +491,7 @@ func TestDB_SubscribePull_rangeOnRemovedChunks(t *testing.T) {
// ignore this bin if it has only one chunk left // ignore this bin if it has only one chunk left
continue continue
} }
ch, stop := db.SubscribePull(ctx, bin, since, until) ch, _, stop := db.SubscribePull(ctx, bin, since, until)
defer stop() defer stop()
// the returned channel should be closed // the returned channel should be closed
......
...@@ -33,6 +33,9 @@ import ( ...@@ -33,6 +33,9 @@ import (
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p" "github.com/ethersphere/bee/pkg/p2p/libp2p"
"github.com/ethersphere/bee/pkg/pingpong" "github.com/ethersphere/bee/pkg/pingpong"
"github.com/ethersphere/bee/pkg/puller"
"github.com/ethersphere/bee/pkg/pullsync"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/pusher" "github.com/ethersphere/bee/pkg/pusher"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
...@@ -59,6 +62,8 @@ type Bee struct { ...@@ -59,6 +62,8 @@ type Bee struct {
localstoreCloser io.Closer localstoreCloser io.Closer
topologyCloser io.Closer topologyCloser io.Closer
pusherCloser io.Closer pusherCloser io.Closer
pullerCloser io.Closer
pullSyncCloser io.Closer
} }
type Options struct { type Options struct {
...@@ -258,6 +263,27 @@ func NewBee(o Options) (*Bee, error) { ...@@ -258,6 +263,27 @@ func NewBee(o Options) (*Bee, error) {
}) })
b.pusherCloser = pushSyncPusher b.pusherCloser = pushSyncPusher
pullStorage := pullstorage.New(storer)
pullSync := pullsync.New(pullsync.Options{
Streamer: p2ps,
Storage: pullStorage,
Logger: logger,
})
b.pullSyncCloser = pullSync
if err = p2ps.AddProtocol(pullSync.Protocol()); err != nil {
return nil, fmt.Errorf("pullsync protocol: %w", err)
}
puller := puller.New(puller.Options{
StateStore: stateStore,
Topology: topologyDriver,
PullSync: pullSync,
Logger: logger,
})
b.pullerCloser = puller
var apiService api.Service var apiService api.Service
if o.APIAddr != "" { if o.APIAddr != "" {
// API server // API server
...@@ -446,6 +472,14 @@ func (b *Bee) Shutdown(ctx context.Context) error { ...@@ -446,6 +472,14 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(fmt.Errorf("pusher: %w", err)) errs.add(fmt.Errorf("pusher: %w", err))
} }
if err := b.pullerCloser.Close(); err != nil {
return fmt.Errorf("puller: %w", err)
}
if err := b.pullSyncCloser.Close(); err != nil {
return fmt.Errorf("pull sync: %w", err)
}
b.p2pCancel() b.p2pCancel()
if err := b.p2pService.Close(); err != nil { if err := b.p2pService.Close(); err != nil {
errs.add(fmt.Errorf("p2p server: %w", err)) errs.add(fmt.Errorf("p2p server: %w", err))
......
package puller
var (
PeerIntervalKey = peerIntervalKey
Bins = &bins
ShallowBinPeers = &shallowBinPeers
IsSyncing = isSyncing
)
This diff is collapsed.
This diff is collapsed.
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pullsync
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pullsync
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"context"
"io"
"math"
"sync"
"github.com/ethersphere/bee/pkg/pullsync"
"github.com/ethersphere/bee/pkg/swarm"
)
var _ pullsync.Interface = (*PullSyncMock)(nil)
func WithCursors(v []uint64) Option {
return optionFunc(func(p *PullSyncMock) {
p.cursors = v
})
}
// WithAutoReply means that the pull syncer will automatically reply
// to incoming range requests with a top = from+limit.
// This is in order to force the requester to request a subsequent range.
func WithAutoReply() Option {
return optionFunc(func(p *PullSyncMock) {
p.autoReply = true
})
}
// WithLiveSyncBlock makes the protocol mock block on incoming live
// sync requests (identified by the math.MaxUint64 `to` field).
func WithLiveSyncBlock() Option {
return optionFunc(func(p *PullSyncMock) {
p.blockLiveSync = true
})
}
func WithLiveSyncReplies(r ...uint64) Option {
return optionFunc(func(p *PullSyncMock) {
p.liveSyncReplies = r
})
}
func WithLateSyncReply(r ...SyncReply) Option {
return optionFunc(func(p *PullSyncMock) {
p.lateReply = true
p.lateSyncReplies = r
})
}
const limit = 50
type SyncCall struct {
Peer swarm.Address
Bin uint8
From, To uint64
Live bool
}
type SyncReply struct {
bin uint8
from uint64
topmost uint64
block bool
}
func NewReply(bin uint8, from, top uint64, block bool) SyncReply {
return SyncReply{
bin: bin,
from: from,
topmost: top,
block: block,
}
}
type PullSyncMock struct {
mtx sync.Mutex
syncCalls []SyncCall
cursors []uint64
getCursorsPeers []swarm.Address
autoReply bool
blockLiveSync bool
liveSyncReplies []uint64
liveSyncCalls int
lateReply bool
lateCond *sync.Cond
lateChange bool
lateSyncReplies []SyncReply
quit chan struct{}
}
func NewPullSync(opts ...Option) *PullSyncMock {
s := &PullSyncMock{
lateCond: sync.NewCond(new(sync.Mutex)),
quit: make(chan struct{}),
}
for _, v := range opts {
v.apply(s)
}
return s
}
func (p *PullSyncMock) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, err error) {
isLive := to == math.MaxUint64
call := SyncCall{
Peer: peer,
Bin: bin,
From: from,
To: to,
Live: isLive,
}
p.mtx.Lock()
p.syncCalls = append(p.syncCalls, call)
p.mtx.Unlock()
if isLive && p.lateReply {
p.lateCond.L.Lock()
for !p.lateChange {
p.lateCond.Wait()
}
p.lateCond.L.Unlock()
select {
case <-p.quit:
return 0, context.Canceled
case <-ctx.Done():
return 0, ctx.Err()
default:
}
found := false
var sr SyncReply
p.mtx.Lock()
for i, v := range p.lateSyncReplies {
if v.bin == bin && v.from == from {
sr = v
found = true
p.lateSyncReplies = append(p.lateSyncReplies[:i], p.lateSyncReplies[i+1:]...)
}
}
p.mtx.Unlock()
if found {
if sr.block {
select {
case <-p.quit:
return 0, context.Canceled
case <-ctx.Done():
return 0, ctx.Err()
}
}
return sr.topmost, nil
}
panic("not found")
}
if isLive && p.blockLiveSync {
// don't respond, wait for quit
<-p.quit
return 0, io.EOF
}
if isLive && len(p.liveSyncReplies) > 0 {
if p.liveSyncCalls >= len(p.liveSyncReplies) {
<-p.quit
return
}
p.mtx.Lock()
v := p.liveSyncReplies[p.liveSyncCalls]
p.liveSyncCalls++
p.mtx.Unlock()
return v, nil
}
if p.autoReply {
t := from + limit - 1
// floor to the cursor
if t > to {
t = to
}
return t, nil
}
return to, nil
}
func (p *PullSyncMock) GetCursors(_ context.Context, peer swarm.Address) ([]uint64, error) {
p.mtx.Lock()
defer p.mtx.Unlock()
p.getCursorsPeers = append(p.getCursorsPeers, peer)
return p.cursors, nil
}
func (p *PullSyncMock) SyncCalls(peer swarm.Address) (res []SyncCall) {
p.mtx.Lock()
defer p.mtx.Unlock()
for _, v := range p.syncCalls {
if v.Peer.Equal(peer) && !v.Live {
res = append(res, v)
}
}
return res
}
func (p *PullSyncMock) LiveSyncCalls(peer swarm.Address) (res []SyncCall) {
p.mtx.Lock()
defer p.mtx.Unlock()
for _, v := range p.syncCalls {
if v.Peer.Equal(peer) && v.Live {
res = append(res, v)
}
}
return res
}
func (p *PullSyncMock) CursorsCalls(peer swarm.Address) bool {
p.mtx.Lock()
defer p.mtx.Unlock()
for _, v := range p.getCursorsPeers {
if v.Equal(peer) {
return true
}
}
return false
}
func (p *PullSyncMock) TriggerChange() {
p.lateCond.L.Lock()
p.lateChange = true
p.lateCond.L.Unlock()
p.lateCond.Broadcast()
}
func (p *PullSyncMock) Close() error {
close(p.quit)
p.lateCond.L.Lock()
p.lateChange = true
p.lateCond.L.Unlock()
p.lateCond.Broadcast()
return nil
}
type Option interface {
apply(*PullSyncMock)
}
type optionFunc func(*PullSyncMock)
func (f optionFunc) apply(r *PullSyncMock) { f(r) }
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate sh -c "protoc -I . -I \"$(go list -f '{{ .Dir }}' -m github.com/gogo/protobuf)/protobuf\" --gogofaster_out=. pullsync.proto"
package pb
This diff is collapsed.
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
syntax = "proto3";
package pullsync;
option go_package = "pb";
message Syn {}
message Ack {
repeated uint64 Cursors = 1;
}
message GetRange {
int32 Bin = 1;
uint64 From = 2;
uint64 To = 3;
}
message Offer {
uint64 Topmost = 1;
bytes Hashes = 2;
}
message Want {
bytes BitVector = 1;
}
message Delivery {
bytes Address = 1;
bytes Data = 2;
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pullstorage
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pullstorage
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"context"
"sync"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
)
var _ pullstorage.Storer = (*PullStorage)(nil)
type chunksResponse struct {
chunks []swarm.Address
topmost uint64
err error
}
// WithIntervalsResp mocks a desired response when calling IntervalChunks method.
// Different possible responses for subsequent responses in multi-call scenarios
// are possible (i.e. first call yields a,b,c, second call yields d,e,f).
// Mock maintains state of current call using chunksCalls counter.
func WithIntervalsResp(addrs []swarm.Address, top uint64, err error) Option {
return optionFunc(func(p *PullStorage) {
p.intervalChunksResponses = append(p.intervalChunksResponses, chunksResponse{chunks: addrs, topmost: top, err: err})
})
}
// WithChunks mocks the set of chunks that the store is aware of (used in Get and Has calls).
func WithChunks(chs ...swarm.Chunk) Option {
return optionFunc(func(p *PullStorage) {
for _, c := range chs {
p.chunks[c.Address().String()] = c.Data()
}
})
}
// WithEvilChunk allows to inject a malicious chunk (request a certain address
// of a chunk, but get another), in order to mock unsolicited chunk delivery.
func WithEvilChunk(addr swarm.Address, ch swarm.Chunk) Option {
return optionFunc(func(p *PullStorage) {
p.evilAddr = addr
p.evilChunk = ch
})
}
func WithCursors(c []uint64) Option {
return optionFunc(func(p *PullStorage) {
p.cursors = c
})
}
func WithCursorsErr(e error) Option {
return optionFunc(func(p *PullStorage) {
p.cursorsErr = e
})
}
type PullStorage struct {
mtx sync.Mutex
chunksCalls int
putCalls int
setCalls int
chunks map[string][]byte
evilAddr swarm.Address
evilChunk swarm.Chunk
cursors []uint64
cursorsErr error
intervalChunksResponses []chunksResponse
}
// NewPullStorage returns a new PullStorage mock.
func NewPullStorage(opts ...Option) *PullStorage {
s := &PullStorage{
chunks: make(map[string][]byte),
}
for _, v := range opts {
v.apply(s)
}
return s
}
// IntervalChunks returns a set of chunk in a requested interval.
func (s *PullStorage) IntervalChunks(_ context.Context, bin uint8, from, to uint64, limit int) (chunks []swarm.Address, topmost uint64, err error) {
s.mtx.Lock()
defer s.mtx.Unlock()
r := s.intervalChunksResponses[s.chunksCalls]
s.chunksCalls++
return r.chunks, r.topmost, r.err
}
func (s *PullStorage) Cursors(ctx context.Context) (curs []uint64, err error) {
return s.cursors, s.cursorsErr
}
// PutCalls returns the amount of times Put was called.
func (s *PullStorage) PutCalls() int {
s.mtx.Lock()
defer s.mtx.Unlock()
return s.putCalls
}
// SetCalls returns the amount of times Set was called.
func (s *PullStorage) SetCalls() int {
s.mtx.Lock()
defer s.mtx.Unlock()
return s.setCalls
}
// Get chunks.
func (s *PullStorage) Get(_ context.Context, _ storage.ModeGet, addrs ...swarm.Address) (chs []swarm.Chunk, err error) {
for _, a := range addrs {
if s.evilAddr.Equal(a) {
//inject the malicious chunk instead
chs = append(chs, s.evilChunk)
continue
}
if v, ok := s.chunks[a.String()]; ok {
chs = append(chs, swarm.NewChunk(a, v))
} else if !ok {
return nil, storage.ErrNotFound
}
}
return chs, nil
}
// Put chunks.
func (s *PullStorage) Put(_ context.Context, _ storage.ModePut, chs ...swarm.Chunk) error {
s.mtx.Lock()
defer s.mtx.Unlock()
for _, c := range chs {
s.chunks[c.Address().String()] = c.Data()
}
s.putCalls++
return nil
}
// Set chunks.
func (s *PullStorage) Set(ctx context.Context, mode storage.ModeSet, addrs ...swarm.Address) error {
s.mtx.Lock()
defer s.mtx.Unlock()
s.setCalls++
return nil
}
// Has chunks.
func (s *PullStorage) Has(_ context.Context, addr swarm.Address) (bool, error) {
if _, ok := s.chunks[addr.String()]; !ok {
return false, nil
}
return true, nil
}
type Option interface {
apply(*PullStorage)
}
type optionFunc func(*PullStorage)
func (f optionFunc) apply(r *PullStorage) { f(r) }
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pullstorage
import (
"context"
"errors"
"time"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
)
var (
_ Storer = (*ps)(nil)
// ErrDbClosed is used to signal the underlying database was closed
ErrDbClosed = errors.New("db closed")
// after how long to return a non-empty batch
batchTimeout = 500 * time.Millisecond
)
// Storer is a thin wrapper around storage.Storer.
// It is used in order to collect and provide information about chunks
// currently present in the local store.
type Storer interface {
// IntervalChunks collects chunk for a requested interval.
IntervalChunks(ctx context.Context, bin uint8, from, to uint64, limit int) (chunks []swarm.Address, topmost uint64, err error)
// Cursors gets the last BinID for every bin in the local storage
Cursors(ctx context.Context) ([]uint64, error)
// Get chunks.
Get(ctx context.Context, mode storage.ModeGet, addrs ...swarm.Address) ([]swarm.Chunk, error)
// Put chunks.
Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) error
// Set chunks.
Set(ctx context.Context, mode storage.ModeSet, addrs ...swarm.Address) error
// Has chunks.
Has(ctx context.Context, addr swarm.Address) (bool, error)
}
// ps wraps storage.Storer.
type ps struct {
storage.Storer
}
// New returns a new pullstorage Storer instance.
func New(storer storage.Storer) Storer {
return &ps{
Storer: storer,
}
}
// IntervalChunks collects chunk for a requested interval.
func (s *ps) IntervalChunks(ctx context.Context, bin uint8, from, to uint64, limit int) (chs []swarm.Address, topmost uint64, err error) {
// call iterator, iterate either until upper bound or limit reached
// return addresses, topmost is the topmost bin ID
var (
timer *time.Timer
timerC <-chan time.Time
)
ch, dbClosed, stop := s.SubscribePull(ctx, bin, from, to)
defer func(start time.Time) {
stop()
if timer != nil {
timer.Stop()
}
}(time.Now())
var nomore bool
LOOP:
for limit > 0 {
select {
case v, ok := <-ch:
if !ok {
nomore = true
break LOOP
}
chs = append(chs, v.Address)
if v.BinID > topmost {
topmost = v.BinID
}
limit--
if timer == nil {
timer = time.NewTimer(batchTimeout)
} else {
if !timer.Stop() {
<-timer.C
}
timer.Reset(batchTimeout)
}
timerC = timer.C
case <-ctx.Done():
return nil, 0, ctx.Err()
case <-timerC:
// return batch if new chunks are not received after some time
break LOOP
}
}
select {
case <-ctx.Done():
return nil, 0, ctx.Err()
case <-dbClosed:
return nil, 0, ErrDbClosed
default:
}
if nomore {
// end of interval reached. no more chunks so interval is complete
// return requested `to`. it could be that len(chs) == 0 if the interval
// is empty
topmost = to
}
return chs, topmost, nil
}
// Cursors gets the last BinID for every bin in the local storage
func (s *ps) Cursors(ctx context.Context) (curs []uint64, err error) {
curs = make([]uint64, 16)
for i := uint8(0); i < 16; i++ {
binID, err := s.Storer.LastPullSubscriptionBinID(i)
if err != nil {
return nil, err
}
curs[i] = binID
}
return curs, nil
}
// Get chunks.
func (s *ps) Get(ctx context.Context, mode storage.ModeGet, addrs ...swarm.Address) ([]swarm.Chunk, error) {
return s.Storer.GetMulti(ctx, mode, addrs...)
}
// Put chunks.
func (s *ps) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) error {
_, err := s.Storer.Put(ctx, mode, chs...)
return err
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pullstorage_test
import (
"context"
"crypto/rand"
"errors"
"io/ioutil"
"testing"
"time"
"github.com/ethersphere/bee/pkg/localstore"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
stesting "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm"
)
var (
addrs = []swarm.Address{
swarm.MustParseHexAddress("0001"),
swarm.MustParseHexAddress("0002"),
swarm.MustParseHexAddress("0003"),
swarm.MustParseHexAddress("0004"),
swarm.MustParseHexAddress("0005"),
swarm.MustParseHexAddress("0006"),
}
limit = 5
)
func someAddrs(i ...int) (r []swarm.Address) {
for _, v := range i {
r = append(r, addrs[v])
}
return r
}
func someDescriptors(i ...int) (d []storage.Descriptor) {
for _, v := range i {
d = append(d, storage.Descriptor{Address: addrs[v], BinID: uint64(v + 1)})
}
return d
}
// TestIntervalChunks tests that the IntervalChunks method always returns
// an upper bound of N chunks for a certain range, and the topmost equals:
// - to the To argument of the function (in case there are no chunks in the interval)
// - to the To argument of the function (in case the number of chunks in interval <= N)
// - to BinID of the last chunk in the returned collection in case number of chunks in interval > N
func TestIntervalChunks(t *testing.T) {
// we need to check four cases of the subscribe pull iterator:
// - no chunks in interval
// - less chunks reported than what is in the interval (but interval still intact, probably old chunks GCd)
// - as much chunks as size of interval
// - more chunks than what's in interval (return lower topmost value)
// - less chunks in interval, but since we're at the top of the interval, block and wait for new chunks
for _, tc := range []struct {
desc string
from, to uint64 // request from, to
mockAddrs []int // which addresses should the mock return
addrs []byte // the expected returned chunk address byte slice
topmost uint64 // expected topmost
}{
{desc: "no chunks in interval", from: 0, to: 5, topmost: 5},
{desc: "interval full", from: 0, to: 5, mockAddrs: []int{0, 1, 2, 3, 4}, topmost: 5},
{desc: "some in the middle", from: 0, to: 5, mockAddrs: []int{1, 3}, topmost: 5},
{desc: "at the edges", from: 0, to: 5, mockAddrs: []int{0, 4}, topmost: 5},
{desc: "at the edges and the middle", from: 0, to: 5, mockAddrs: []int{0, 2, 4}, topmost: 5},
{desc: "more than interval", from: 0, to: 5, mockAddrs: []int{0, 1, 2, 3, 4, 5}, topmost: 5},
} {
t.Run(tc.desc, func(t *testing.T) {
b := someAddrs(tc.mockAddrs...)
desc := someDescriptors(tc.mockAddrs...)
ps, _ := newPullStorage(t, mock.WithSubscribePullChunks(desc...))
ctx, cancel := context.WithCancel(context.Background())
addresses, topmost, err := ps.IntervalChunks(ctx, 0, tc.from, tc.to, limit)
if err != nil {
t.Fatal(err)
}
cancel()
checkAinB(t, addresses, b)
if topmost != tc.topmost {
t.Fatalf("expected topmost %d but got %d", tc.topmost, topmost)
}
})
}
}
// Get some descriptor from the chunk channel, then block for a while
// then add more chunks to the subscribe pull iterator and make sure the loop
// exits correctly.
func TestIntervalChunks_GetChunksLater(t *testing.T) {
desc := someDescriptors(0, 2)
ps, db := newPullStorage(t, mock.WithSubscribePullChunks(desc...), mock.WithPartialInterval(true))
go func() {
<-time.After(200 * time.Millisecond)
// add chunks to subscribe pull on the storage mock
db.MorePull(someDescriptors(1, 3, 4)...)
}()
addrs, topmost, err := ps.IntervalChunks(context.Background(), 0, 0, 5, limit)
if err != nil {
t.Fatal(err)
}
if l := len(addrs); l != 5 {
t.Fatalf("want %d addrs but got %d", 5, l)
}
// highest chunk we sent had BinID 5
exp := uint64(5)
if topmost != exp {
t.Fatalf("expected topmost %d but got %d", exp, topmost)
}
}
// Get some descriptors, but then let the iterator time out and return just the stuff we got in the beginning
func TestIntervalChunks_NoChunksLater(t *testing.T) {
desc := someDescriptors(0, 2)
ps, db := newPullStorage(t, mock.WithSubscribePullChunks(desc...), mock.WithPartialInterval(true))
go func() {
<-time.After(600 * time.Millisecond)
// add chunks to subscribe pull on the storage mock
db.MorePull(someDescriptors(1, 3, 4)...)
}()
addrs, topmost, err := ps.IntervalChunks(context.Background(), 0, 0, 5, limit)
if err != nil {
t.Fatal(err)
}
if l := len(addrs); l != 2 {
t.Fatalf("want %d addrs but got %d", 2, l)
}
// highest chunk we sent had BinID 3
exp := uint64(3)
if topmost != exp {
t.Fatalf("expected topmost %d but got %d", exp, topmost)
}
}
func TestIntervalChunks_Blocking(t *testing.T) {
desc := someDescriptors(0, 2)
ps, _ := newPullStorage(t, mock.WithSubscribePullChunks(desc...), mock.WithPartialInterval(true))
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-time.After(100 * time.Millisecond)
cancel()
}()
_, _, err := ps.IntervalChunks(ctx, 0, 0, 5, limit)
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, context.Canceled) {
t.Fatal(err)
}
}
func TestIntervalChunks_DbShutdown(t *testing.T) {
ps, db := newPullStorage(t, mock.WithPartialInterval(true))
go func() {
<-time.After(100 * time.Millisecond)
db.Close()
}()
_, _, err := ps.IntervalChunks(context.Background(), 0, 0, 5, limit)
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, pullstorage.ErrDbClosed) {
t.Fatal(err)
}
}
// TestIntervalChunks_Localstore is an integration test with a real
// localstore instance.
func TestIntervalChunks_Localstore(t *testing.T) {
fill := func(f, t int) (ints []int) {
for i := f; i <= t; i++ {
ints = append(ints, i)
}
return ints
}
for _, tc := range []struct {
name string
chunks int
f, t uint64
limit int
expect int // chunks
top uint64 // topmost
addrs []int // indexes of the generated chunk slice
}{
{
name: "0-1, expect 1 chunk", // intervals always >0
chunks: 50,
f: 0, t: 1,
limit: 50,
expect: 1, top: 1, addrs: fill(1, 1),
},
{
name: "1-1, expect 1 chunk",
chunks: 50,
f: 1, t: 1,
limit: 50,
expect: 1, top: 1, addrs: fill(1, 1),
},
{
name: "2-2, expect 1 chunk",
chunks: 50,
f: 2, t: 2,
limit: 50,
expect: 1, top: 2, addrs: fill(2, 2),
},
{
name: "0-10, expect 10 chunks", // intervals always >0
chunks: 50,
f: 0, t: 10,
limit: 50,
expect: 10, top: 10, addrs: fill(1, 10),
},
{
name: "1-10, expect 10 chunks",
chunks: 50,
f: 0, t: 10,
limit: 50,
expect: 10, top: 10, addrs: fill(1, 10),
},
{
name: "0-50, expect 50 chunks", // intervals always >0
chunks: 50,
f: 0, t: 50,
limit: 50,
expect: 50, top: 50, addrs: fill(1, 50),
},
{
name: "1-50, expect 50 chunks",
chunks: 50,
f: 1, t: 50,
limit: 50,
expect: 50, top: 50, addrs: fill(1, 50),
},
{
name: "0-60, expect 50 chunks", // hit the limit
chunks: 50,
f: 0, t: 60,
limit: 50,
expect: 50, top: 50, addrs: fill(1, 50),
},
{
name: "1-60, expect 50 chunks", // hit the limit
chunks: 50,
f: 0, t: 60,
limit: 50,
expect: 50, top: 50, addrs: fill(1, 50),
},
} {
t.Run(tc.name, func(t *testing.T) {
base, db := newTestDB(t, nil)
ps := pullstorage.New(db)
var chunks []swarm.Chunk
for i := 1; i <= tc.chunks; {
c := stesting.GenerateTestRandomChunk()
po := swarm.Proximity(c.Address().Bytes(), base)
if po == 1 {
chunks = append(chunks, c)
i++
}
}
ctx := context.Background()
_, err := db.Put(ctx, storage.ModePutUpload, chunks...)
if err != nil {
t.Fatal(err)
}
//always bin 1
chs, topmost, err := ps.IntervalChunks(ctx, 1, tc.f, tc.t, tc.limit)
if err != nil {
t.Fatal(err)
}
checkAddrs := make([]swarm.Address, len(tc.addrs))
for i, v := range tc.addrs {
checkAddrs[i] = chunks[v-1].Address()
}
for i, c := range chs {
if !c.Equal(checkAddrs[i]) {
t.Fatalf("chunk %d address mismatch", i)
}
}
if topmost != tc.top {
t.Fatalf("topmost mismatch, got %d want %d", topmost, tc.top)
}
if l := len(chs); l != tc.expect {
t.Fatalf("expected %d chunks but got %d", tc.expect, l)
}
})
}
}
func newPullStorage(t *testing.T, o ...mock.Option) (pullstorage.Storer, *mock.MockStorer) {
db := mock.NewStorer(o...)
ps := pullstorage.New(db)
return ps, db
}
func newTestDB(t testing.TB, o *localstore.Options) (baseKey []byte, db *localstore.DB) {
t.Helper()
baseKey = make([]byte, 32)
if _, err := rand.Read(baseKey); err != nil {
t.Fatal(err)
}
logger := logging.New(ioutil.Discard, 0)
db, err := localstore.New("", baseKey, o, logger)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
err := db.Close()
if err != nil {
t.Error(err)
}
})
return baseKey, db
}
// check that every a exists in b
func checkAinB(t *testing.T, a, b []swarm.Address) {
t.Helper()
for _, v := range a {
if !isIn(v, b) {
t.Fatalf("address %s not found in slice %s", v, b)
}
}
}
func isIn(a swarm.Address, b []swarm.Address) bool {
for _, v := range b {
if a.Equal(v) {
return true
}
}
return false
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pullsync
import (
"context"
"errors"
"fmt"
"io"
"time"
"github.com/ethersphere/bee/pkg/bitvector"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/pullsync/pb"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
)
const (
protocolName = "pullsync"
protocolVersion = "1.0.0"
streamName = "pullsync"
cursorStreamName = "cursors"
)
var (
ErrUnsolicitedChunk = errors.New("peer sent unsolicited chunk")
)
// how many maximum chunks in a batch
var maxPage = 50
type Interface interface {
SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, err error)
GetCursors(ctx context.Context, peer swarm.Address) ([]uint64, error)
}
type Syncer struct {
streamer p2p.Streamer
logger logging.Logger
storage pullstorage.Storer
Interface
io.Closer
}
type Options struct {
Streamer p2p.Streamer
Storage pullstorage.Storer
Logger logging.Logger
}
func New(o Options) *Syncer {
return &Syncer{
streamer: o.Streamer,
storage: o.Storage,
logger: o.Logger,
}
}
func (s *Syncer) Protocol() p2p.ProtocolSpec {
return p2p.ProtocolSpec{
Name: protocolName,
Version: protocolVersion,
StreamSpecs: []p2p.StreamSpec{
{
Name: streamName,
Handler: s.handler,
},
{
Name: cursorStreamName,
Handler: s.cursorHandler,
},
},
}
}
// SyncInterval syncs a requested interval from the given peer.
// It returns the BinID of highest chunk that was synced from the given interval.
// If the requested interval is too large, the downstream peer has the liberty to
// provide less chunks than requested.
func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, err error) {
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil {
return 0, fmt.Errorf("new stream: %w", err)
}
defer func() {
if err != nil {
_ = stream.FullClose()
return
}
_ = stream.Close()
}()
w, r := protobuf.NewWriterAndReader(stream)
rangeMsg := &pb.GetRange{Bin: int32(bin), From: from, To: to}
if err = w.WriteMsgWithContext(ctx, rangeMsg); err != nil {
return 0, fmt.Errorf("write get range: %w", err)
}
var offer pb.Offer
if err = r.ReadMsgWithContext(ctx, &offer); err != nil {
return 0, fmt.Errorf("read offer: %w", err)
}
if len(offer.Hashes)%swarm.HashSize != 0 {
return 0, fmt.Errorf("inconsistent hash length")
}
// empty interval (no chunks present in interval).
// return the end of the requested range as topmost.
if len(offer.Hashes) == 0 {
return offer.Topmost, nil
}
var (
bvLen = len(offer.Hashes) / swarm.HashSize
wantChunks = make(map[string]struct{})
ctr = 0
)
bv, err := bitvector.New(bvLen)
if err != nil {
return 0, fmt.Errorf("new bitvector: %w", err)
}
for i := 0; i < len(offer.Hashes); i += swarm.HashSize {
a := swarm.NewAddress(offer.Hashes[i : i+swarm.HashSize])
if a.Equal(swarm.ZeroAddress) {
// i'd like to have this around to see we don't see any of these in the logs
s.logger.Errorf("syncer got a zero address hash on offer")
return 0, fmt.Errorf("zero address on offer")
}
have, err := s.storage.Has(ctx, a)
if err != nil {
return 0, fmt.Errorf("storage has: %w", err)
}
if !have {
wantChunks[a.String()] = struct{}{}
ctr++
bv.Set(i / swarm.HashSize)
}
}
wantMsg := &pb.Want{BitVector: bv.Bytes()}
if err = w.WriteMsgWithContext(ctx, wantMsg); err != nil {
return 0, fmt.Errorf("write want: %w", err)
}
// if ctr is zero, it means we don't want any chunk in the batch
// thus, the following loop will not get executed and the method
// returns immediately with the topmost value on the offer, which
// will seal the interval and request the next one
for ; ctr > 0; ctr-- {
var delivery pb.Delivery
if err = r.ReadMsgWithContext(ctx, &delivery); err != nil {
return 0, fmt.Errorf("read delivery: %w", err)
}
addr := swarm.NewAddress(delivery.Address)
if _, ok := wantChunks[addr.String()]; !ok {
return 0, ErrUnsolicitedChunk
}
delete(wantChunks, addr.String())
if err = s.storage.Put(ctx, storage.ModePutSync, swarm.NewChunk(addr, delivery.Data)); err != nil {
return 0, fmt.Errorf("delivery put: %w", err)
}
}
return offer.Topmost, nil
}
// handler handles an incoming request to sync an interval
func (s *Syncer) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) error {
w, r := protobuf.NewWriterAndReader(stream)
defer stream.Close()
var rn pb.GetRange
if err := r.ReadMsgWithContext(ctx, &rn); err != nil {
return fmt.Errorf("read get range: %w", err)
}
s.logger.Debugf("got range peer %s bin %d from %d to %d", p.Address.String(), rn.Bin, rn.From, rn.To)
// make an offer to the upstream peer in return for the requested range
offer, addrs, err := s.makeOffer(ctx, rn)
if err != nil {
return fmt.Errorf("make offer: %w", err)
}
if err := w.WriteMsgWithContext(ctx, offer); err != nil {
return fmt.Errorf("write offer: %w", err)
}
// we don't have any hashes to offer in this range (the
// interval is empty). nothing more to do
if len(offer.Hashes) == 0 {
return nil
}
var want pb.Want
if err := r.ReadMsgWithContext(ctx, &want); err != nil {
return fmt.Errorf("read want: %w", err)
}
// empty bitvector implies downstream peer does not want
// any chunks (it has them already). mark chunks as synced
if len(want.BitVector) == 0 {
return s.setChunks(ctx, addrs...)
}
chs, err := s.processWant(ctx, offer, &want)
if err != nil {
return fmt.Errorf("process want: %w", err)
}
for _, v := range chs {
deliver := pb.Delivery{Address: v.Address().Bytes(), Data: v.Data()}
if err := w.WriteMsgWithContext(ctx, &deliver); err != nil {
return fmt.Errorf("write delivery: %w", err)
}
}
err = s.setChunks(ctx, addrs...)
if err != nil {
return err
}
time.Sleep(50 * time.Millisecond) //because of test, getting EOF w/o
return nil
}
func (s *Syncer) setChunks(ctx context.Context, addrs ...swarm.Address) error {
return s.storage.Set(ctx, storage.ModeSetSyncPull, addrs...)
}
// makeOffer tries to assemble an offer for a given requested interval.
func (s *Syncer) makeOffer(ctx context.Context, rn pb.GetRange) (o *pb.Offer, addrs []swarm.Address, err error) {
s.logger.Tracef("syncer make offer for bin %d from %d to %d maxpage %d", rn.Bin, rn.From, rn.To, maxPage)
chs, top, err := s.storage.IntervalChunks(ctx, uint8(rn.Bin), rn.From, rn.To, maxPage)
if err != nil {
return o, nil, err
}
o = new(pb.Offer)
o.Topmost = top
o.Hashes = make([]byte, 0)
for _, v := range chs {
o.Hashes = append(o.Hashes, v.Bytes()...)
}
return o, chs, nil
}
// processWant compares a received Want to a sent Offer and returns
// the appropriate chunks from the local store.
func (s *Syncer) processWant(ctx context.Context, o *pb.Offer, w *pb.Want) ([]swarm.Chunk, error) {
l := len(o.Hashes) / swarm.HashSize
bv, err := bitvector.NewFromBytes(w.BitVector, l)
if err != nil {
return nil, err
}
var addrs []swarm.Address
for i := 0; i < len(o.Hashes); i += swarm.HashSize {
if bv.Get(i / swarm.HashSize) {
a := swarm.NewAddress(o.Hashes[i : i+swarm.HashSize])
addrs = append(addrs, a)
}
}
return s.storage.Get(ctx, storage.ModeGetSync, addrs...)
}
func (s *Syncer) GetCursors(ctx context.Context, peer swarm.Address) ([]uint64, error) {
s.logger.Debugf("syncer get cursors from peer %s", peer)
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, cursorStreamName)
if err != nil {
return nil, fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
w, r := protobuf.NewWriterAndReader(stream)
syn := &pb.Syn{}
if err = w.WriteMsgWithContext(ctx, syn); err != nil {
return nil, fmt.Errorf("write syn: %w", err)
}
var ack pb.Ack
if err = r.ReadMsgWithContext(ctx, &ack); err != nil {
return nil, fmt.Errorf("read ack: %w", err)
}
s.logger.Debugf("syncer peer %s cursors %s", peer, ack.Cursors)
return ack.Cursors, nil
}
func (s *Syncer) cursorHandler(ctx context.Context, p p2p.Peer, stream p2p.Stream) error {
w, r := protobuf.NewWriterAndReader(stream)
defer stream.Close()
var syn pb.Syn
if err := r.ReadMsgWithContext(ctx, &syn); err != nil {
return fmt.Errorf("read syn: %w", err)
}
var ack pb.Ack
ints, err := s.storage.Cursors(ctx)
if err != nil {
_ = stream.FullClose()
return err
}
ack.Cursors = ints
if err = w.WriteMsgWithContext(ctx, &ack); err != nil {
return fmt.Errorf("write ack: %w", err)
}
return nil
}
func (s *Syncer) Close() error {
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pullsync_test
import (
"context"
"crypto/rand"
"errors"
"io"
"io/ioutil"
"testing"
"time"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pullsync"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage/mock"
"github.com/ethersphere/bee/pkg/swarm"
)
var (
addrs = []swarm.Address{
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000001"),
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000002"),
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000003"),
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000004"),
swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000005"),
}
chunks []swarm.Chunk
)
func someChunks(i ...int) (c []swarm.Chunk) {
for _, v := range i {
c = append(c, chunks[v])
}
return c
}
func init() {
chunks = make([]swarm.Chunk, 5)
for i := 0; i < 5; i++ {
data := make([]byte, swarm.ChunkSize)
_, _ = rand.Read(data)
chunks[i] = swarm.NewChunk(addrs[i], data)
}
}
// TestIncoming tests that an incoming request for an interval
// is handled correctly when no chunks are available in the interval.
// This means the interval exists but chunks are not there (GCd).
// Expected behavior is that an offer message with the requested
// To value is returned to the requester, but offer.Hashes is empty.
func TestIncoming_WantEmptyInterval(t *testing.T) {
var (
mockTopmost = uint64(5)
ps, serverDb = newPullSync(nil, mock.WithIntervalsResp([]swarm.Address{}, mockTopmost, nil))
recorder = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
psClient, clientDb = newPullSync(recorder)
)
topmost, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 1, 0, 5)
if err != nil {
t.Fatal(err)
}
if topmost != mockTopmost {
t.Fatalf("got offer topmost %d but want %d", topmost, mockTopmost)
}
if clientDb.PutCalls() > 0 {
t.Fatal("too many puts")
}
waitSet(t, serverDb, 0)
}
func TestIncoming_WantNone(t *testing.T) {
var (
mockTopmost = uint64(5)
ps, serverDb = newPullSync(nil, mock.WithIntervalsResp(addrs, mockTopmost, nil), mock.WithChunks(chunks...))
recorder = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
psClient, clientDb = newPullSync(recorder, mock.WithChunks(chunks...))
)
topmost, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
if err != nil {
t.Fatal(err)
}
if topmost != mockTopmost {
t.Fatalf("got offer topmost %d but want %d", topmost, mockTopmost)
}
if clientDb.PutCalls() > 0 {
t.Fatal("too many puts")
}
waitSet(t, serverDb, 1)
}
func TestIncoming_WantOne(t *testing.T) {
var (
mockTopmost = uint64(5)
ps, serverDb = newPullSync(nil, mock.WithIntervalsResp(addrs, mockTopmost, nil), mock.WithChunks(chunks...))
recorder = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
psClient, clientDb = newPullSync(recorder, mock.WithChunks(someChunks(1, 2, 3, 4)...))
)
topmost, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
if err != nil {
t.Fatal(err)
}
if topmost != mockTopmost {
t.Fatalf("got offer topmost %d but want %d", topmost, mockTopmost)
}
// should have all
haveChunks(t, clientDb, addrs...)
if clientDb.PutCalls() > 1 {
t.Fatal("too many puts")
}
waitSet(t, serverDb, 1)
}
func TestIncoming_WantAll(t *testing.T) {
var (
mockTopmost = uint64(5)
ps, serverDb = newPullSync(nil, mock.WithIntervalsResp(addrs, mockTopmost, nil), mock.WithChunks(chunks...))
recorder = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
psClient, clientDb = newPullSync(recorder)
)
topmost, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
if err != nil {
t.Fatal(err)
}
if topmost != mockTopmost {
t.Fatalf("got offer topmost %d but want %d", topmost, mockTopmost)
}
// should have all
haveChunks(t, clientDb, addrs...)
if p := clientDb.PutCalls(); p != 5 {
t.Fatalf("want %d puts but got %d", 5, p)
}
waitSet(t, serverDb, 1)
}
func TestIncoming_UnsolicitedChunk(t *testing.T) {
evilAddr := swarm.MustParseHexAddress("0000000000000000000000000000000000000000000000000000000000000666")
evilData := []byte{0x66, 0x66, 0x66}
evil := swarm.NewChunk(evilAddr, evilData)
var (
mockTopmost = uint64(5)
ps, _ = newPullSync(nil, mock.WithIntervalsResp(addrs, mockTopmost, nil), mock.WithChunks(chunks...), mock.WithEvilChunk(addrs[4], evil))
recorder = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
psClient, _ = newPullSync(recorder)
)
_, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
if !errors.Is(err, pullsync.ErrUnsolicitedChunk) {
t.Fatalf("expected ErrUnsolicitedChunk but got %v", err)
}
}
func TestGetCursors(t *testing.T) {
var (
mockCursors = []uint64{100, 101, 102, 103}
ps, _ = newPullSync(nil, mock.WithCursors(mockCursors))
recorder = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
psClient, _ = newPullSync(recorder)
)
curs, err := psClient.GetCursors(context.Background(), swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
if len(curs) != len(mockCursors) {
t.Fatalf("length mismatch got %d want %d", len(curs), len(mockCursors))
}
for i, v := range mockCursors {
if curs[i] != v {
t.Errorf("cursor mismatch. index %d want %d got %d", i, v, curs[i])
}
}
}
func TestGetCursorsError(t *testing.T) {
var (
e = errors.New("erring")
ps, _ = newPullSync(nil, mock.WithCursorsErr(e))
recorder = streamtest.New(streamtest.WithProtocols(ps.Protocol()))
psClient, _ = newPullSync(recorder)
)
_, err := psClient.GetCursors(context.Background(), swarm.ZeroAddress)
if err == nil {
t.Fatal("expected error but got none")
}
if !errors.Is(err, io.EOF) {
t.Fatalf("expect error '%v' but got '%v'", e, err)
}
}
func haveChunks(t *testing.T, s *mock.PullStorage, addrs ...swarm.Address) {
t.Helper()
for _, a := range addrs {
have, err := s.Has(context.Background(), a)
if err != nil {
t.Fatal(err)
}
if !have {
t.Errorf("storage does not have chunk %s", a)
}
}
}
func newPullSync(s p2p.Streamer, o ...mock.Option) (*pullsync.Syncer, *mock.PullStorage) {
storage := mock.NewPullStorage(o...)
logger := logging.New(ioutil.Discard, 0)
return pullsync.New(pullsync.Options{Streamer: s, Storage: storage, Logger: logger}), storage
}
func waitSet(t *testing.T, db *mock.PullStorage, v int) {
time.Sleep(10 * time.Millisecond) // give leeway for the case where v==0
var s int
for i := 0; i < 10; i++ {
s = db.SetCalls()
switch {
case s > v:
t.Fatalf("too many Set calls: got %d want %d", s, v)
case s == v:
return
default:
time.Sleep(10 * time.Millisecond)
}
}
t.Fatalf("timed out waiting for set to be called. got %d calls want %d", s, v)
}
...@@ -88,7 +88,7 @@ func NewDB(path string) (db *DB, err error) { ...@@ -88,7 +88,7 @@ func NewDB(path string) (db *DB, err error) {
} }
// Put wraps LevelDB Put method to increment metrics counter. // Put wraps LevelDB Put method to increment metrics counter.
func (db *DB) Put(key []byte, value []byte) (err error) { func (db *DB) Put(key, value []byte) (err error) {
err = db.ldb.Put(key, value, nil) err = db.ldb.Put(key, value, nil)
if err != nil { if err != nil {
db.metrics.PutFailCounter.Inc() db.metrics.PutFailCounter.Inc()
......
...@@ -23,16 +23,44 @@ type MockStorer struct { ...@@ -23,16 +23,44 @@ type MockStorer struct {
pinnedAddress []swarm.Address // Stores the pinned address pinnedAddress []swarm.Address // Stores the pinned address
pinnedCounter []uint64 // and its respective counter. These are stored as slices to preserve the order. pinnedCounter []uint64 // and its respective counter. These are stored as slices to preserve the order.
pinSetMu sync.Mutex pinSetMu sync.Mutex
subpull []storage.Descriptor
partialInterval bool
validator swarm.ChunkValidator validator swarm.ChunkValidator
tags *tags.Tags tags *tags.Tags
morePull chan struct{}
mtx sync.Mutex
quit chan struct{}
} }
func NewStorer() storage.Storer { func WithSubscribePullChunks(chs ...storage.Descriptor) Option {
return &MockStorer{ return optionFunc(func(m *MockStorer) {
m.subpull = make([]storage.Descriptor, len(chs))
for i, v := range chs {
m.subpull[i] = v
}
})
}
func WithPartialInterval(v bool) Option {
return optionFunc(func(m *MockStorer) {
m.partialInterval = v
})
}
func NewStorer(opts ...Option) *MockStorer {
s := &MockStorer{
store: make(map[string][]byte), store: make(map[string][]byte),
modeSet: make(map[string]storage.ModeSet), modeSet: make(map[string]storage.ModeSet),
modeSetMu: sync.Mutex{}, modeSetMu: sync.Mutex{},
morePull: make(chan struct{}),
quit: make(chan struct{}),
}
for _, v := range opts {
v.apply(s)
} }
return s
} }
func NewValidatingStorer(v swarm.ChunkValidator, tags *tags.Tags) *MockStorer { func NewValidatingStorer(v swarm.ChunkValidator, tags *tags.Tags) *MockStorer {
...@@ -47,6 +75,9 @@ func NewValidatingStorer(v swarm.ChunkValidator, tags *tags.Tags) *MockStorer { ...@@ -47,6 +75,9 @@ func NewValidatingStorer(v swarm.ChunkValidator, tags *tags.Tags) *MockStorer {
} }
func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) { func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) {
m.mtx.Lock()
defer m.mtx.Unlock()
v, has := m.store[addr.String()] v, has := m.store[addr.String()]
if !has { if !has {
return nil, storage.ErrNotFound return nil, storage.ErrNotFound
...@@ -55,6 +86,9 @@ func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.A ...@@ -55,6 +86,9 @@ func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.A
} }
func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err error) { func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err error) {
m.mtx.Lock()
defer m.mtx.Unlock()
for _, ch := range chs { for _, ch := range chs {
if m.validator != nil { if m.validator != nil {
if !m.validator.Validate(ch) { if !m.validator.Validate(ch) {
...@@ -62,7 +96,7 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm ...@@ -62,7 +96,7 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm
} }
} }
m.store[ch.Address().String()] = ch.Data() m.store[ch.Address().String()] = ch.Data()
yes, err := m.Has(ctx, ch.Address()) yes, err := m.has(ctx, ch.Address())
if err != nil { if err != nil {
exist = append(exist, false) exist = append(exist, false)
continue continue
...@@ -81,11 +115,17 @@ func (m *MockStorer) GetMulti(ctx context.Context, mode storage.ModeGet, addrs . ...@@ -81,11 +115,17 @@ func (m *MockStorer) GetMulti(ctx context.Context, mode storage.ModeGet, addrs .
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
func (m *MockStorer) Has(ctx context.Context, addr swarm.Address) (yes bool, err error) { func (m *MockStorer) has(ctx context.Context, addr swarm.Address) (yes bool, err error) {
_, has := m.store[addr.String()] _, has := m.store[addr.String()]
return has, nil return has, nil
} }
func (m *MockStorer) Has(ctx context.Context, addr swarm.Address) (yes bool, err error) {
m.mtx.Lock()
defer m.mtx.Unlock()
return m.has(ctx, addr)
}
func (m *MockStorer) HasMulti(ctx context.Context, addrs ...swarm.Address) (yes []bool, err error) { func (m *MockStorer) HasMulti(ctx context.Context, addrs ...swarm.Address) (yes []bool, err error) {
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
...@@ -148,8 +188,72 @@ func (m *MockStorer) LastPullSubscriptionBinID(bin uint8) (id uint64, err error) ...@@ -148,8 +188,72 @@ func (m *MockStorer) LastPullSubscriptionBinID(bin uint8) (id uint64, err error)
panic("not implemented") // TODO: Implement panic("not implemented") // TODO: Implement
} }
func (m *MockStorer) SubscribePull(ctx context.Context, bin uint8, since uint64, until uint64) (c <-chan storage.Descriptor, stop func()) { func (m *MockStorer) SubscribePull(ctx context.Context, bin uint8, since, until uint64) (<-chan storage.Descriptor, <-chan struct{}, func()) {
panic("not implemented") // TODO: Implement c := make(chan storage.Descriptor)
done := make(chan struct{})
stop := func() {
close(done)
}
go func() {
defer close(c)
m.mtx.Lock()
for _, ch := range m.subpull {
select {
case c <- ch:
case <-done:
return
case <-ctx.Done():
return
case <-m.quit:
return
}
}
m.mtx.Unlock()
if m.partialInterval {
// block since we're at the top of the bin and waiting for new chunks
select {
case <-done:
return
case <-m.quit:
return
case <-ctx.Done():
return
case <-m.morePull:
}
}
m.mtx.Lock()
defer m.mtx.Unlock()
// iterate on what we have in the iterator
for _, ch := range m.subpull {
select {
case c <- ch:
case <-done:
return
case <-ctx.Done():
return
case <-m.quit:
return
}
}
}()
return c, m.quit, stop
}
func (m *MockStorer) MorePull(d ...storage.Descriptor) {
// clear out what we already have in subpull
m.mtx.Lock()
defer m.mtx.Unlock()
m.subpull = make([]storage.Descriptor, len(d))
for i, v := range d {
m.subpull[i] = v
}
close(m.morePull)
} }
func (m *MockStorer) SubscribePush(ctx context.Context) (c <-chan swarm.Chunk, stop func()) { func (m *MockStorer) SubscribePush(ctx context.Context) (c <-chan swarm.Chunk, stop func()) {
...@@ -184,5 +288,13 @@ func (m *MockStorer) PinInfo(address swarm.Address) (uint64, error) { ...@@ -184,5 +288,13 @@ func (m *MockStorer) PinInfo(address swarm.Address) (uint64, error) {
} }
func (m *MockStorer) Close() error { func (m *MockStorer) Close() error {
panic("not implemented") // TODO: Implement close(m.quit)
return nil
}
type Option interface {
apply(*MockStorer)
} }
type optionFunc func(*MockStorer)
func (f optionFunc) apply(r *MockStorer) { f(r) }
...@@ -3,9 +3,10 @@ package mock_test ...@@ -3,9 +3,10 @@ package mock_test
import ( import (
"bytes" "bytes"
"context" "context"
"github.com/ethersphere/bee/pkg/tags"
"testing" "testing"
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/storage/mock/validator" "github.com/ethersphere/bee/pkg/storage/mock/validator"
...@@ -41,11 +42,10 @@ func TestMockStorer(t *testing.T) { ...@@ -41,11 +42,10 @@ func TestMockStorer(t *testing.T) {
if chunk, err := s.Get(ctx, storage.ModeGetRequest, keyFound); err != nil { if chunk, err := s.Get(ctx, storage.ModeGetRequest, keyFound); err != nil {
t.Fatalf("expected no error but got: %v", err) t.Fatalf("expected no error but got: %v", err)
} else { } else if !bytes.Equal(chunk.Data(), valueFound) {
if !bytes.Equal(chunk.Data(), valueFound) {
t.Fatalf("expected value %s but got %s", valueFound, chunk.Data()) t.Fatalf("expected value %s but got %s", valueFound, chunk.Data())
} }
}
has, err := s.Has(ctx, keyFound) has, err := s.Has(ctx, keyFound)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -83,11 +83,9 @@ func TestMockValidatingStorer(t *testing.T) { ...@@ -83,11 +83,9 @@ func TestMockValidatingStorer(t *testing.T) {
if chunk, err := s.Get(ctx, storage.ModeGetRequest, validAddress); err != nil { if chunk, err := s.Get(ctx, storage.ModeGetRequest, validAddress); err != nil {
t.Fatalf("got error on get but expected none: %v", err) t.Fatalf("got error on get but expected none: %v", err)
} else { } else if !bytes.Equal(chunk.Data(), validContent) {
if !bytes.Equal(chunk.Data(), validContent) {
t.Fatal("stored content not identical to input data") t.Fatal("stored content not identical to input data")
} }
}
if _, err := s.Get(ctx, storage.ModeGetRequest, invalidAddress); err == nil { if _, err := s.Get(ctx, storage.ModeGetRequest, invalidAddress); err == nil {
t.Fatal("got no error on get but expected one") t.Fatal("got no error on get but expected one")
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment