pullsync.go 13.5 KB
Newer Older
1 2 3 4
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

5 6
// Package pullsync provides the pullsync protocol
// implementation.
7 8 9 10
package pullsync

import (
	"context"
11 12
	"crypto/rand"
	"encoding/binary"
13 14 15
	"errors"
	"fmt"
	"io"
16
	"sync"
17 18 19
	"time"

	"github.com/ethersphere/bee/pkg/bitvector"
20
	"github.com/ethersphere/bee/pkg/cac"
21 22 23
	"github.com/ethersphere/bee/pkg/logging"
	"github.com/ethersphere/bee/pkg/p2p"
	"github.com/ethersphere/bee/pkg/p2p/protobuf"
24
	"github.com/ethersphere/bee/pkg/postage"
25 26
	"github.com/ethersphere/bee/pkg/pullsync/pb"
	"github.com/ethersphere/bee/pkg/pullsync/pullstorage"
acud's avatar
acud committed
27
	"github.com/ethersphere/bee/pkg/soc"
28 29 30 31 32 33 34 35 36
	"github.com/ethersphere/bee/pkg/storage"
	"github.com/ethersphere/bee/pkg/swarm"
)

const (
	protocolName     = "pullsync"
	protocolVersion  = "1.0.0"
	streamName       = "pullsync"
	cursorStreamName = "cursors"
37
	cancelStreamName = "cancel"
38 39 40 41
)

var (
	ErrUnsolicitedChunk = errors.New("peer sent unsolicited chunk")
42 43

	cancellationTimeout = 5 * time.Second // explicit ruid cancellation message timeout
44 45 46 47 48
)

// how many maximum chunks in a batch
var maxPage = 50

49
// Interface is the PullSync interface.
50
type Interface interface {
51 52 53 54
	// SyncInterval syncs a requested interval from the given peer.
	// It returns the BinID of highest chunk that was synced from the given
	// interval. If the requested interval is too large, the downstream peer
	// has the liberty to provide less chunks than requested.
55
	SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, ruid uint32, err error)
56
	// GetCursors retrieves all cursors from a downstream peer.
57
	GetCursors(ctx context.Context, peer swarm.Address) ([]uint64, error)
58 59
	// CancelRuid cancels active pullsync operation identified by ruid on
	// a downstream peer.
60
	CancelRuid(ctx context.Context, peer swarm.Address, ruid uint32) error
61 62 63
}

type Syncer struct {
acud's avatar
acud committed
64 65 66 67 68 69 70
	streamer   p2p.Streamer
	metrics    metrics
	logger     logging.Logger
	storage    pullstorage.Storer
	quit       chan struct{}
	wg         sync.WaitGroup
	unwrap     func(swarm.Chunk)
71
	validStamp postage.ValidStampFn
72

73 74 75
	ruidMtx sync.Mutex
	ruidCtx map[uint32]func()

76 77 78 79
	Interface
	io.Closer
}

80
func New(streamer p2p.Streamer, storage pullstorage.Storer, unwrap func(swarm.Chunk), validStamp postage.ValidStampFn, logger logging.Logger) *Syncer {
81
	return &Syncer{
acud's avatar
acud committed
82 83 84 85 86 87 88 89 90
		streamer:   streamer,
		storage:    storage,
		metrics:    newMetrics(),
		unwrap:     unwrap,
		validStamp: validStamp,
		logger:     logger,
		ruidCtx:    make(map[uint32]func()),
		wg:         sync.WaitGroup{},
		quit:       make(chan struct{}),
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
	}
}

func (s *Syncer) Protocol() p2p.ProtocolSpec {
	return p2p.ProtocolSpec{
		Name:    protocolName,
		Version: protocolVersion,
		StreamSpecs: []p2p.StreamSpec{
			{
				Name:    streamName,
				Handler: s.handler,
			},
			{
				Name:    cursorStreamName,
				Handler: s.cursorHandler,
			},
107 108 109 110
			{
				Name:    cancelStreamName,
				Handler: s.cancelHandler,
			},
111 112 113 114 115 116 117 118
		},
	}
}

// SyncInterval syncs a requested interval from the given peer.
// It returns the BinID of highest chunk that was synced from the given interval.
// If the requested interval is too large, the downstream peer has the liberty to
// provide less chunks than requested.
119
func (s *Syncer) SyncInterval(ctx context.Context, peer swarm.Address, bin uint8, from, to uint64) (topmost uint64, ruid uint32, err error) {
120 121
	stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, streamName)
	if err != nil {
122 123
		return 0, 0, fmt.Errorf("new stream: %w", err)
	}
124 125 126 127 128 129 130
	defer func() {
		if err != nil {
			_ = stream.Reset()
		} else {
			go stream.FullClose()
		}
	}()
131 132 133 134 135 136

	var ru pb.Ruid
	b := make([]byte, 4)
	_, err = rand.Read(b)
	if err != nil {
		return 0, 0, fmt.Errorf("crypto rand: %w", err)
137
	}
138 139

	ru.Ruid = binary.BigEndian.Uint32(b)
140 141

	w, r := protobuf.NewWriterAndReader(stream)
142

143 144 145
	if err = w.WriteMsgWithContext(ctx, &ru); err != nil {
		return 0, 0, fmt.Errorf("write ruid: %w", err)
	}
146 147 148

	rangeMsg := &pb.GetRange{Bin: int32(bin), From: from, To: to}
	if err = w.WriteMsgWithContext(ctx, rangeMsg); err != nil {
149
		return 0, ru.Ruid, fmt.Errorf("write get range: %w", err)
150 151 152 153
	}

	var offer pb.Offer
	if err = r.ReadMsgWithContext(ctx, &offer); err != nil {
154
		return 0, ru.Ruid, fmt.Errorf("read offer: %w", err)
155 156 157
	}

	if len(offer.Hashes)%swarm.HashSize != 0 {
158
		return 0, ru.Ruid, fmt.Errorf("inconsistent hash length")
159 160 161 162 163
	}

	// empty interval (no chunks present in interval).
	// return the end of the requested range as topmost.
	if len(offer.Hashes) == 0 {
164
		return offer.Topmost, ru.Ruid, nil
165 166 167 168 169 170 171 172 173 174
	}

	var (
		bvLen      = len(offer.Hashes) / swarm.HashSize
		wantChunks = make(map[string]struct{})
		ctr        = 0
	)

	bv, err := bitvector.New(bvLen)
	if err != nil {
175
		return 0, ru.Ruid, fmt.Errorf("new bitvector: %w", err)
176 177 178 179 180 181 182
	}

	for i := 0; i < len(offer.Hashes); i += swarm.HashSize {
		a := swarm.NewAddress(offer.Hashes[i : i+swarm.HashSize])
		if a.Equal(swarm.ZeroAddress) {
			// i'd like to have this around to see we don't see any of these in the logs
			s.logger.Errorf("syncer got a zero address hash on offer")
183
			return 0, ru.Ruid, fmt.Errorf("zero address on offer")
184
		}
185 186
		s.metrics.OfferCounter.Inc()
		s.metrics.DbOpsCounter.Inc()
187 188
		have, err := s.storage.Has(ctx, a)
		if err != nil {
189
			return 0, ru.Ruid, fmt.Errorf("storage has: %w", err)
190 191 192 193
		}
		if !have {
			wantChunks[a.String()] = struct{}{}
			ctr++
194
			s.metrics.WantCounter.Inc()
195 196 197 198 199 200
			bv.Set(i / swarm.HashSize)
		}
	}

	wantMsg := &pb.Want{BitVector: bv.Bytes()}
	if err = w.WriteMsgWithContext(ctx, wantMsg); err != nil {
201
		return 0, ru.Ruid, fmt.Errorf("write want: %w", err)
202 203 204 205 206 207
	}

	// if ctr is zero, it means we don't want any chunk in the batch
	// thus, the following loop will not get executed and the method
	// returns immediately with the topmost value on the offer, which
	// will seal the interval and request the next one
208 209
	err = nil
	var chunksToPut []swarm.Chunk
210 211 212 213

	for ; ctr > 0; ctr-- {
		var delivery pb.Delivery
		if err = r.ReadMsgWithContext(ctx, &delivery); err != nil {
214 215 216 217
			// this is not a fatal error and we should write
			// a partial batch if some chunks have been received.
			err = fmt.Errorf("read delivery: %w", err)
			break
218 219 220 221
		}

		addr := swarm.NewAddress(delivery.Address)
		if _, ok := wantChunks[addr.String()]; !ok {
222 223
			// this is fatal for the entire batch, return the
			// error and don't write the partial batch.
224
			return 0, ru.Ruid, ErrUnsolicitedChunk
225 226 227
		}

		delete(wantChunks, addr.String())
228
		s.metrics.DeliveryCounter.Inc()
229 230

		chunk := swarm.NewChunk(addr, delivery.Data)
acud's avatar
acud committed
231
		if chunk, err = s.validStamp(chunk, delivery.Stamp); err != nil {
232
			s.logger.Debugf("unverified chunk: %v", err)
233
			continue
acud's avatar
acud committed
234 235
		}

236
		if cac.Valid(chunk) {
acud's avatar
acud committed
237 238
			go s.unwrap(chunk)
		} else if !soc.Valid(chunk) {
239 240
			// this is fatal for the entire batch, return the
			// error and don't write the partial batch.
241 242
			return 0, ru.Ruid, swarm.ErrInvalidChunk
		}
243 244 245 246 247 248 249 250 251
		chunksToPut = append(chunksToPut, chunk)
	}
	if len(chunksToPut) > 0 {
		s.metrics.DbOpsCounter.Inc()
		if ierr := s.storage.Put(ctx, storage.ModePutSync, chunksToPut...); ierr != nil {
			if err != nil {
				ierr = fmt.Errorf(", sync err: %w", err)
			}
			return 0, ru.Ruid, fmt.Errorf("delivery put: %w", ierr)
252 253
		}
	}
254 255 256 257 258 259
	// there might have been an error in the for loop above,
	// return it if it indeed happened
	if err != nil {
		return 0, ru.Ruid, err
	}

260
	return offer.Topmost, ru.Ruid, nil
261 262 263
}

// handler handles an incoming request to sync an interval
264
func (s *Syncer) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (err error) {
265
	r := protobuf.NewReader(stream)
266 267 268 269 270 271 272
	defer func() {
		if err != nil {
			_ = stream.Reset()
		} else {
			_ = stream.FullClose()
		}
	}()
273 274 275 276 277 278 279 280 281
	var ru pb.Ruid
	if err := r.ReadMsgWithContext(ctx, &ru); err != nil {
		return fmt.Errorf("send ruid: %w", err)
	}

	ctx, cancel := context.WithCancel(ctx)
	s.ruidMtx.Lock()
	s.ruidCtx[ru.Ruid] = cancel
	s.ruidMtx.Unlock()
acud's avatar
acud committed
282 283 284
	cc := make(chan struct{})
	defer close(cc)
	go func() {
285 286 287
		select {
		case <-s.quit:
		case <-ctx.Done():
acud's avatar
acud committed
288
		case <-cc:
289
		}
acud's avatar
acud committed
290
		cancel()
291 292 293 294
		s.ruidMtx.Lock()
		delete(s.ruidCtx, ru.Ruid)
		s.ruidMtx.Unlock()
	}()
295 296 297 298 299 300 301 302 303

	select {
	case <-s.quit:
		return nil
	default:
	}

	s.wg.Add(1)
	defer s.wg.Done()
304

305 306 307 308 309 310
	var rn pb.GetRange
	if err := r.ReadMsgWithContext(ctx, &rn); err != nil {
		return fmt.Errorf("read get range: %w", err)
	}

	// make an offer to the upstream peer in return for the requested range
311
	offer, _, err := s.makeOffer(ctx, rn)
312 313 314 315
	if err != nil {
		return fmt.Errorf("make offer: %w", err)
	}

316 317 318 319 320
	// recreate the reader to allow the first one to be garbage collected
	// before the makeOffer function call, to reduce the total memory allocated
	// while makeOffer is executing (waiting for the new chunks)
	w, r := protobuf.NewWriterAndReader(stream)

321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
	if err := w.WriteMsgWithContext(ctx, offer); err != nil {
		return fmt.Errorf("write offer: %w", err)
	}

	// we don't have any hashes to offer in this range (the
	// interval is empty). nothing more to do
	if len(offer.Hashes) == 0 {
		return nil
	}

	var want pb.Want
	if err := r.ReadMsgWithContext(ctx, &want); err != nil {
		return fmt.Errorf("read want: %w", err)
	}

	chs, err := s.processWant(ctx, offer, &want)
	if err != nil {
		return fmt.Errorf("process want: %w", err)
	}

	for _, v := range chs {
acud's avatar
acud committed
342 343 344 345 346
		stamp, err := v.Stamp().MarshalBinary()
		if err != nil {
			return fmt.Errorf("serialise stamp: %w", err)
		}
		deliver := pb.Delivery{Address: v.Address().Bytes(), Data: v.Data(), Stamp: stamp}
347 348 349 350 351
		if err := w.WriteMsgWithContext(ctx, &deliver); err != nil {
			return fmt.Errorf("write delivery: %w", err)
		}
	}

352
	time.Sleep(50 * time.Millisecond) // because of test, getting EOF w/o
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
	return nil
}

// makeOffer tries to assemble an offer for a given requested interval.
func (s *Syncer) makeOffer(ctx context.Context, rn pb.GetRange) (o *pb.Offer, addrs []swarm.Address, err error) {
	chs, top, err := s.storage.IntervalChunks(ctx, uint8(rn.Bin), rn.From, rn.To, maxPage)
	if err != nil {
		return o, nil, err
	}
	o = new(pb.Offer)
	o.Topmost = top
	o.Hashes = make([]byte, 0)
	for _, v := range chs {
		o.Hashes = append(o.Hashes, v.Bytes()...)
	}
	return o, chs, nil
}

// processWant compares a received Want to a sent Offer and returns
// the appropriate chunks from the local store.
func (s *Syncer) processWant(ctx context.Context, o *pb.Offer, w *pb.Want) ([]swarm.Chunk, error) {
	l := len(o.Hashes) / swarm.HashSize
	bv, err := bitvector.NewFromBytes(w.BitVector, l)
	if err != nil {
		return nil, err
	}

	var addrs []swarm.Address
	for i := 0; i < len(o.Hashes); i += swarm.HashSize {
		if bv.Get(i / swarm.HashSize) {
			a := swarm.NewAddress(o.Hashes[i : i+swarm.HashSize])
			addrs = append(addrs, a)
		}
	}
387
	s.metrics.DbOpsCounter.Inc()
388 389 390
	return s.storage.Get(ctx, storage.ModeGetSync, addrs...)
}

391
func (s *Syncer) GetCursors(ctx context.Context, peer swarm.Address) (retr []uint64, err error) {
392 393 394 395
	stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, cursorStreamName)
	if err != nil {
		return nil, fmt.Errorf("new stream: %w", err)
	}
396 397 398 399 400 401 402
	defer func() {
		if err != nil {
			_ = stream.Reset()
		} else {
			go stream.FullClose()
		}
	}()
403 404 405 406 407 408 409 410 411 412 413 414

	w, r := protobuf.NewWriterAndReader(stream)
	syn := &pb.Syn{}
	if err = w.WriteMsgWithContext(ctx, syn); err != nil {
		return nil, fmt.Errorf("write syn: %w", err)
	}

	var ack pb.Ack
	if err = r.ReadMsgWithContext(ctx, &ack); err != nil {
		return nil, fmt.Errorf("read ack: %w", err)
	}

415 416 417
	retr = ack.Cursors

	return retr, nil
418 419
}

420
func (s *Syncer) cursorHandler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (err error) {
421
	w, r := protobuf.NewWriterAndReader(stream)
422 423 424 425 426 427 428
	defer func() {
		if err != nil {
			_ = stream.Reset()
		} else {
			_ = stream.FullClose()
		}
	}()
429 430 431 432 433 434 435

	var syn pb.Syn
	if err := r.ReadMsgWithContext(ctx, &syn); err != nil {
		return fmt.Errorf("read syn: %w", err)
	}

	var ack pb.Ack
436
	s.metrics.DbOpsCounter.Inc()
437 438 439 440 441 442 443 444 445 446 447 448
	ints, err := s.storage.Cursors(ctx)
	if err != nil {
		return err
	}
	ack.Cursors = ints
	if err = w.WriteMsgWithContext(ctx, &ack); err != nil {
		return fmt.Errorf("write ack: %w", err)
	}

	return nil
}

449 450
func (s *Syncer) CancelRuid(ctx context.Context, peer swarm.Address, ruid uint32) (err error) {
	stream, err := s.streamer.NewStream(ctx, peer, nil, protocolName, protocolVersion, cancelStreamName)
451 452 453 454 455
	if err != nil {
		return fmt.Errorf("new stream: %w", err)
	}

	w := protobuf.NewWriter(stream)
456 457 458 459 460 461 462
	defer func() {
		if err != nil {
			_ = stream.Reset()
		} else {
			go stream.FullClose()
		}
	}()
463

464 465 466
	ctx, cancel := context.WithTimeout(ctx, cancellationTimeout)
	defer cancel()

467 468
	var c pb.Cancel
	c.Ruid = ruid
469
	if err := w.WriteMsgWithContext(ctx, &c); err != nil {
470 471 472 473 474 475
		return fmt.Errorf("send cancellation: %w", err)
	}
	return nil
}

// handler handles an incoming request to explicitly cancel a ruid
476
func (s *Syncer) cancelHandler(ctx context.Context, p p2p.Peer, stream p2p.Stream) (err error) {
477
	r := protobuf.NewReader(stream)
478 479 480 481 482 483 484
	defer func() {
		if err != nil {
			_ = stream.Reset()
		} else {
			_ = stream.FullClose()
		}
	}()
485 486 487 488 489 490

	var c pb.Cancel
	if err := r.ReadMsgWithContext(ctx, &c); err != nil {
		return fmt.Errorf("read cancel: %w", err)
	}

491 492 493
	s.ruidMtx.Lock()
	defer s.ruidMtx.Unlock()

494 495 496 497 498 499 500
	if cancel, ok := s.ruidCtx[c.Ruid]; ok {
		cancel()
	}
	delete(s.ruidCtx, c.Ruid)
	return nil
}

501
func (s *Syncer) Close() error {
502 503 504 505 506 507 508
	s.logger.Info("pull syncer shutting down")
	close(s.quit)
	cc := make(chan struct{})
	go func() {
		defer close(cc)
		s.wg.Wait()
	}()
509 510 511 512 513 514 515 516

	// cancel all contexts
	s.ruidMtx.Lock()
	for _, c := range s.ruidCtx {
		c()
	}
	s.ruidMtx.Unlock()

517 518
	select {
	case <-cc:
519
	case <-time.After(5 * time.Second):
520 521
		s.logger.Warning("pull syncer shutting down with running goroutines")
	}
522 523
	return nil
}