Commit cfb17e33 authored by acud's avatar acud Committed by GitHub

pullsync: add explicit termination (#371)

* pullsync: add explicit termination of a request for intervals
parent 23f47d4d
......@@ -365,11 +365,17 @@ func (p *Puller) histSyncWorker(ctx context.Context, peer swarm.Address, bin uin
}
return
}
top, err := p.syncer.SyncInterval(ctx, peer, bin, s, cur)
top, ruid, err := p.syncer.SyncInterval(ctx, peer, bin, s, cur)
if err != nil {
if logMore {
p.logger.Debugf("histSyncWorker error syncing interval. peer %s, bin %d, cursor %d, err %v", peer.String(), bin, cur, err)
}
if ruid == 0 {
return
}
if err := p.syncer.CancelRuid(peer, ruid); err != nil && logMore {
p.logger.Debugf("histSyncWorker cancel ruid: %v", err)
}
return
}
err = p.addPeerInterval(peer, bin, s, top)
......@@ -399,11 +405,17 @@ func (p *Puller) liveSyncWorker(ctx context.Context, peer swarm.Address, bin uin
return
default:
}
top, err := p.syncer.SyncInterval(ctx, peer, bin, from, math.MaxUint64)
top, ruid, err := p.syncer.SyncInterval(ctx, peer, bin, from, math.MaxUint64)
if err != nil {
if logMore {
p.logger.Debugf("liveSyncWorker exit on sync error. peer %s bin %d from %d err %v", peer, bin, from, err)
}
if ruid == 0 {
return
}
if err := p.syncer.CancelRuid(peer, ruid); err != nil && logMore {
p.logger.Debugf("histSyncWorker cancel ruid: %v", err)
}
return
}
if top == 0 {
......
......@@ -106,7 +106,7 @@ func NewPullSync(opts ...Option) *PullSyncMock {
return s
}
func (p *PullSyncMock) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, err error) {
func (p *PullSyncMock) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, ruid uint32, err error) {
isLive := to == math.MaxUint64
call := SyncCall{
......@@ -129,9 +129,9 @@ func (p *PullSyncMock) SyncInterval(ctx context.Context, peer swarm.Address, bin
select {
case <-p.quit:
return 0, context.Canceled
return 0, 1, context.Canceled
case <-ctx.Done():
return 0, ctx.Err()
return 0, 1, ctx.Err()
default:
}
......@@ -150,12 +150,12 @@ func (p *PullSyncMock) SyncInterval(ctx context.Context, peer swarm.Address, bin
if sr.block {
select {
case <-p.quit:
return 0, context.Canceled
return 0, 1, context.Canceled
case <-ctx.Done():
return 0, ctx.Err()
return 0, 1, ctx.Err()
}
}
return sr.topmost, nil
return sr.topmost, 0, nil
}
panic("not found")
}
......@@ -163,7 +163,7 @@ func (p *PullSyncMock) SyncInterval(ctx context.Context, peer swarm.Address, bin
if isLive && p.blockLiveSync {
// don't respond, wait for quit
<-p.quit
return 0, io.EOF
return 0, 1, io.EOF
}
if isLive && len(p.liveSyncReplies) > 0 {
......@@ -175,7 +175,7 @@ func (p *PullSyncMock) SyncInterval(ctx context.Context, peer swarm.Address, bin
v := p.liveSyncReplies[p.liveSyncCalls]
p.liveSyncCalls++
p.mtx.Unlock()
return v, nil
return v, 1, nil
}
if p.autoReply {
......@@ -184,9 +184,9 @@ func (p *PullSyncMock) SyncInterval(ctx context.Context, peer swarm.Address, bin
if t > to {
t = to
}
return t, nil
return t, 1, nil
}
return to, nil
return to, 1, nil
}
func (p *PullSyncMock) GetCursors(_ context.Context, peer swarm.Address) ([]uint64, error) {
......@@ -208,6 +208,10 @@ func (p *PullSyncMock) SyncCalls(peer swarm.Address) (res []SyncCall) {
return res
}
func (p *PullSyncMock) CancelRuid(peer swarm.Address, ruid uint32) error {
return nil
}
func (p *PullSyncMock) LiveSyncCalls(peer swarm.Address) (res []SyncCall) {
p.mtx.Lock()
defer p.mtx.Unlock()
......
This diff is collapsed.
......@@ -14,6 +14,14 @@ message Ack {
repeated uint64 Cursors = 1;
}
message Ruid {
uint32 Ruid = 1;
}
message Cancel {
uint32 Ruid = 1;
}
message GetRange {
int32 Bin = 1;
uint64 From = 2;
......
......@@ -6,9 +6,12 @@ package pullsync
import (
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/ethersphere/bee/pkg/bitvector"
......@@ -27,6 +30,7 @@ const (
protocolVersion = "1.0.0"
streamName = "pullsync"
cursorStreamName = "cursors"
cancelStreamName = "cancel"
)
var (
......@@ -37,8 +41,9 @@ var (
var maxPage = 50
type Interface interface {
SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, err error)
SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, ruid uint32, err error)
GetCursors(ctx context.Context, peer swarm.Address) ([]uint64, error)
CancelRuid(peer swarm.Address, ruid uint32) error
}
type Syncer struct {
......@@ -46,6 +51,9 @@ type Syncer struct {
logger logging.Logger
storage pullstorage.Storer
ruidMtx sync.Mutex
ruidCtx map[uint32]func()
Interface
io.Closer
}
......@@ -62,6 +70,7 @@ func New(o Options) *Syncer {
streamer: o.Streamer,
storage: o.Storage,
logger: o.Logger,
ruidCtx: make(map[uint32]func()),
}
}
......@@ -78,6 +87,10 @@ func (s *Syncer) Protocol() p2p.ProtocolSpec {
Name: cursorStreamName,
Handler: s.cursorHandler,
},
{
Name: cancelStreamName,
Handler: s.cancelHandler,
},
},
}
}
......@@ -86,11 +99,20 @@ func (s *Syncer) Protocol() p2p.ProtocolSpec {
// It returns the BinID of highest chunk that was synced from the given interval.
// If the requested interval is too large, the downstream peer has the liberty to
// provide less chunks than requested.
func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, err error) {
func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, ruid uint32, err error) {
stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
if err != nil {
return 0, fmt.Errorf("new stream: %w", err)
return 0, 0, fmt.Errorf("new stream: %w", err)
}
var ru pb.Ruid
b := make([]byte, 4)
_, err = rand.Read(b)
if err != nil {
return 0, 0, fmt.Errorf("crypto rand: %w", err)
}
ru.Ruid = binary.BigEndian.Uint32(b)
defer func() {
if err != nil {
_ = stream.FullClose()
......@@ -100,25 +122,28 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8
}()
w, r := protobuf.NewWriterAndReader(stream)
if err = w.WriteMsgWithContext(ctx, &ru); err != nil {
return 0, 0, fmt.Errorf("write ruid: %w", err)
}
rangeMsg := &pb.GetRange{Bin: int32(bin), From: from, To: to}
if err = w.WriteMsgWithContext(ctx, rangeMsg); err != nil {
return 0, fmt.Errorf("write get range: %w", err)
return 0, ru.Ruid, fmt.Errorf("write get range: %w", err)
}
var offer pb.Offer
if err = r.ReadMsgWithContext(ctx, &offer); err != nil {
return 0, fmt.Errorf("read offer: %w", err)
return 0, ru.Ruid, fmt.Errorf("read offer: %w", err)
}
if len(offer.Hashes)%swarm.HashSize != 0 {
return 0, fmt.Errorf("inconsistent hash length")
return 0, ru.Ruid, fmt.Errorf("inconsistent hash length")
}
// empty interval (no chunks present in interval).
// return the end of the requested range as topmost.
if len(offer.Hashes) == 0 {
return offer.Topmost, nil
return offer.Topmost, ru.Ruid, nil
}
var (
......@@ -129,7 +154,7 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8
bv, err := bitvector.New(bvLen)
if err != nil {
return 0, fmt.Errorf("new bitvector: %w", err)
return 0, ru.Ruid, fmt.Errorf("new bitvector: %w", err)
}
for i := 0; i < len(offer.Hashes); i += swarm.HashSize {
......@@ -137,11 +162,11 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8
if a.Equal(swarm.ZeroAddress) {
// i'd like to have this around to see we don't see any of these in the logs
s.logger.Errorf("syncer got a zero address hash on offer")
return 0, fmt.Errorf("zero address on offer")
return 0, ru.Ruid, fmt.Errorf("zero address on offer")
}
have, err := s.storage.Has(ctx, a)
if err != nil {
return 0, fmt.Errorf("storage has: %w", err)
return 0, ru.Ruid, fmt.Errorf("storage has: %w", err)
}
if !have {
wantChunks[a.String()] = struct{}{}
......@@ -152,7 +177,7 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8
wantMsg := &pb.Want{BitVector: bv.Bytes()}
if err = w.WriteMsgWithContext(ctx, wantMsg); err != nil {
return 0, fmt.Errorf("write want: %w", err)
return 0, ru.Ruid, fmt.Errorf("write want: %w", err)
}
// if ctr is zero, it means we don't want any chunk in the batch
......@@ -163,21 +188,21 @@ func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8
for ; ctr > 0; ctr-- {
var delivery pb.Delivery
if err = r.ReadMsgWithContext(ctx, &delivery); err != nil {
return 0, fmt.Errorf("read delivery: %w", err)
return 0, ru.Ruid, fmt.Errorf("read delivery: %w", err)
}
addr := swarm.NewAddress(delivery.Address)
if _, ok := wantChunks[addr.String()]; !ok {
return 0, ErrUnsolicitedChunk
return 0, ru.Ruid, ErrUnsolicitedChunk
}
delete(wantChunks, addr.String())
if err = s.storage.Put(ctx, storage.ModePutSync, swarm.NewChunk(addr, delivery.Data)); err != nil {
return 0, fmt.Errorf("delivery put: %w", err)
return 0, ru.Ruid, fmt.Errorf("delivery put: %w", err)
}
}
return offer.Topmost, nil
return offer.Topmost, ru.Ruid, nil
}
// handler handles an incoming request to sync an interval
......@@ -185,6 +210,22 @@ func (s *Syncer) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) err
w, r := protobuf.NewWriterAndReader(stream)
defer stream.Close()
var ru pb.Ruid
if err := r.ReadMsgWithContext(ctx, &ru); err != nil {
return fmt.Errorf("send ruid: %w", err)
}
ctx, cancel := context.WithCancel(ctx)
s.ruidMtx.Lock()
s.ruidCtx[ru.Ruid] = cancel
s.ruidMtx.Unlock()
defer func() {
s.ruidMtx.Lock()
delete(s.ruidCtx, ru.Ruid)
s.ruidMtx.Unlock()
}()
defer cancel()
var rn pb.GetRange
if err := r.ReadMsgWithContext(ctx, &rn); err != nil {
return fmt.Errorf("read get range: %w", err)
......@@ -320,6 +361,42 @@ func (s *Syncer) cursorHandler(ctx context.Context, p p2p.Peer, stream p2p.Strea
return nil
}
func (s *Syncer) CancelRuid(peer swarm.Address, ruid uint32) error {
stream, err := s.streamer.NewStream(context.Background(), peer, nil, protocolName, protocolVersion, cancelStreamName)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
w := protobuf.NewWriter(stream)
defer stream.Close()
var c pb.Cancel
c.Ruid = ruid
if err := w.WriteMsgWithTimeout(5*time.Second, &c); err != nil {
return fmt.Errorf("send cancellation: %w", err)
}
return nil
}
// handler handles an incoming request to explicitly cancel a ruid
func (s *Syncer) cancelHandler(ctx context.Context, p p2p.Peer, stream p2p.Stream) error {
r := protobuf.NewReader(stream)
defer stream.Close()
var c pb.Cancel
if err := r.ReadMsgWithContext(ctx, &c); err != nil {
return fmt.Errorf("read cancel: %w", err)
}
if cancel, ok := s.ruidCtx[c.Ruid]; ok {
cancel()
}
s.ruidMtx.Lock()
delete(s.ruidCtx, c.Ruid)
s.ruidMtx.Unlock()
return nil
}
func (s *Syncer) Close() error {
return nil
}
......@@ -61,7 +61,7 @@ func TestIncoming_WantEmptyInterval(t *testing.T) {
psClient, clientDb = newPullSync(recorder)
)
topmost, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 1, 0, 5)
topmost, _, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 1, 0, 5)
if err != nil {
t.Fatal(err)
}
......@@ -84,7 +84,7 @@ func TestIncoming_WantNone(t *testing.T) {
psClient, clientDb = newPullSync(recorder, mock.WithChunks(chunks...))
)
topmost, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
topmost, _, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
if err != nil {
t.Fatal(err)
}
......@@ -107,7 +107,7 @@ func TestIncoming_WantOne(t *testing.T) {
psClient, clientDb = newPullSync(recorder, mock.WithChunks(someChunks(1, 2, 3, 4)...))
)
topmost, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
topmost, _, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
if err != nil {
t.Fatal(err)
}
......@@ -132,7 +132,7 @@ func TestIncoming_WantAll(t *testing.T) {
psClient, clientDb = newPullSync(recorder)
)
topmost, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
topmost, _, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
if err != nil {
t.Fatal(err)
}
......@@ -161,7 +161,7 @@ func TestIncoming_UnsolicitedChunk(t *testing.T) {
psClient, _ = newPullSync(recorder)
)
_, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
_, _, err := psClient.SyncInterval(context.Background(), swarm.ZeroAddress, 0, 0, 5)
if !errors.Is(err, pullsync.ErrUnsolicitedChunk) {
t.Fatalf("expected ErrUnsolicitedChunk but got %v", err)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment