Commit 1fc07169 authored by metacertain's avatar metacertain Committed by GitHub

feat: cache on forwarding in retrieval (#2234)

parent cd58784b
...@@ -70,6 +70,7 @@ const ( ...@@ -70,6 +70,7 @@ const (
optionNameBlockTime = "block-time" optionNameBlockTime = "block-time"
optionWarmUpTime = "warmup-time" optionWarmUpTime = "warmup-time"
optionNameMainNet = "mainnet" optionNameMainNet = "mainnet"
optionNameRetrievalCaching = "cache-retrieval"
) )
func init() { func init() {
...@@ -247,6 +248,7 @@ func (c *command) setAllFlags(cmd *cobra.Command) { ...@@ -247,6 +248,7 @@ func (c *command) setAllFlags(cmd *cobra.Command) {
cmd.Flags().String(optionNameSwapDeploymentGasPrice, "", "gas price in wei to use for deployment and funding") cmd.Flags().String(optionNameSwapDeploymentGasPrice, "", "gas price in wei to use for deployment and funding")
cmd.Flags().Duration(optionWarmUpTime, time.Minute*20, "time to warmup the node before pull/push protocols can be kicked off.") cmd.Flags().Duration(optionWarmUpTime, time.Minute*20, "time to warmup the node before pull/push protocols can be kicked off.")
cmd.Flags().Bool(optionNameMainNet, false, "triggers connect to main net bootnodes.") cmd.Flags().Bool(optionNameMainNet, false, "triggers connect to main net bootnodes.")
cmd.Flags().Bool(optionNameRetrievalCaching, true, "enable forwarded content caching")
} }
func newLogger(cmd *cobra.Command, verbosity string) (logging.Logger, error) { func newLogger(cmd *cobra.Command, verbosity string) (logging.Logger, error) {
......
...@@ -198,6 +198,7 @@ inability to use, or your interaction with other nodes or the software.`) ...@@ -198,6 +198,7 @@ inability to use, or your interaction with other nodes or the software.`)
DeployGasPrice: c.config.GetString(optionNameSwapDeploymentGasPrice), DeployGasPrice: c.config.GetString(optionNameSwapDeploymentGasPrice),
WarmupTime: c.config.GetDuration(optionWarmUpTime), WarmupTime: c.config.GetDuration(optionWarmUpTime),
ChainID: networkConfig.chainID, ChainID: networkConfig.chainID,
RetrievalCaching: c.config.GetBool(optionNameRetrievalCaching),
}) })
if err != nil { if err != nil {
return err return err
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/postage"
"github.com/ethersphere/bee/pkg/recovery" "github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/sctx" "github.com/ethersphere/bee/pkg/sctx"
...@@ -25,7 +26,7 @@ type store struct { ...@@ -25,7 +26,7 @@ type store struct {
storage.Storer storage.Storer
retrieval retrieval.Interface retrieval retrieval.Interface
logger logging.Logger logger logging.Logger
validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error) validStamp postage.ValidStampFn
recoveryCallback recovery.Callback // this is the callback to be executed when a chunk fails to be retrieved recoveryCallback recovery.Callback // this is the callback to be executed when a chunk fails to be retrieved
} }
...@@ -34,7 +35,7 @@ var ( ...@@ -34,7 +35,7 @@ var (
) )
// New returns a new NetStore that wraps a given Storer. // New returns a new NetStore that wraps a given Storer.
func New(s storage.Storer, validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), rcb recovery.Callback, r retrieval.Interface, logger logging.Logger) storage.Storer { func New(s storage.Storer, validStamp postage.ValidStampFn, rcb recovery.Callback, r retrieval.Interface, logger logging.Logger) storage.Storer {
return &store{Storer: s, validStamp: validStamp, recoveryCallback: rcb, retrieval: r, logger: logger} return &store{Storer: s, validStamp: validStamp, recoveryCallback: rcb, retrieval: r, logger: logger}
} }
......
...@@ -16,6 +16,7 @@ import ( ...@@ -16,6 +16,7 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/postage"
postagetesting "github.com/ethersphere/bee/pkg/postage/testing" postagetesting "github.com/ethersphere/bee/pkg/postage/testing"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/recovery" "github.com/ethersphere/bee/pkg/recovery"
...@@ -186,7 +187,7 @@ func TestInvalidPostageStamp(t *testing.T) { ...@@ -186,7 +187,7 @@ func TestInvalidPostageStamp(t *testing.T) {
} }
// returns a mock retrieval protocol, a mock local storage and a netstore // returns a mock retrieval protocol, a mock local storage and a netstore
func newRetrievingNetstore(rec recovery.Callback, validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error)) (ret *retrievalMock, mockStore *mock.MockStorer, ns storage.Storer) { func newRetrievingNetstore(rec recovery.Callback, validStamp postage.ValidStampFn) (ret *retrievalMock, mockStore *mock.MockStorer, ns storage.Storer) {
retrieve := &retrievalMock{} retrieve := &retrievalMock{}
store := mock.NewStorer() store := mock.NewStorer()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
......
...@@ -136,6 +136,7 @@ type Options struct { ...@@ -136,6 +136,7 @@ type Options struct {
PaymentTolerance string PaymentTolerance string
PaymentEarly string PaymentEarly string
ResolverConnectionCfgs []multiresolver.ConnectionConfig ResolverConnectionCfgs []multiresolver.ConnectionConfig
RetrievalCaching bool
GatewayMode bool GatewayMode bool
BootnodeMode bool BootnodeMode bool
SwapEndpoint string SwapEndpoint string
...@@ -609,7 +610,7 @@ func NewBee(addr string, publicKey *ecdsa.PublicKey, signer crypto.Signer, netwo ...@@ -609,7 +610,7 @@ func NewBee(addr string, publicKey *ecdsa.PublicKey, signer crypto.Signer, netwo
pricing.SetPaymentThresholdObserver(acc) pricing.SetPaymentThresholdObserver(acc)
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer, o.RetrievalCaching, validStamp)
tagService := tags.NewTags(stateStore, logger) tagService := tags.NewTags(stateStore, logger)
b.tagsCloser = tagService b.tagsCloser = tagService
......
...@@ -113,8 +113,10 @@ func toSignDigest(addr, batchId, index, timestamp []byte) ([]byte, error) { ...@@ -113,8 +113,10 @@ func toSignDigest(addr, batchId, index, timestamp []byte) ([]byte, error) {
return h.Sum(nil), nil return h.Sum(nil), nil
} }
type ValidStampFn func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error)
// ValidStamp returns a stampvalidator function passed to protocols with chunk entrypoints. // ValidStamp returns a stampvalidator function passed to protocols with chunk entrypoints.
func ValidStamp(batchStore Storer) func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) { func ValidStamp(batchStore Storer) ValidStampFn {
return func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) { return func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) {
stamp := new(Stamp) stamp := new(Stamp)
err := stamp.UnmarshalBinary(stampBytes) err := stamp.UnmarshalBinary(stampBytes)
......
...@@ -21,6 +21,7 @@ import ( ...@@ -21,6 +21,7 @@ import (
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/postage"
"github.com/ethersphere/bee/pkg/pullsync/pb" "github.com/ethersphere/bee/pkg/pullsync/pb"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage" "github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
...@@ -67,7 +68,7 @@ type Syncer struct { ...@@ -67,7 +68,7 @@ type Syncer struct {
quit chan struct{} quit chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
unwrap func(swarm.Chunk) unwrap func(swarm.Chunk)
validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error) validStamp postage.ValidStampFn
ruidMtx sync.Mutex ruidMtx sync.Mutex
ruidCtx map[uint32]func() ruidCtx map[uint32]func()
...@@ -76,7 +77,7 @@ type Syncer struct { ...@@ -76,7 +77,7 @@ type Syncer struct {
io.Closer io.Closer
} }
func New(streamer p2p.Streamer, storage pullstorage.Storer, unwrap func(swarm.Chunk), validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), logger logging.Logger) *Syncer { func New(streamer p2p.Streamer, storage pullstorage.Storer, unwrap func(swarm.Chunk), validStamp postage.ValidStampFn, logger logging.Logger) *Syncer {
return &Syncer{ return &Syncer{
streamer: streamer, streamer: streamer,
storage: storage, storage: storage,
......
...@@ -72,7 +72,7 @@ type PushSync struct { ...@@ -72,7 +72,7 @@ type PushSync struct {
pricer pricer.Interface pricer pricer.Interface
metrics metrics metrics metrics
tracer *tracing.Tracer tracer *tracing.Tracer
validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error) validStamp postage.ValidStampFn
signer crypto.Signer signer crypto.Signer
isFullNode bool isFullNode bool
warmupPeriod time.Time warmupPeriod time.Time
...@@ -83,7 +83,7 @@ var defaultTTL = 20 * time.Second // request time to live ...@@ -83,7 +83,7 @@ var defaultTTL = 20 * time.Second // request time to live
var timeToWaitForPushsyncToNeighbor = 3 * time.Second // time to wait to get a receipt for a chunk var timeToWaitForPushsyncToNeighbor = 3 * time.Second // time to wait to get a receipt for a chunk
var nPeersToPushsync = 3 // number of peers to replicate to as receipt is sent upstream var nPeersToPushsync = 3 // number of peers to replicate to as receipt is sent upstream
func New(address swarm.Address, blockHash []byte, streamer p2p.StreamerDisconnecter, storer storage.Putter, topology topology.Driver, tagger *tags.Tags, isFullNode bool, unwrap func(swarm.Chunk), validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, signer crypto.Signer, tracer *tracing.Tracer, warmupTime time.Duration) *PushSync { func New(address swarm.Address, blockHash []byte, streamer p2p.StreamerDisconnecter, storer storage.Putter, topology topology.Driver, tagger *tags.Tags, isFullNode bool, unwrap func(swarm.Chunk), validStamp postage.ValidStampFn, logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, signer crypto.Signer, tracer *tracing.Tracer, warmupTime time.Duration) *PushSync {
ps := &PushSync{ ps := &PushSync{
address: address, address: address,
blockHash: blockHash, blockHash: blockHash,
......
...@@ -233,12 +233,12 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store ...@@ -233,12 +233,12 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store
return nil return nil
}} }}
server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps0, logger, serverMockAccounting, pricerMock, nil) server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps0, logger, serverMockAccounting, pricerMock, nil, false, noopStampValidator)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(peerID), streamtest.WithBaseAddr(peerID),
) )
retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil) retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil, false, noopStampValidator)
validStamp := func(ch swarm.Chunk, stamp []byte) (swarm.Chunk, error) { validStamp := func(ch swarm.Chunk, stamp []byte) (swarm.Chunk, error) {
return ch.WithStamp(postage.NewStamp(nil, nil, nil, nil)), nil return ch.WithStamp(postage.NewStamp(nil, nil, nil, nil)), nil
} }
...@@ -267,3 +267,7 @@ func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []by ...@@ -267,3 +267,7 @@ func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []by
mp.callbackC <- true mp.callbackC <- true
return nil return nil
} }
var noopStampValidator = func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) {
return chunk, nil
}
...@@ -64,9 +64,11 @@ type Service struct { ...@@ -64,9 +64,11 @@ type Service struct {
metrics metrics metrics metrics
pricer pricer.Interface pricer pricer.Interface
tracer *tracing.Tracer tracer *tracing.Tracer
caching bool
validStamp postage.ValidStampFn
} }
func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, tracer *tracing.Tracer) *Service { func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, tracer *tracing.Tracer, forwarderCaching bool, validStamp postage.ValidStampFn) *Service {
return &Service{ return &Service{
addr: addr, addr: addr,
streamer: streamer, streamer: streamer,
...@@ -77,6 +79,8 @@ func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunk ...@@ -77,6 +79,8 @@ func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunk
pricer: pricer, pricer: pricer,
metrics: newMetrics(), metrics: newMetrics(),
tracer: tracer, tracer: tracer,
caching: forwarderCaching,
validStamp: validStamp,
} }
} }
...@@ -416,6 +420,8 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -416,6 +420,8 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
ctx = context.WithValue(ctx, requestSourceContextKey{}, p.Address.String()) ctx = context.WithValue(ctx, requestSourceContextKey{}, p.Address.String())
addr := swarm.NewAddress(req.Addr) addr := swarm.NewAddress(req.Addr)
forwarded := false
chunk, err := s.storer.Get(ctx, storage.ModeGetRequest, addr) chunk, err := s.storer.Get(ctx, storage.ModeGetRequest, addr)
if err != nil { if err != nil {
if errors.Is(err, storage.ErrNotFound) { if errors.Is(err, storage.ErrNotFound) {
...@@ -424,11 +430,11 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -424,11 +430,11 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
if err != nil { if err != nil {
return fmt.Errorf("retrieve chunk: %w", err) return fmt.Errorf("retrieve chunk: %w", err)
} }
forwarded = true
} else { } else {
return fmt.Errorf("get from store: %w", err) return fmt.Errorf("get from store: %w", err)
} }
} }
stamp, err := chunk.Stamp().MarshalBinary() stamp, err := chunk.Stamp().MarshalBinary()
if err != nil { if err != nil {
return fmt.Errorf("stamp marshal: %w", err) return fmt.Errorf("stamp marshal: %w", err)
...@@ -449,6 +455,28 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -449,6 +455,28 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
} }
s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String()) s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String())
// debit price from p's balance // debit price from p's balance
return debit.Apply() if err := debit.Apply(); err != nil {
return fmt.Errorf("apply debit: %w", err)
}
// cache the request last, so that putting to the localstore does not slow down the request flow
if s.caching && forwarded {
putMode := storage.ModePutRequest
cch, err := s.validStamp(chunk, stamp)
if err != nil {
// if a chunk with an invalid postage stamp was received
// we force it into the cache.
putMode = storage.ModePutRequestCache
cch = chunk
}
_, err = s.storer.Put(ctx, putMode, cch)
if err != nil {
return fmt.Errorf("retrieve cache put: %w", err)
}
}
return nil
} }
...@@ -61,7 +61,7 @@ func TestDelivery(t *testing.T) { ...@@ -61,7 +61,7 @@ func TestDelivery(t *testing.T) {
} }
// create the server that will handle the request and will serve the response // create the server that will handle the request and will serve the response
server := retrieval.New(swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, nil) server := retrieval.New(swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, nil, false, noopStampValidator)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(clientAddr), streamtest.WithBaseAddr(clientAddr),
...@@ -78,7 +78,7 @@ func TestDelivery(t *testing.T) { ...@@ -78,7 +78,7 @@ func TestDelivery(t *testing.T) {
return nil return nil
}} }}
client := retrieval.New(clientAddr, clientMockStorer, recorder, ps, logger, clientMockAccounting, pricerMock, nil) client := retrieval.New(clientAddr, clientMockStorer, recorder, ps, logger, clientMockAccounting, pricerMock, nil, false, noopStampValidator)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel() defer cancel()
v, err := client.RetrieveChunk(ctx, chunk.Address(), true) v, err := client.RetrieveChunk(ctx, chunk.Address(), true)
...@@ -167,14 +167,14 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -167,14 +167,14 @@ func TestRetrieveChunk(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
server := retrieval.New(serverAddress, serverStorer, nil, nil, logger, accountingmock.NewAccounting(), pricer, nil) server := retrieval.New(serverAddress, serverStorer, nil, nil, logger, accountingmock.NewAccounting(), pricer, nil, false, noopStampValidator)
recorder := streamtest.New(streamtest.WithProtocols(server.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(server.Protocol()))
clientSuggester := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error { clientSuggester := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddress, 0) _, _, _ = f(serverAddress, 0)
return nil return nil
}} }}
client := retrieval.New(clientAddress, nil, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, nil) client := retrieval.New(clientAddress, nil, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, nil, false, noopStampValidator)
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true) got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil { if err != nil {
...@@ -207,11 +207,15 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -207,11 +207,15 @@ func TestRetrieveChunk(t *testing.T) {
accountingmock.NewAccounting(), accountingmock.NewAccounting(),
pricer, pricer,
nil, nil,
false,
noopStampValidator,
) )
forwarderStore := storemock.NewStorer()
forwarder := retrieval.New( forwarder := retrieval.New(
forwarderAddress, forwarderAddress,
storemock.NewStorer(), // no chunk in forwarder's store forwarderStore, // no chunk in forwarder's store
streamtest.New(streamtest.WithProtocols(server.Protocol())), // connect to server streamtest.New(streamtest.WithProtocols(server.Protocol())), // connect to server
mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error { mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddress, 0) // suggest server's address _, _, _ = f(serverAddress, 0) // suggest server's address
...@@ -221,6 +225,8 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -221,6 +225,8 @@ func TestRetrieveChunk(t *testing.T) {
accountingmock.NewAccounting(), accountingmock.NewAccounting(),
pricer, pricer,
nil, nil,
true, // note explicit caching
noopStampValidator,
) )
client := retrieval.New( client := retrieval.New(
...@@ -235,8 +241,14 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -235,8 +241,14 @@ func TestRetrieveChunk(t *testing.T) {
accountingmock.NewAccounting(), accountingmock.NewAccounting(),
pricer, pricer,
nil, nil,
false,
noopStampValidator,
) )
if got, _ := forwarderStore.Has(context.Background(), chunk.Address()); got {
t.Fatalf("forwarder node already has chunk")
}
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true) got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -244,6 +256,11 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -244,6 +256,11 @@ func TestRetrieveChunk(t *testing.T) {
if !bytes.Equal(got.Data(), chunk.Data()) { if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data()) t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
} }
if got, _ := forwarderStore.Has(context.Background(), chunk.Address()); !got {
t.Fatalf("forwarder did not cache chunk")
}
}) })
} }
...@@ -301,8 +318,8 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -301,8 +318,8 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
return peerSuggester return peerSuggester
} }
server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil) server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil) server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
t.Run("peer not reachable", func(t *testing.T) { t.Run("peer not reachable", func(t *testing.T) {
ranOnce := true ranOnce := true
...@@ -330,7 +347,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -330,7 +347,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
streamtest.WithBaseAddr(clientAddress), streamtest.WithBaseAddr(clientAddress),
) )
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil) client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true) got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil { if err != nil {
...@@ -366,7 +383,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -366,7 +383,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
), ),
) )
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil) client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true) got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil { if err != nil {
...@@ -395,8 +412,8 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -395,8 +412,8 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
server1MockAccounting := accountingmock.NewAccounting() server1MockAccounting := accountingmock.NewAccounting()
server2MockAccounting := accountingmock.NewAccounting() server2MockAccounting := accountingmock.NewAccounting()
server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, server1MockAccounting, pricerMock, nil) server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, server1MockAccounting, pricerMock, nil, false, noopStampValidator)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, server2MockAccounting, pricerMock, nil) server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, server2MockAccounting, pricerMock, nil, false, noopStampValidator)
// NOTE: must be more than retry duration // NOTE: must be more than retry duration
// (here one second more) // (here one second more)
...@@ -430,7 +447,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -430,7 +447,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
clientMockAccounting := accountingmock.NewAccounting() clientMockAccounting := accountingmock.NewAccounting()
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, clientMockAccounting, pricerMock, nil) client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, clientMockAccounting, pricerMock, nil, false, noopStampValidator)
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true) got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil { if err != nil {
...@@ -468,21 +485,25 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -468,21 +485,25 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
t.Run("peer forwards request", func(t *testing.T) { t.Run("peer forwards request", func(t *testing.T) {
// server 2 has the chunk // server 2 has the chunk
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil) server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
server1Recorder := streamtest.New( server1Recorder := streamtest.New(
streamtest.WithProtocols(server2.Protocol()), streamtest.WithProtocols(server2.Protocol()),
) )
// server 1 will forward request to server 2 // server 1 will forward request to server 2
server1 := retrieval.New(serverAddress1, serverStorer1, server1Recorder, peerSuggesterFn(serverAddress2), logger, accountingmock.NewAccounting(), pricerMock, nil) server1 := retrieval.New(serverAddress1, serverStorer1, server1Recorder, peerSuggesterFn(serverAddress2), logger, accountingmock.NewAccounting(), pricerMock, nil, true, noopStampValidator)
clientRecorder := streamtest.New( clientRecorder := streamtest.New(
streamtest.WithProtocols(server1.Protocol()), streamtest.WithProtocols(server1.Protocol()),
) )
// client only knows about server 1 // client only knows about server 1
client := retrieval.New(clientAddress, nil, clientRecorder, peerSuggesterFn(serverAddress1), logger, accountingmock.NewAccounting(), pricerMock, nil) client := retrieval.New(clientAddress, nil, clientRecorder, peerSuggesterFn(serverAddress1), logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
if got, _ := serverStorer1.Has(context.Background(), chunk.Address()); got {
t.Fatalf("forwarder node already has chunk")
}
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true) got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil { if err != nil {
...@@ -492,6 +513,10 @@ func TestRetrievePreemptiveRetry(t *testing.T) { ...@@ -492,6 +513,10 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
if !bytes.Equal(got.Data(), chunk.Data()) { if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data()) t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
} }
if got, _ := serverStorer1.Has(context.Background(), chunk.Address()); !got {
t.Fatalf("forwarder node does not have chunk")
}
}) })
} }
...@@ -505,3 +530,7 @@ func (s mockPeerSuggester) EachPeer(topology.EachPeerFunc) error { ...@@ -505,3 +530,7 @@ func (s mockPeerSuggester) EachPeer(topology.EachPeerFunc) error {
func (s mockPeerSuggester) EachPeerRev(f topology.EachPeerFunc) error { func (s mockPeerSuggester) EachPeerRev(f topology.EachPeerFunc) error {
return s.eachPeerRevFunc(f) return s.eachPeerRevFunc(f)
} }
var noopStampValidator = func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) {
return chunk, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment