Commit 18bdc869 authored by metacertain's avatar metacertain Committed by GitHub

feat: Retrial behavior in retrieval (#1780)

add forwarder reattempt behaviour on overdrafts
parent 165d1f67
......@@ -47,7 +47,7 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
// request from network
ch, err = s.retrieval.RetrieveChunk(ctx, addr)
ch, err = s.retrieval.RetrieveChunk(ctx, addr, true)
if err != nil {
targets := sctx.GetTargets(ctx)
if targets == nil || s.recoveryCallback == nil {
......
......@@ -200,7 +200,7 @@ type retrievalMock struct {
addr swarm.Address
}
func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address, orig bool) (chunk swarm.Chunk, err error) {
if r.failure {
return nil, fmt.Errorf("chunk not found")
}
......@@ -219,7 +219,7 @@ func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets pss.Targets
mr.callbackC <- true
}
func (r *mockRecovery) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
func (r *mockRecovery) RetrieveChunk(ctx context.Context, addr swarm.Address, orig bool) (chunk swarm.Chunk, err error) {
return nil, fmt.Errorf("chunk not found")
}
......
......@@ -227,9 +227,16 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.Callback) storage.Store
_, _, _ = f(peerID, 0)
return nil
}}
server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps, logger, serverMockAccounting, pricerMock, nil)
ps0 := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
// not calling peer iterator on server as it would cause dereference of non existing streamer
return nil
}}
server := retrieval.New(swarm.ZeroAddress, mockStorer, nil, ps0, logger, serverMockAccounting, pricerMock, nil)
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
streamtest.WithBaseAddr(peerID),
)
retrieve := retrieval.New(swarm.ZeroAddress, mockStorer, recorder, ps, logger, serverMockAccounting, pricerMock, nil)
validStamp := func(ch swarm.Chunk, stamp []byte) (swarm.Chunk, error) {
......
......@@ -29,7 +29,7 @@ import (
"github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/tracing"
"github.com/opentracing/opentracing-go"
"golang.org/x/sync/singleflight"
"resenje.org/singleflight"
)
type requestSourceContextKey struct{}
......@@ -43,7 +43,14 @@ const (
var _ Interface = (*Service)(nil)
type Interface interface {
RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error)
RetrieveChunk(ctx context.Context, addr swarm.Address, origin bool) (chunk swarm.Chunk, err error)
}
type retrievalResult struct {
chunk swarm.Chunk
peer swarm.Address
err error
retrieved bool
}
type Service struct {
......@@ -87,18 +94,29 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
}
const (
maxPeers = 5
retrieveChunkTimeout = 10 * time.Second
retrieveRetryIntervalDuration = 5 * time.Second
maxRequestRounds = 5
maxSelects = 8
originSuffix = "_origin"
)
func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) {
func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address, origin bool) (swarm.Chunk, error) {
s.metrics.RequestCounter.Inc()
v, err, _ := s.singleflight.Do(addr.String(), func() (interface{}, error) {
span, logger, ctx := s.tracer.StartSpanFromContext(ctx, "retrieve-chunk", s.logger, opentracing.Tag{Key: "address", Value: addr.String()})
defer span.Finish()
flightRoute := addr.String()
if origin {
flightRoute = addr.String() + originSuffix
}
// topCtx is passing the tracing span to the first singleflight call
topCtx := ctx
v, _, err := s.singleflight.Do(ctx, flightRoute, func(ctx context.Context) (interface{}, error) {
maxPeers := 1
if origin {
maxPeers = maxSelects
}
sp := newSkipPeers()
......@@ -108,51 +126,97 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.
var (
peerAttempt int
peersResults int
resultC = make(chan swarm.Chunk, maxPeers)
errC = make(chan error, maxPeers)
resultC = make(chan retrievalResult, maxSelects)
)
for {
if peerAttempt < maxPeers {
peerAttempt++
requestAttempt := 0
s.metrics.PeerRequestCounter.Inc()
lastTime := time.Now().Unix()
for requestAttempt < maxRequestRounds {
if peerAttempt < maxSelects {
// create a new context without cancelation but
// set the tracing span to the new context from the context of the first caller
ctx := tracing.WithContext(context.Background(), tracing.FromContext(topCtx))
// get the tracing span
span, _, ctx := s.tracer.StartSpanFromContext(ctx, "retrieve-chunk", s.logger, opentracing.Tag{Key: "address", Value: addr.String()})
defer span.Finish()
peerAttempt++
s.metrics.PeerRequestCounter.Inc()
go func() {
chunk, peer, err := s.retrieveChunk(ctx, addr, sp)
if err != nil {
if !peer.IsZero() {
logger.Debugf("retrieval: failed to get chunk %s from peer %s: %v", addr, peer, err)
}
errC <- err
return
}
// cancel the goroutine just with the timeout
ctx, cancel := context.WithTimeout(ctx, retrieveChunkTimeout)
defer cancel()
resultC <- chunk
chunk, peer, requested, err := s.retrieveChunk(ctx, addr, sp)
resultC <- retrievalResult{
chunk: chunk,
peer: peer,
err: err,
retrieved: requested,
}
}()
} else {
ticker.Stop()
resultC <- retrievalResult{}
}
select {
case <-ticker.C:
// break
case chunk := <-resultC:
return chunk, nil
case <-errC:
case res := <-resultC:
if res.retrieved {
if res.err != nil {
if !res.peer.IsZero() {
s.logger.Debugf("retrieval: failed to get chunk %s from peer %s: %v", addr, res.peer, res.err)
}
peersResults++
} else {
return res.chunk, nil
}
}
case <-ctx.Done():
logger.Tracef("retrieval: failed to get chunk %s: %v", addr, ctx.Err())
s.logger.Tracef("retrieval: failed to get chunk %s: %v", addr, ctx.Err())
return nil, fmt.Errorf("retrieval: %w", ctx.Err())
}
// all results received
// all results received, only successfully attempted requests are counted
if peersResults >= maxPeers {
logger.Tracef("retrieval: failed to get chunk %s", addr)
s.logger.Tracef("retrieval: failed to get chunk %s", addr)
return nil, storage.ErrNotFound
}
// if we have not counted enough successful attempts but out of selection amount, reset
if peerAttempt >= maxSelects {
if !origin {
return nil, storage.ErrNotFound
}
requestAttempt++
timeNow := time.Now().Unix()
if timeNow > lastTime {
lastTime = timeNow
peerAttempt = 0
sp.Reset()
} else {
select {
case <-time.After(600 * time.Millisecond):
case <-ctx.Done():
s.logger.Tracef("retrieval: failed to get chunk %s: %v", addr, ctx.Err())
return nil, fmt.Errorf("retrieval: %w", ctx.Err())
}
}
}
}
// if we have not managed to get results after 5 (maxRequestRounds) rounds of peer selections, give up
return nil, storage.ErrNotFound
})
if err != nil {
return nil, err
......@@ -161,9 +225,8 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.
return v.(swarm.Chunk), nil
}
func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *skipPeers) (chunk swarm.Chunk, peer swarm.Address, err error) {
func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *skipPeers) (chunk swarm.Chunk, peer swarm.Address, requested bool, err error) {
startTimer := time.Now()
v := ctx.Value(requestSourceContextKey{})
sourcePeerAddr := swarm.Address{}
// allow upstream requests if this node is the source of the request
......@@ -184,7 +247,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
defer cancel()
peer, err = s.closestPeer(addr, sp.All(), allowUpstream)
if err != nil {
return nil, peer, fmt.Errorf("get closest for address %s, allow upstream %v: %w", addr.String(), allowUpstream, err)
return nil, peer, false, fmt.Errorf("get closest for address %s, allow upstream %v: %w", addr.String(), allowUpstream, err)
}
peerPO := swarm.Proximity(s.addr.Bytes(), peer.Bytes())
......@@ -201,16 +264,25 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
Inc()
}
sp.Add(peer)
// compute the peer's price for this chunk for price header
chunkPrice := s.pricer.PeerPrice(peer, addr)
// Reserve to see whether we can request the chunk
err = s.accounting.Reserve(ctx, peer, chunkPrice)
if err != nil {
sp.AddOverdraft(peer)
return nil, peer, false, err
}
defer s.accounting.Release(peer, chunkPrice)
sp.Add(peer)
s.logger.Tracef("retrieval: requesting chunk %s from peer %s", addr, peer)
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil {
s.metrics.TotalErrors.Inc()
return nil, peer, fmt.Errorf("new stream: %w", err)
return nil, peer, false, fmt.Errorf("new stream: %w", err)
}
defer func() {
......@@ -221,25 +293,18 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
}
}()
// Reserve to see whether we can request the chunk
err = s.accounting.Reserve(ctx, peer, chunkPrice)
if err != nil {
return nil, peer, err
}
defer s.accounting.Release(peer, chunkPrice)
w, r := protobuf.NewWriterAndReader(stream)
if err := w.WriteMsgWithContext(ctx, &pb.Request{
Addr: addr.Bytes(),
}); err != nil {
s.metrics.TotalErrors.Inc()
return nil, peer, fmt.Errorf("write request: %w peer %s", err, peer.String())
return nil, peer, false, fmt.Errorf("write request: %w peer %s", err, peer.String())
}
var d pb.Delivery
if err := r.ReadMsgWithContext(ctx, &d); err != nil {
s.metrics.TotalErrors.Inc()
return nil, peer, fmt.Errorf("read delivery: %w peer %s", err, peer.String())
return nil, peer, true, fmt.Errorf("read delivery: %w peer %s", err, peer.String())
}
s.metrics.RetrieveChunkPeerPOTimer.
WithLabelValues(strconv.Itoa(int(peerPO))).
......@@ -249,25 +314,24 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *ski
stamp := new(postage.Stamp)
err = stamp.UnmarshalBinary(d.Stamp)
if err != nil {
return nil, peer, fmt.Errorf("stamp unmarshal: %w", err)
return nil, peer, true, fmt.Errorf("stamp unmarshal: %w", err)
}
chunk = swarm.NewChunk(addr, d.Data).WithStamp(stamp)
if !cac.Valid(chunk) {
if !soc.Valid(chunk) {
s.metrics.InvalidChunkRetrieved.Inc()
s.metrics.TotalErrors.Inc()
return nil, peer, swarm.ErrInvalidChunk
return nil, peer, true, swarm.ErrInvalidChunk
}
}
// credit the peer after successful delivery
err = s.accounting.Credit(peer, chunkPrice)
if err != nil {
return nil, peer, err
return nil, peer, true, err
}
s.metrics.ChunkPrice.Observe(float64(chunkPrice))
return chunk, peer, err
return chunk, peer, true, err
}
// closestPeer returns address of the peer that is closest to the chunk with
......@@ -349,7 +413,7 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
// forward the request
chunk, err = s.RetrieveChunk(ctx, addr)
chunk, err = s.RetrieveChunk(ctx, addr, false)
if err != nil {
return fmt.Errorf("retrieve chunk: %w", err)
}
......@@ -375,7 +439,6 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
}
s.logger.Tracef("retrieval protocol debiting peer %s", p.Address.String())
// debit price from p's balance
return debit.Apply()
}
......@@ -81,7 +81,7 @@ func TestDelivery(t *testing.T) {
client := retrieval.New(clientAddr, clientMockStorer, recorder, ps, logger, clientMockAccounting, pricerMock, nil)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
v, err := client.RetrieveChunk(ctx, chunk.Address())
v, err := client.RetrieveChunk(ctx, chunk.Address(), true)
if err != nil {
t.Fatal(err)
}
......@@ -176,7 +176,7 @@ func TestRetrieveChunk(t *testing.T) {
}}
client := retrieval.New(clientAddress, nil, recorder, clientSuggester, logger, accountingmock.NewAccounting(), pricer, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
t.Fatal(err)
}
......@@ -237,7 +237,7 @@ func TestRetrieveChunk(t *testing.T) {
nil,
)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
t.Fatal(err)
}
......@@ -332,7 +332,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
t.Fatal(err)
}
......@@ -368,7 +368,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
t.Fatal(err)
}
......@@ -432,7 +432,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, clientMockAccounting, pricerMock, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
t.Fatal(err)
}
......@@ -484,7 +484,7 @@ func TestRetrievePreemptiveRetry(t *testing.T) {
// client only knows about server 1
client := retrieval.New(clientAddress, nil, clientRecorder, peerSuggesterFn(serverAddress1), logger, accountingmock.NewAccounting(), pricerMock, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
got, err := client.RetrieveChunk(context.Background(), chunk.Address(), true)
if err != nil {
t.Fatal(err)
}
......
......@@ -11,6 +11,7 @@ import (
)
type skipPeers struct {
overdraftAddresses []swarm.Address
addresses []swarm.Address
mu sync.Mutex
}
......@@ -23,7 +24,13 @@ func (s *skipPeers) All() []swarm.Address {
s.mu.Lock()
defer s.mu.Unlock()
return append(s.addresses[:0:0], s.addresses...)
return append(append(s.addresses[:0:0], s.addresses...), s.overdraftAddresses...)
}
func (s *skipPeers) Reset() {
s.mu.Lock()
defer s.mu.Unlock()
s.overdraftAddresses = []swarm.Address{}
}
func (s *skipPeers) Add(address swarm.Address) {
......@@ -38,3 +45,16 @@ func (s *skipPeers) Add(address swarm.Address) {
s.addresses = append(s.addresses, address)
}
func (s *skipPeers) AddOverdraft(address swarm.Address) {
s.mu.Lock()
defer s.mu.Unlock()
for _, a := range s.overdraftAddresses {
if a.Equal(address) {
return
}
}
s.overdraftAddresses = append(s.overdraftAddresses, address)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment