Commit 1a8858ee authored by Nemanja Zbiljić's avatar Nemanja Zbiljić Committed by GitHub

Perform preemptive retry for chunk retrieval (#1096)

parent d1dcc1dd
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package retrieval
import (
"context"
"github.com/ethersphere/bee/pkg/p2p"
)
func (s *Service) Handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) error {
return s.handler(ctx, p, stream)
}
...@@ -83,46 +83,70 @@ func (s *Service) Protocol() p2p.ProtocolSpec { ...@@ -83,46 +83,70 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
const ( const (
maxPeers = 5 maxPeers = 5
retrieveChunkTimeout = 10 * time.Second retrieveChunkTimeout = 10 * time.Second
retrieveRetryIntervalDuration = 5 * time.Second
) )
func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) {
ctx, cancel := context.WithTimeout(ctx, maxPeers*retrieveChunkTimeout)
defer cancel()
s.metrics.RequestCounter.Inc() s.metrics.RequestCounter.Inc()
v, err, _ := s.singleflight.Do(addr.String(), func() (interface{}, error) { v, err, _ := s.singleflight.Do(addr.String(), func() (interface{}, error) {
span, logger, ctx := s.tracer.StartSpanFromContext(ctx, "retrieve-chunk", s.logger, opentracing.Tag{Key: "address", Value: addr.String()}) span, logger, ctx := s.tracer.StartSpanFromContext(ctx, "retrieve-chunk", s.logger, opentracing.Tag{Key: "address", Value: addr.String()})
defer span.Finish() defer span.Finish()
var skipPeers []swarm.Address sp := newSkipPeers()
ticker := time.NewTicker(retrieveRetryIntervalDuration)
defer ticker.Stop()
var (
peerAttempt int
peersResults int
resultC = make(chan swarm.Chunk, maxPeers)
errC = make(chan error, maxPeers)
)
for {
if peerAttempt < maxPeers {
peerAttempt++
s.metrics.PeerRequestCounter.Inc()
go func() {
chunk, peer, err := s.retrieveChunk(ctx, addr, sp)
if err != nil {
if !peer.IsZero() {
logger.Debugf("retrieval: failed to get chunk %s from peer %s: %v", addr, peer, err)
}
errC <- err
return
}
resultC <- chunk
}()
} else {
ticker.Stop()
}
LOOP:
for i := 0; i < maxPeers; i++ {
select { select {
case <-ticker.C:
// break
case chunk := <-resultC:
return chunk, nil
case <-errC:
peersResults++
case <-ctx.Done(): case <-ctx.Done():
break LOOP logger.Tracef("retrieval: failed to get chunk %s: %v", addr, ctx.Err())
default: return nil, fmt.Errorf("retrieval: %w", ctx.Err())
} }
s.metrics.PeerRequestCounter.Inc() // all results received
if peersResults >= maxPeers {
var peer swarm.Address logger.Tracef("retrieval: failed to get chunk %s", addr)
return nil, storage.ErrNotFound
chunk, peer, err := s.retrieveChunk(ctx, addr, skipPeers)
if err != nil {
if peer.IsZero() {
return nil, err
}
logger.Debugf("retrieval: failed to get chunk %s from peer %s: %v", addr, peer, err)
skipPeers = append(skipPeers, peer)
continue
} }
logger.Tracef("retrieval: got chunk %s from peer %s", addr, peer)
return chunk, nil
} }
logger.Tracef("retrieval: failed to get chunk %s: reached max peers of %v", addr, maxPeers)
return nil, storage.ErrNotFound
}) })
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -131,7 +155,7 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm. ...@@ -131,7 +155,7 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.
return v.(swarm.Chunk), nil return v.(swarm.Chunk), nil
} }
func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPeers []swarm.Address) (chunk swarm.Chunk, peer swarm.Address, err error) { func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, sp *skipPeers) (chunk swarm.Chunk, peer swarm.Address, err error) {
startTimer := time.Now() startTimer := time.Now()
v := ctx.Value(requestSourceContextKey{}) v := ctx.Value(requestSourceContextKey{})
...@@ -143,7 +167,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -143,7 +167,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
if src, ok := v.(string); ok { if src, ok := v.(string); ok {
sourcePeerAddr, err = swarm.ParseHexAddress(src) sourcePeerAddr, err = swarm.ParseHexAddress(src)
if err == nil { if err == nil {
skipPeers = append(skipPeers, sourcePeerAddr) sp.Add(sourcePeerAddr)
} }
// do not allow upstream requests if the request was forwarded to this node // do not allow upstream requests if the request was forwarded to this node
// to avoid the request loops // to avoid the request loops
...@@ -152,8 +176,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -152,8 +176,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
ctx, cancel := context.WithTimeout(ctx, retrieveChunkTimeout) ctx, cancel := context.WithTimeout(ctx, retrieveChunkTimeout)
defer cancel() defer cancel()
peer, err = s.closestPeer(addr, sp.All(), allowUpstream)
peer, err = s.closestPeer(addr, skipPeers, allowUpstream)
if err != nil { if err != nil {
return nil, peer, fmt.Errorf("get closest for address %s, allow upstream %v: %w", addr.String(), allowUpstream, err) return nil, peer, fmt.Errorf("get closest for address %s, allow upstream %v: %w", addr.String(), allowUpstream, err)
} }
...@@ -172,6 +195,8 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -172,6 +195,8 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
Inc() Inc()
} }
sp.Add(peer)
// compute the price we pay for this chunk and reserve it for the rest of this function // compute the price we pay for this chunk and reserve it for the rest of this function
chunkPrice := s.pricer.PeerPrice(peer, addr) chunkPrice := s.pricer.PeerPrice(peer, addr)
err = s.accounting.Reserve(ctx, peer, chunkPrice) err = s.accounting.Reserve(ctx, peer, chunkPrice)
......
...@@ -9,12 +9,15 @@ import ( ...@@ -9,12 +9,15 @@ import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"io/ioutil"
"os" "os"
"testing" "testing"
"time" "time"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock" accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
...@@ -224,6 +227,250 @@ func TestRetrieveChunk(t *testing.T) { ...@@ -224,6 +227,250 @@ func TestRetrieveChunk(t *testing.T) {
}) })
} }
func TestRetrievePreemptiveRetry(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
chunk := testingc.FixtureChunk("0025")
someOtherChunk := testingc.FixtureChunk("0033")
price := uint64(1)
pricerMock := accountingmock.NewPricer(price, price)
clientAddress := swarm.MustParseHexAddress("1010")
serverAddress1 := swarm.MustParseHexAddress("1000000000000000000000000000000000000000000000000000000000000000")
serverAddress2 := swarm.MustParseHexAddress("0200000000000000000000000000000000000000000000000000000000000000")
peers := []swarm.Address{
serverAddress1,
serverAddress2,
}
serverStorer1 := storemock.NewStorer()
serverStorer2 := storemock.NewStorer()
// we put some other chunk on server 1
_, err := serverStorer1.Put(context.Background(), storage.ModePutUpload, someOtherChunk)
if err != nil {
t.Fatal(err)
}
// we put chunk we need on server 2
_, err = serverStorer2.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil {
t.Fatal(err)
}
noPeerSuggester := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(swarm.ZeroAddress, 0)
return nil
}}
peerSuggesterFn := func(peers ...swarm.Address) topology.EachPeerer {
if len(peers) == 0 {
return noPeerSuggester
}
peerID := 0
peerSuggester := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(peers[peerID], 0)
// circulate suggested peers
peerID++
if peerID >= len(peers) {
peerID = 0
}
return nil
}}
return peerSuggester
}
server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil)
t.Run("peer not reachable", func(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
server1.Protocol(),
server2.Protocol(),
),
streamtest.WithMiddlewares(
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error {
// NOTE: return error for peer1
if serverAddress1.Equal(peer.Address) {
return fmt.Errorf("peer not reachable: %s", peer.Address.String())
}
if serverAddress2.Equal(peer.Address) {
return server2.Handler(ctx, peer, stream)
}
return fmt.Errorf("unknown peer: %s", peer.Address.String())
}
},
),
)
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
}
})
t.Run("peer does not have chunk", func(t *testing.T) {
recorder := streamtest.New(
streamtest.WithProtocols(
server1.Protocol(),
server2.Protocol(),
),
streamtest.WithMiddlewares(
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error {
if serverAddress1.Equal(peer.Address) {
return server1.Handler(ctx, peer, stream)
}
if serverAddress2.Equal(peer.Address) {
return server2.Handler(ctx, peer, stream)
}
return fmt.Errorf("unknown peer: %s", peer.Address.String())
}
},
),
)
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, accountingmock.NewAccounting(), pricerMock, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
}
})
t.Run("one peer is slower", func(t *testing.T) {
serverStorer1 := storemock.NewStorer()
serverStorer2 := storemock.NewStorer()
// both peers have required chunk
_, err := serverStorer1.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil {
t.Fatal(err)
}
_, err = serverStorer2.Put(context.Background(), storage.ModePutUpload, chunk)
if err != nil {
t.Fatal(err)
}
server1MockAccounting := accountingmock.NewAccounting()
server2MockAccounting := accountingmock.NewAccounting()
server1 := retrieval.New(serverAddress1, serverStorer1, nil, noPeerSuggester, logger, server1MockAccounting, pricerMock, nil)
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, server2MockAccounting, pricerMock, nil)
// NOTE: must be more than retry duration
// (here one second more)
server1ResponseDelayDuration := 6 * time.Second
recorder := streamtest.New(
streamtest.WithProtocols(
server1.Protocol(),
server2.Protocol(),
),
streamtest.WithMiddlewares(
func(h p2p.HandlerFunc) p2p.HandlerFunc {
return func(ctx context.Context, peer p2p.Peer, stream p2p.Stream) error {
if serverAddress1.Equal(peer.Address) {
// NOTE: sleep time must be more than retry duration
time.Sleep(server1ResponseDelayDuration)
return server1.Handler(ctx, peer, stream)
}
if serverAddress2.Equal(peer.Address) {
return server2.Handler(ctx, peer, stream)
}
return fmt.Errorf("unknown peer: %s", peer.Address.String())
}
},
),
)
clientMockAccounting := accountingmock.NewAccounting()
client := retrieval.New(clientAddress, nil, recorder, peerSuggesterFn(peers...), logger, clientMockAccounting, pricerMock, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
}
clientServer1Balance, _ := clientMockAccounting.Balance(serverAddress1)
if clientServer1Balance != 0 {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientServer1Balance)
}
clientServer2Balance, _ := clientMockAccounting.Balance(serverAddress2)
if clientServer2Balance != -int64(price) {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientServer2Balance)
}
// wait and check balance again
// (yet one second more than before, minus original duration)
time.Sleep(2 * time.Second)
clientServer1Balance, _ = clientMockAccounting.Balance(serverAddress1)
if clientServer1Balance != -int64(price) {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientServer1Balance)
}
clientServer2Balance, _ = clientMockAccounting.Balance(serverAddress2)
if clientServer2Balance != -int64(price) {
t.Fatalf("unexpected balance on client. want %d got %d", -price, clientServer2Balance)
}
})
t.Run("peer forwards request", func(t *testing.T) {
// server 2 has the chunk
server2 := retrieval.New(serverAddress2, serverStorer2, nil, noPeerSuggester, logger, accountingmock.NewAccounting(), pricerMock, nil)
server1Recorder := streamtest.New(
streamtest.WithProtocols(server2.Protocol()),
)
// server 1 will forward request to server 2
server1 := retrieval.New(serverAddress1, serverStorer1, server1Recorder, peerSuggesterFn(serverAddress2), logger, accountingmock.NewAccounting(), pricerMock, nil)
clientRecorder := streamtest.New(
streamtest.WithProtocols(server1.Protocol()),
)
// client only knows about server 1
client := retrieval.New(clientAddress, nil, clientRecorder, peerSuggesterFn(serverAddress1), logger, accountingmock.NewAccounting(), pricerMock, nil)
got, err := client.RetrieveChunk(context.Background(), chunk.Address())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got.Data(), chunk.Data()) {
t.Fatalf("got data %x, want %x", got.Data(), chunk.Data())
}
})
}
type mockPeerSuggester struct { type mockPeerSuggester struct {
eachPeerRevFunc func(f topology.EachPeerFunc) error eachPeerRevFunc func(f topology.EachPeerFunc) error
} }
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package retrieval
import (
"sync"
"github.com/ethersphere/bee/pkg/swarm"
)
type skipPeers struct {
addresses []swarm.Address
mu sync.Mutex
}
func newSkipPeers() *skipPeers {
return &skipPeers{}
}
func (s *skipPeers) All() []swarm.Address {
s.mu.Lock()
defer s.mu.Unlock()
return append(s.addresses[:0:0], s.addresses...)
}
func (s *skipPeers) Add(address swarm.Address) {
s.mu.Lock()
defer s.mu.Unlock()
for _, a := range s.addresses {
if a.Equal(address) {
return
}
}
s.addresses = append(s.addresses, address)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment