Commit 3c564de4 authored by Janoš Guljaš's avatar Janoš Guljaš Committed by GitHub

decouple retrieval forwarding from netstore (#805)

parent 2918ec2e
...@@ -343,7 +343,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -343,7 +343,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
chunkvalidator := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator()) chunkvalidator := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator())
retrieve := retrieval.New(swarmAddress, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator, tracer) retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator, tracer)
tagg := tags.NewTags(stateStore, logger) tagg := tags.NewTags(stateStore, logger)
b.tagsCloser = tagg b.tagsCloser = tagg
...@@ -368,7 +368,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey, ...@@ -368,7 +368,6 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
} else { } else {
ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator) ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator)
} }
retrieve.SetStorer(ns)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer) pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
......
...@@ -226,13 +226,11 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.S ...@@ -226,13 +226,11 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.S
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
return nil return nil
}} }}
server := retrieval.New(swarm.ZeroAddress, nil, nil, logger, serverMockAccounting, nil, nil, nil) server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps, logger, serverMockAccounting, nil, nil, nil)
server.SetStorer(mockStorer)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
) )
retrieve := retrieval.New(swarm.ZeroAddress, recorder, ps, logger, serverMockAccounting, pricerMock, nil, nil) retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil, nil)
retrieve.SetStorer(mockStorer)
ns := netstore.New(storer, recoveryFunc, retrieve, logger, nil) ns := netstore.New(storer, recoveryFunc, retrieve, logger, nil)
return ns return ns
} }
......
...@@ -6,6 +6,7 @@ package retrieval ...@@ -6,6 +6,7 @@ package retrieval
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
...@@ -49,11 +50,12 @@ type Service struct { ...@@ -49,11 +50,12 @@ type Service struct {
tracer *tracing.Tracer tracer *tracing.Tracer
} }
func New(addr swarm.Address, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, validator swarm.Validator, tracer *tracing.Tracer) *Service { func New(addr swarm.Address, storer storage.Storer, streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, validator swarm.Validator, tracer *tracing.Tracer) *Service {
return &Service{ return &Service{
addr: addr, addr: addr,
streamer: streamer, streamer: streamer,
peerSuggester: chunkPeerer, peerSuggester: chunkPeerer,
storer: storer,
logger: logger, logger: logger,
accounting: accounting, accounting: accounting,
pricer: pricer, pricer: pricer,
...@@ -258,9 +260,18 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -258,9 +260,18 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
defer span.Finish() defer span.Finish()
ctx = context.WithValue(ctx, requestSourceContextKey{}, p.Address.String()) ctx = context.WithValue(ctx, requestSourceContextKey{}, p.Address.String())
chunk, err := s.storer.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(req.Addr)) addr := swarm.NewAddress(req.Addr)
chunk, err := s.storer.Get(ctx, storage.ModeGetRequest, addr)
if err != nil { if err != nil {
return fmt.Errorf("get from store: %w peer %s", err, p.Address.String()) if errors.Is(err, storage.ErrNotFound) {
// forward the request
chunk, err = s.RetrieveChunk(ctx, addr)
if err != nil {
return fmt.Errorf("retrieve chunk: %w", err)
}
} else {
return fmt.Errorf("get from store: %w", err)
}
} }
if err := w.WriteMsgWithContext(ctx, &pb.Delivery{ if err := w.WriteMsgWithContext(ctx, &pb.Delivery{
...@@ -278,8 +289,3 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -278,8 +289,3 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
return nil return nil
} }
// SetStorer sets the storer. This call is not goroutine safe.
func (s *Service) SetStorer(storer storage.Storer) {
s.storer = storer
}
...@@ -51,8 +51,7 @@ func TestDelivery(t *testing.T) { ...@@ -51,8 +51,7 @@ func TestDelivery(t *testing.T) {
pricerMock := accountingmock.NewPricer(price, price) pricerMock := accountingmock.NewPricer(price, price)
// create the server that will handle the request and will serve the response // create the server that will handle the request and will serve the response
server := retrieval.New(swarm.MustParseHexAddress("00112234"), nil, nil, logger, serverMockAccounting, pricerMock, mockValidator, nil) server := retrieval.New(swarm.MustParseHexAddress("00112234"), mockStorer, nil, nil, logger, serverMockAccounting, pricerMock, mockValidator, nil)
server.SetStorer(mockStorer)
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
) )
...@@ -70,8 +69,7 @@ func TestDelivery(t *testing.T) { ...@@ -70,8 +69,7 @@ func TestDelivery(t *testing.T) {
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
return nil return nil
}} }}
client := retrieval.New(swarm.MustParseHexAddress("9ee7add8"), recorder, ps, logger, clientMockAccounting, pricerMock, mockValidator, nil) client := retrieval.New(swarm.MustParseHexAddress("9ee7add8"), clientMockStorer, recorder, ps, logger, clientMockAccounting, pricerMock, mockValidator, nil)
client.SetStorer(clientMockStorer)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel() defer cancel()
v, err := client.RetrieveChunk(ctx, reqAddr) v, err := client.RetrieveChunk(ctx, reqAddr)
...@@ -153,8 +151,7 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -153,8 +151,7 @@ func TestRetrieveChunk(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
server := retrieval.New(serverAddress, nil, nil, logger, accountingmock.NewAccounting(), pricer, mockValidator, nil) server := retrieval.New(serverAddress, serverStorer, nil, nil, logger, accountingmock.NewAccounting(), pricer, mockValidator, nil)
server.SetStorer(serverStorer)
recorder := streamtest.New(streamtest.WithProtocols(server.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(server.Protocol()))
...@@ -162,7 +159,71 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -162,7 +159,71 @@ func TestRetrieveChunk(t *testing.T) {
_, _, _ = f(serverAddress, 0) _, _, _ = f(serverAddress, 0)
return nil return nil
}} }}
client := retrieval.New(clientAddress, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, mockValidator, nil) client := retrieval.New(clientAddress, nil, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, mockValidator, nil)
got, err := client.RetrieveChunk(context.Background(), chunkAddress)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
}
})
t.Run("forward", func(t *testing.T) {
chunkAddress := swarm.MustParseHexAddress("00")
serverAddress := swarm.MustParseHexAddress("01")
forwarderAddress := swarm.MustParseHexAddress("02")
clientAddress := swarm.MustParseHexAddress("03")
serverStorer := storemock.NewStorer()
chunk := swarm.NewChunk(chunkAddress, []byte("some data"))
_, err := serverStorer.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil {
t.Fatal(err)
}
server := retrieval.New(
serverAddress,
serverStorer, // chunk is in sever's store
nil,
nil,
logger,
accountingmock.NewAccounting(),
pricer,
mockValidator,
nil,
)
forwarder := retrieval.New(
forwarderAddress,
storemock.NewStorer(), // no chunk in forwarder's store
streamtest.New(streamtest.WithProtocols(server.Protocol())), // connect to server
mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(serverAddress, 0) // suggest server's address
return nil
}},
logger,
accountingmock.NewAccounting(),
pricer,
mockValidator,
nil,
)
client := retrieval.New(
clientAddress,
storemock.NewStorer(), // no chunk in clients's store
streamtest.New(streamtest.WithProtocols(forwarder.Protocol())), // connect to forwarder
mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(forwarderAddress, 0) // suggest forwarder's address
return nil
}},
logger,
accountingmock.NewAccounting(),
pricer,
mockValidator,
nil,
)
got, err := client.RetrieveChunk(context.Background(), chunkAddress) got, err := client.RetrieveChunk(context.Background(), chunkAddress)
if err != nil { if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment