Commit 1fc07169 authored by metacertain's avatar metacertain Committed by GitHub

feat: cache on forwarding in retrieval (#2234)

parent cd58784b
......@@ -70,6 +70,7 @@ const (
optionNameBlockTime = "block-time"
optionWarmUpTime = "warmup-time"
optionNameMainNet = "mainnet"
optionNameRetrievalCaching = "cache-retrieval"
)
func init() {
......@@ -247,6 +248,7 @@ func (c *command) setAllFlags(cmd *cobra.Command) {
cmd.Flags().String(optionNameSwapDeploymentGasPrice, "", "gas price in wei to use for deployment and funding")
cmd.Flags().Duration(optionWarmUpTime, time.Minute*20, "time to warmup the node before pull/push protocols can be kicked off.")
cmd.Flags().Bool(optionNameMainNet, false, "triggers connect to main net bootnodes.")
cmd.Flags().Bool(optionNameRetrievalCaching, true, "enable forwarded content caching")
}
func newLogger(cmd *cobra.Command, verbosity string) (logging.Logger, error) {
......
......@@ -198,6 +198,7 @@ inability to use, or your interaction with other nodes or the software.`)
DeployGasPrice: c.config.GetString(optionNameSwapDeploymentGasPrice),
WarmupTime: c.config.GetDuration(optionWarmUpTime),
ChainID: networkConfig.chainID,
RetrievalCaching: c.config.GetBool(optionNameRetrievalCaching),
})
if err != nil {
return err
......
......@@ -14,6 +14,7 @@ import (
"fmt"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/postage"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/sctx"
......@@ -25,7 +26,7 @@ type store struct {
storage.Storer
retrieval retrieval.Interface
logger logging.Logger
validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error)
validStamp postage.ValidStampFn
recoveryCallback recovery.Callback // this is the callback to be executed when a chunk fails to be retrieved
}
......@@ -34,7 +35,7 @@ var (
)
// New returns a new NetStore that wraps a given Storer.
func New(s storage.Storer, validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), rcb recovery.Callback, r retrieval.Interface, logger logging.Logger) storage.Storer {
func New(s storage.Storer, validStamp postage.ValidStampFn, rcb recovery.Callback, r retrieval.Interface, logger logging.Logger) storage.Storer {
return &store{Storer: s, validStamp: validStamp, recoveryCallback: rcb, retrieval: r, logger: logger}
}
......
......@@ -16,6 +16,7 @@ import (
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/postage"
postagetesting "github.com/ethersphere/bee/pkg/postage/testing"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/recovery"
......@@ -186,7 +187,7 @@ func TestInvalidPostageStamp(t *testing.T) {
}
// returns a mock retrieval protocol, a mock local storage and a netstore
func newRetrievingNetstore(rec recovery.Callback, validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error)) (ret *retrievalMock, mockStore *mock.MockStorer, ns storage.Storer) {
func newRetrievingNetstore(rec recovery.Callback, validStamp postage.ValidStampFn) (ret *retrievalMock, mockStore *mock.MockStorer, ns storage.Storer) {
retrieve := &retrievalMock{}
store := mock.NewStorer()
logger := logging.New(ioutil.Discard, 0)
......
......@@ -136,6 +136,7 @@ type Options struct {
PaymentTolerance string
PaymentEarly string
ResolverConnectionCfgs []multiresolver.ConnectionConfig
RetrievalCaching bool
GatewayMode bool
BootnodeMode bool
SwapEndpoint string
......@@ -609,7 +610,7 @@ func NewBee(addr string, publicKey *ecdsa.PublicKey, signer crypto.Signer, netwo
pricing.SetPaymentThresholdObserver(acc)
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer)
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, pricer, tracer, o.RetrievalCaching, validStamp)
tagService := tags.NewTags(stateStore, logger)
b.tagsCloser = tagService
......
......@@ -113,8 +113,10 @@ func toSignDigest(addr, batchId, index, timestamp []byte) ([]byte, error) {
return h.Sum(nil), nil
}
type ValidStampFn func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error)
// ValidStamp returns a stampvalidator function passed to protocols with chunk entrypoints.
func ValidStamp(batchStore Storer) func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) {
func ValidStamp(batchStore Storer) ValidStampFn {
return func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) {
stamp := new(Stamp)
err := stamp.UnmarshalBinary(stampBytes)
......
......@@ -21,6 +21,7 @@ import (
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/postage"
"github.com/ethersphere/bee/pkg/pullsync/pb"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/soc"
......@@ -67,7 +68,7 @@ type Syncer struct {
quit chan struct{}
wg sync.WaitGroup
unwrap func(swarm.Chunk)
validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error)
validStamp postage.ValidStampFn
ruidMtx sync.Mutex
ruidCtx map[uint32]func()
......@@ -76,7 +77,7 @@ type Syncer struct {
io.Closer
}
func New(streamer p2p.Streamer, storage pullstorage.Storer, unwrap func(swarm.Chunk), validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), logger logging.Logger) *Syncer {
func New(streamer p2p.Streamer, storage pullstorage.Storer, unwrap func(swarm.Chunk), validStamp postage.ValidStampFn, logger logging.Logger) *Syncer {
return &Syncer{
streamer: streamer,
storage: storage,
......
......@@ -72,7 +72,7 @@ type PushSync struct {
pricer pricer.Interface
metrics metrics
tracer *tracing.Tracer
validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error)
validStamp postage.ValidStampFn
signer crypto.Signer
isFullNode bool
warmupPeriod time.Time
......@@ -83,7 +83,7 @@ var defaultTTL = 20 * time.Second // request time to live
var timeToWaitForPushsyncToNeighbor = 3 * time.Second // time to wait to get a receipt for a chunk
var nPeersToPushsync = 3 // number of peers to replicate to as receipt is sent upstream
func New(address swarm.Address, blockHash []byte, streamer p2p.StreamerDisconnecter, storer storage.Putter, topology topology.Driver, tagger *tags.Tags, isFullNode bool, unwrap func(swarm.Chunk), validStamp func(swarm.Chunk, []byte) (swarm.Chunk, error), logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, signer crypto.Signer, tracer *tracing.Tracer, warmupTime time.Duration) *PushSync {
func New(address swarm.Address, blockHash []byte, streamer p2p.StreamerDisconnecter, storer storage.Putter, topology topology.Driver, tagger *tags.Tags, isFullNode bool, unwrap func(swarm.Chunk), validStamp postage.ValidStampFn, logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, signer crypto.Signer, tracer *tracing.Tracer, warmupTime time.Duration) *PushSync {
ps := &PushSync{
address: address,
blockHash: blockHash,
......
......@@ -233,12 +233,12 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store
return nil
}}
server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps0, logger, serverMockAccounting, pricerMock, nil)
server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps0, logger, serverMockAccounting, pricerMock, nil, false, noopStampValidator)
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(peerID),
)
retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil)
retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil, false, noopStampValidator)
validStamp := func(ch swarm.Chunk, stamp []byte) (swarm.Chunk, error) {
return ch.WithStamp(postage.NewStamp(nil, nil, nil, nil)), nil
}
......@@ -267,3 +267,7 @@ func (mp *mockPssSender) Send(ctx context.Context, topic pss.Topic, payload []by
mp.callbackC <- true
return nil
}
var noopStampValidator = func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) {
return chunk, nil
}
......@@ -64,9 +64,11 @@ type Service struct {
metrics metrics
pricer pricer.Interface
tracer *tracing.Tracer
caching bool
validStamp postage.ValidStampFn
}
func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, tracer *tracing.Tracer) *Service {
func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer pricer.Interface, tracer *tracing.Tracer, forwarderCaching bool, validStamp postage.ValidStampFn) *Service {
return &Service{
addr: addr,
streamer: streamer,
......@@ -77,6 +79,8 @@ func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunk
pricer: pricer,
metrics: newMetrics(),
tracer: tracer,
caching: forwarderCaching,
validStamp: validStamp,
}
}
......@@ -416,6 +420,8 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
ctx = context.WithValue(ctx, requestSourceContextKey{}, p.Address.String())
addr := swarm.NewAddress(req.Addr)
forwarded := false
chunk, err := s.storer.Get(ctx, storage.ModeGetRequest, addr)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
......@@ -424,11 +430,11 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
if err != nil {
return fmt.Errorf("retrieve chunk: %w", err)
}
forwarded = true
} else {
return fmt.Errorf("get from store: %w", err)
}
}
stamp, err := chunk.Stamp().MarshalBinary()
if err != nil {
return fmt.Errorf("stamp marshal: %w", err)
......@@ -449,6 +455,28 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
}
s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String())
// debit price from p's balance
return debit.Apply()
if err := debit.Apply(); err != nil {
return fmt.Errorf("apply debit: %w", err)
}
// cache the request last, so that putting to the localstore does not slow down the request flow
if s.caching && forwarded {
putMode := storage.ModePutRequest
cch, err := s.validStamp(chunk, stamp)
if err != nil {
// if a chunk with an invalid postage stamp was received
// we force it into the cache.
putMode = storage.ModePutRequestCache
cch = chunk
}
_, err = s.storer.Put(ctx, putMode, cch)
if err != nil {
return fmt.Errorf("retrieve cache put: %w", err)
}
}
return nil
}
......@@ -61,7 +61,7 @@ func TestDelivery(t *testing.T) {
}
// create the server that will handle the request and will serve the response
server := retrieval.New(swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, nil)
server := retrieval.New(swarm.MustParseHexAddress("0034"), mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, nil, false, noopStampValidator)
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(clientAddr),
......@@ -78,7 +78,7 @@ func TestDelivery(t *testing.T) {
return nil
}}
client := retrieval.New(clientAddr, clientMockStorer, recorder, ps, logger, clientMockAccounting, pricerMock, nil)
client := retrieval.New(clientAddr, clientMockStorer, recorder, ps, logger, clientMockAccounting, pricerMock, nil, false, noopStampValidator)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
v, err := client.RetrieveChunk(ctx, chunk.Address(), true)
......@@ -167,14 +167,14 @@ func TestRetrieveChunk(t *testing.T) {
t.Fatal(err)
}
server := retrieval.New(serverAddress, serverStorer, nil, nil, logger, accountingmock.NewAccounting(), pricer, nil)
server := retrieval.New(serverAddress, serverStorer, nil, nil, logger, accountingmock.NewAccounting(), pricer, nil, false, noopStampValidator)
recorder := streamtest.New(streamtest.WithProtocols(server.Protocol()))
clientSuggester := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddress, 0)
return nil
}}
client := retrieval.New(clientAddress, nil, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, nil)
client := retrieval.New(clientAddress, nil, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, nil, false, noopStampValidator)
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
......@@ -207,11 +207,15 @@ func TestRetrieveChunk(t *testing.T) {
accountingmock.NewAccounting(),
pricer,
nil,
false,
noopStampValidator,
)
forwarderStore := storemock.NewStorer()
forwarder := retrieval.New(
forwarderAddress,
storemock.NewStorer(), // no chunk in forwarder's store
forwarderStore, // no chunk in forwarder's store
streamtest.New(streamtest.WithProtocols(server.Protocol())), // connect to server
mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddress, 0) // suggest server's address
......@@ -221,6 +225,8 @@ func TestRetrieveChunk(t *testing.T) {
accountingmock.NewAccounting(),
pricer,
nil,
true, // note explicit caching
noopStampValidator,
)
client := retrieval.New(
......@@ -235,8 +241,14 @@ func TestRetrieveChunk(t *testing.T) {
accountingmock.NewAccounting(),
pricer,
nil,
false,
noopStampValidator,
)
if got, _ := forwarderStore.Has(context.Background(), chunk.Address()); got {
t.Fatalf("forwarder node already has chunk")
}
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
t.Fatal(err)
......@@ -244,6 +256,11 @@ func TestRetrieveChunk(t *testing.T) {
if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
}
if got, _ := forwarderStore.Has(context.Background(), chunk.Address()); !got {
t.Fatalf("forwarder did not cache chunk")
}
})
}
......@@ -301,8 +318,8 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
return peerSuggester
}
server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil)
server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
t.Run("peer not reachable", func(t *testing.T) {
ranOnce := true
......@@ -330,7 +347,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
streamtest.WithBaseAddr(clientAddress),
)
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil)
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
......@@ -366,7 +383,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
),
)
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil)
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
......@@ -395,8 +412,8 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
server1MockAccounting := accountingmock.NewAccounting()
server2MockAccounting := accountingmock.NewAccounting()
server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, server1MockAccounting, pricerMock, nil)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, server2MockAccounting, pricerMock, nil)
server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, server1MockAccounting, pricerMock, nil, false, noopStampValidator)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, server2MockAccounting, pricerMock, nil, false, noopStampValidator)
// NOTE: must be more than retry duration
// (here one second more)
......@@ -430,7 +447,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
clientMockAccounting := accountingmock.NewAccounting()
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, clientMockAccounting, pricerMock, nil)
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, clientMockAccounting, pricerMock, nil, false, noopStampValidator)
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
......@@ -468,21 +485,25 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
t.Run("peer forwards request", func(t *testing.T) {
// server 2 has the chunk
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
server1Recorder := streamtest.New(
streamtest.WithProtocols(server2.Protocol()),
)
// server 1 will forward request to server 2
server1 := retrieval.New(serverAddress1, serverStorer1, server1Recorder, peerSuggesterFn(serverAddress2), logger, accountingmock.NewAccounting(), pricerMock, nil)
server1 := retrieval.New(serverAddress1, serverStorer1, server1Recorder, peerSuggesterFn(serverAddress2), logger, accountingmock.NewAccounting(), pricerMock, nil, true, noopStampValidator)
clientRecorder := streamtest.New(
streamtest.WithProtocols(server1.Protocol()),
)
// client only knows about server 1
client := retrieval.New(clientAddress, nil, clientRecorder, peerSuggesterFn(serverAddress1), logger, accountingmock.NewAccounting(), pricerMock, nil)
client := retrieval.New(clientAddress, nil, clientRecorder, peerSuggesterFn(serverAddress1), logger, accountingmock.NewAccounting(), pricerMock, nil, false, noopStampValidator)
if got, _ := serverStorer1.Has(context.Background(), chunk.Address()); got {
t.Fatalf("forwarder node already has chunk")
}
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
......@@ -492,6 +513,10 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
}
if got, _ := serverStorer1.Has(context.Background(), chunk.Address()); !got {
t.Fatalf("forwarder node does not have chunk")
}
})
}
......@@ -505,3 +530,7 @@ func (s mockPeerSuggester) EachPeer(topology.EachPeerFunc) error {
func (s mockPeerSuggester) EachPeerRev(f topology.EachPeerFunc) error {
return s.eachPeerRevFunc(f)
}
var noopStampValidator = func(chunk swarm.Chunk, stampBytes []byte) (swarm.Chunk, error) {
return chunk, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment