Commit c9bb6556 authored by Janoš Guljaš's avatar Janoš Guljaš Committed by GitHub

blocklist peer on retrieval timeout (#685)

parent d0cd9b4a
...@@ -238,7 +238,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -238,7 +238,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
chunkvalidator := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator()) chunkvalidator := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator())
retrieve := retrieval.New(p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator) retrieve := retrieval.New(p2ps, storer, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator)
tagg := tags.NewTags(stateStore, logger) tagg := tags.NewTags(stateStore, logger)
b.tagsCloser = tagg b.tagsCloser = tagg
...@@ -257,7 +257,6 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -257,7 +257,6 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
} else { } else {
ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator) ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator)
} }
retrieve.SetStorer(ns)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10)) pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10))
......
...@@ -20,13 +20,17 @@ type Service interface { ...@@ -20,13 +20,17 @@ type Service interface {
AddProtocol(ProtocolSpec) error AddProtocol(ProtocolSpec) error
// Connect to a peer but do not notify topology about the established connection. // Connect to a peer but do not notify topology about the established connection.
Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.Address, err error) Connect(ctx context.Context, addr ma.Multiaddr) (address *bzz.Address, err error)
Disconnecter
Peers() []Peer
AddNotifier(topology.Notifier)
Addresses() ([]ma.Multiaddr, error)
}
type Disconnecter interface {
Disconnect(overlay swarm.Address) error Disconnect(overlay swarm.Address) error
// Blocklist will disconnect a peer and put it on a blocklist (blocking in & out connections) for provided duration // Blocklist will disconnect a peer and put it on a blocklist (blocking in & out connections) for provided duration
// duration 0 is treated as an infinite duration // duration 0 is treated as an infinite duration
Blocklist(overlay swarm.Address, duration time.Duration) error Blocklist(overlay swarm.Address, duration time.Duration) error
Peers() []Peer
AddNotifier(topology.Notifier)
Addresses() ([]ma.Multiaddr, error)
} }
// DebugService extends the Service with method used for debugging. // DebugService extends the Service with method used for debugging.
...@@ -41,6 +45,11 @@ type Streamer interface { ...@@ -41,6 +45,11 @@ type Streamer interface {
NewStream(ctx context.Context, address swarm.Address, h Headers, protocol, version, stream string) (Stream, error) NewStream(ctx context.Context, address swarm.Address, h Headers, protocol, version, stream string) (Stream, error)
} }
type StreamerDisconnecter interface {
Streamer
Disconnecter
}
// Stream represent a bidirectional data Stream. // Stream represent a bidirectional data Stream.
type Stream interface { type Stream interface {
io.ReadWriter io.ReadWriter
......
...@@ -319,3 +319,52 @@ type Option interface { ...@@ -319,3 +319,52 @@ type Option interface {
type optionFunc func(*Recorder) type optionFunc func(*Recorder)
func (f optionFunc) apply(r *Recorder) { f(r) } func (f optionFunc) apply(r *Recorder) { f(r) }
var _ p2p.StreamerDisconnecter = (*RecorderDisconnecter)(nil)
type RecorderDisconnecter struct {
*Recorder
disconnected map[string]struct{}
blocklisted map[string]time.Duration
mu sync.RWMutex
}
func NewRecorderDisconnecter(r *Recorder) *RecorderDisconnecter {
return &RecorderDisconnecter{
Recorder: r,
disconnected: make(map[string]struct{}),
blocklisted: make(map[string]time.Duration),
}
}
func (r *RecorderDisconnecter) Disconnect(overlay swarm.Address) error {
r.mu.Lock()
defer r.mu.Unlock()
r.disconnected[overlay.String()] = struct{}{}
return nil
}
func (r *RecorderDisconnecter) Blocklist(overlay swarm.Address, d time.Duration) error {
r.mu.Lock()
defer r.mu.Unlock()
r.blocklisted[overlay.String()] = d
return nil
}
func (r *RecorderDisconnecter) IsDisconnected(overlay swarm.Address) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, yes := r.disconnected[overlay.String()]
return yes
}
func (r *RecorderDisconnecter) IsBlocklisted(overlay swarm.Address) (bool, time.Duration) {
r.mu.RLock()
defer r.mu.RUnlock()
d, yes := r.blocklisted[overlay.String()]
return yes, d
}
...@@ -269,13 +269,11 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.S ...@@ -269,13 +269,11 @@ func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.S
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
return nil return nil
}} }}
server := retrieval.New(nil, nil, logger, serverMockAccounting, nil, nil) server := retrieval.New(nil, mockStorer, ps, logger, serverMockAccounting, nil, nil)
server.SetStorer(mockStorer) recorder := streamtest.NewRecorderDisconnecter(streamtest.New(
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
) ))
retrieve := retrieval.New(recorder, ps, logger, serverMockAccounting, pricerMock, nil) retrieve := retrieval.New(recorder, mockStorer, ps, logger, serverMockAccounting, pricerMock, nil)
retrieve.SetStorer(mockStorer)
ns := netstore.New(storer, recoveryFunc, retrieve, logger, nil) ns := netstore.New(storer, recoveryFunc, retrieve, logger, nil)
return ns return ns
} }
......
...@@ -6,6 +6,7 @@ package retrieval ...@@ -6,6 +6,7 @@ package retrieval
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
...@@ -35,7 +36,7 @@ type Interface interface { ...@@ -35,7 +36,7 @@ type Interface interface {
} }
type Service struct { type Service struct {
streamer p2p.Streamer streamer p2p.StreamerDisconnecter
peerSuggester topology.EachPeerer peerSuggester topology.EachPeerer
storer storage.Storer storer storage.Storer
singleflight singleflight.Group singleflight singleflight.Group
...@@ -45,10 +46,11 @@ type Service struct { ...@@ -45,10 +46,11 @@ type Service struct {
validator swarm.Validator validator swarm.Validator
} }
func New(streamer p2p.Streamer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, validator swarm.Validator) *Service { func New(streamer p2p.StreamerDisconnecter, storer storage.Storer, chunkPeerer topology.EachPeerer, logger logging.Logger, accounting accounting.Interface, pricer accounting.Pricer, validator swarm.Validator) *Service {
return &Service{ return &Service{
streamer: streamer, streamer: streamer,
peerSuggester: chunkPeerer, peerSuggester: chunkPeerer,
storer: storer,
logger: logger, logger: logger,
accounting: accounting, accounting: accounting,
pricer: pricer, pricer: pricer,
...@@ -72,11 +74,10 @@ func (s *Service) Protocol() p2p.ProtocolSpec { ...@@ -72,11 +74,10 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
const ( const (
maxPeers = 5 maxPeers = 5
retrieveChunkTimeout = 10 * time.Second retrieveChunkTimeout = 10 * time.Second
blocklistDuration = time.Minute
) )
func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) { func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.Chunk, error) {
ctx, cancel := context.WithTimeout(ctx, maxPeers*retrieveChunkTimeout)
defer cancel()
v, err, _ := s.singleflight.Do(addr.String(), func() (interface{}, error) { v, err, _ := s.singleflight.Do(addr.String(), func() (interface{}, error) {
var skipPeers []swarm.Address var skipPeers []swarm.Address
...@@ -89,6 +90,14 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm. ...@@ -89,6 +90,14 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (swarm.
} }
s.logger.Debugf("retrieval: failed to get chunk %s from peer %s: %v", addr, peer, err) s.logger.Debugf("retrieval: failed to get chunk %s from peer %s: %v", addr, peer, err)
skipPeers = append(skipPeers, peer) skipPeers = append(skipPeers, peer)
if errors.Is(err, context.DeadlineExceeded) {
if err := s.streamer.Blocklist(peer, blocklistDuration); err != nil {
s.logger.Errorf("retrieval: unable to block peer %s", peer)
s.logger.Debugf("retrieval: blocking peer %s: %v", peer, err)
} else {
s.logger.Warningf("retrieval: peer %s blocked as unresponsive", peer)
}
}
continue continue
} }
s.logger.Tracef("retrieval: got chunk %s from peer %s", addr, peer) s.logger.Tracef("retrieval: got chunk %s from peer %s", addr, peer)
...@@ -124,7 +133,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -124,7 +133,7 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
chunkPrice := s.pricer.PeerPrice(peer, addr) chunkPrice := s.pricer.PeerPrice(peer, addr)
err = s.accounting.Reserve(peer, chunkPrice) err = s.accounting.Reserve(peer, chunkPrice)
if err != nil { if err != nil {
return nil, peer, err return nil, peer, fmt.Errorf("accounting retrieve: %w", err)
} }
defer s.accounting.Release(peer, chunkPrice) defer s.accounting.Release(peer, chunkPrice)
...@@ -146,26 +155,26 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -146,26 +155,26 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
if err := w.WriteMsgWithContext(ctx, &pb.Request{ if err := w.WriteMsgWithContext(ctx, &pb.Request{
Addr: addr.Bytes(), Addr: addr.Bytes(),
}); err != nil { }); err != nil {
return nil, peer, fmt.Errorf("write request: %w peer %s", err, peer.String()) return nil, peer, fmt.Errorf("write request: %w", err)
} }
var d pb.Delivery var d pb.Delivery
if err := r.ReadMsgWithContext(ctx, &d); err != nil { if err := r.ReadMsgWithContext(ctx, &d); err != nil {
return nil, peer, fmt.Errorf("read delivery: %w peer %s", err, peer.String()) return nil, peer, fmt.Errorf("read delivery: %w", err)
} }
// credit the peer after successful delivery // credit the peer after successful delivery
chunk = swarm.NewChunk(addr, d.Data) chunk = swarm.NewChunk(addr, d.Data)
if !s.validator.Validate(chunk) { if !s.validator.Validate(chunk) {
return nil, peer, err return nil, peer, fmt.Errorf("new chunk: %w", err)
} }
err = s.accounting.Credit(peer, chunkPrice) err = s.accounting.Credit(peer, chunkPrice)
if err != nil { if err != nil {
return nil, peer, err return nil, peer, fmt.Errorf("accounting credit: %w", err)
} }
return chunk, peer, err return chunk, peer, nil
} }
func (s *Service) closestPeer(addr swarm.Address, skipPeers []swarm.Address) (swarm.Address, error) { func (s *Service) closestPeer(addr swarm.Address, skipPeers []swarm.Address) (swarm.Address, error) {
...@@ -219,18 +228,27 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -219,18 +228,27 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
}() }()
var req pb.Request var req pb.Request
if err := r.ReadMsg(&req); err != nil { if err := r.ReadMsg(&req); err != nil {
return fmt.Errorf("read request: %w peer %s", err, p.Address.String()) return fmt.Errorf("read request: %w", err)
} }
ctx = context.WithValue(ctx, requestSourceContextKey{}, p.Address.String()) ctx = context.WithValue(ctx, requestSourceContextKey{}, p.Address.String())
chunk, err := s.storer.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(req.Addr)) addr := swarm.NewAddress(req.Addr)
chunk, err := s.storer.Get(ctx, storage.ModeGetRequest, addr)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
// forward the request
chunk, err = s.RetrieveChunk(ctx, addr)
if err != nil { if err != nil {
return fmt.Errorf("get from store: %w peer %s", err, p.Address.String()) return fmt.Errorf("retrieve chunk: %w", err)
}
} else {
return fmt.Errorf("get from store: %w", err)
}
} }
if err := w.WriteMsgWithContext(ctx, &pb.Delivery{ if err := w.WriteMsgWithContext(ctx, &pb.Delivery{
Data: chunk.Data(), Data: chunk.Data(),
}); err != nil { }); err != nil {
return fmt.Errorf("write delivery: %w peer %s", err, p.Address.String()) return fmt.Errorf("write delivery: %w", err)
} }
// compute the price we charge for this chunk and debit it from p's balance // compute the price we charge for this chunk and debit it from p's balance
...@@ -242,8 +260,3 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e ...@@ -242,8 +260,3 @@ func (s *Service) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (e
return nil return nil
} }
// SetStorer sets the storer. This call is not goroutine safe.
func (s *Service) SetStorer(storer storage.Storer) {
s.storer = storer
}
...@@ -51,11 +51,10 @@ func TestDelivery(t *testing.T) { ...@@ -51,11 +51,10 @@ func TestDelivery(t *testing.T) {
pricerMock := accountingmock.NewPricer(price, price) pricerMock := accountingmock.NewPricer(price, price)
// create the server that will handle the request and will serve the response // create the server that will handle the request and will serve the response
server := retrieval.New(nil, nil, logger, serverMockAccounting, pricerMock, mockValidator) server := retrieval.New(nil, mockStorer, nil, logger, serverMockAccounting, pricerMock, mockValidator)
server.SetStorer(mockStorer) recorder := streamtest.NewRecorderDisconnecter(streamtest.New(
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
) ))
clientMockAccounting := accountingmock.NewAccounting() clientMockAccounting := accountingmock.NewAccounting()
...@@ -70,8 +69,7 @@ func TestDelivery(t *testing.T) { ...@@ -70,8 +69,7 @@ func TestDelivery(t *testing.T) {
_, _, _ = f(peerID, 0) _, _, _ = f(peerID, 0)
return nil return nil
}} }}
client := retrieval.New(recorder, ps, logger, clientMockAccounting, pricerMock, mockValidator) client := retrieval.New(recorder, clientMockStorer, ps, logger, clientMockAccounting, pricerMock, mockValidator)
client.SetStorer(clientMockStorer)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel() defer cancel()
v, err := client.RetrieveChunk(ctx, reqAddr) v, err := client.RetrieveChunk(ctx, reqAddr)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment