batching.go 4.78 KB
Newer Older
1
package batching
2 3 4

import (
	"context"
5
	"errors"
6
	"fmt"
7 8 9
	"io"
	"sync"
	"sync/atomic"
10

11
	"github.com/hashicorp/go-multierror"
12 13

	"github.com/ethereum/go-ethereum/rpc"
14 15
)

16 17 18
// IterativeBatchCall batches many RPC requests with safe and easy parallelization.
// Request errors are handled and re-tried, and the batch size is configurable.
// Executing IterativeBatchCall is as simple as calling Fetch repeatedly until it returns io.EOF.
19
type IterativeBatchCall[K any, V any] struct {
20 21
	completed uint32       // tracks how far to completing all requests we are
	resetLock sync.RWMutex // ensures we do not concurrently read (incl. fetch) / reset
22

23 24
	requestsKeys []K
	batchSize    int
25

26
	makeRequest func(K) (V, rpc.BatchElem)
27
	getBatch    BatchCallContextFn
28
	getSingle   CallContextFn
29

30 31 32
	requestsValues []V
	scheduled      chan rpc.BatchElem
}
33

34 35
// NewIterativeBatchCall constructs a batch call, fetching the values with the given keys,
// and transforms them into a verified final result.
36
func NewIterativeBatchCall[K any, V any](
37 38
	requestsKeys []K,
	makeRequest func(K) (V, rpc.BatchElem),
39
	getBatch BatchCallContextFn,
40
	getSingle CallContextFn,
41
	batchSize int) *IterativeBatchCall[K, V] {
42 43 44 45 46 47 48

	if len(requestsKeys) < batchSize {
		batchSize = len(requestsKeys)
	}
	if batchSize < 1 {
		batchSize = 1
	}
49

50
	out := &IterativeBatchCall[K, V]{
51 52
		completed:    0,
		getBatch:     getBatch,
53
		getSingle:    getSingle,
54 55 56
		requestsKeys: requestsKeys,
		batchSize:    batchSize,
		makeRequest:  makeRequest,
57
	}
58 59 60
	out.Reset()
	return out
}
61

62
// Reset will clear the batch call, to start fetching all contents from scratch.
63
func (ibc *IterativeBatchCall[K, V]) Reset() {
64 65 66 67 68 69 70 71 72
	ibc.resetLock.Lock()
	defer ibc.resetLock.Unlock()

	scheduled := make(chan rpc.BatchElem, len(ibc.requestsKeys))
	requestsValues := make([]V, len(ibc.requestsKeys))
	for i, k := range ibc.requestsKeys {
		v, r := ibc.makeRequest(k)
		requestsValues[i] = v
		scheduled <- r
73 74
	}

75
	atomic.StoreUint32(&ibc.completed, 0)
76 77 78 79 80 81
	ibc.requestsValues = requestsValues
	ibc.scheduled = scheduled
	if len(ibc.requestsKeys) == 0 {
		close(ibc.scheduled)
	}
}
82

83
// Fetch fetches more of the data, and returns io.EOF when all data has been fetched.
84
// This method is safe to call concurrently; it will parallelize the fetching work.
85 86
// If no work is available, but the fetching is not done yet,
// then Fetch will block until the next thing can be fetched, or until the context expires.
87
func (ibc *IterativeBatchCall[K, V]) Fetch(ctx context.Context) error {
88 89 90
	ibc.resetLock.RLock()
	defer ibc.resetLock.RUnlock()

Michael de Hoog's avatar
Michael de Hoog committed
91 92 93 94 95
	// return early if context is Done
	if ctx.Err() != nil {
		return ctx.Err()
	}

96 97 98 99 100 101 102
	// collect a batch from the requests channel
	batch := make([]rpc.BatchElem, 0, ibc.batchSize)
	// wait for first element
	select {
	case reqElem, ok := <-ibc.scheduled:
		if !ok { // no more requests to do
			return io.EOF
103
		}
104 105 106
		batch = append(batch, reqElem)
	case <-ctx.Done():
		return ctx.Err()
107 108
	}

109
	// collect more elements, if there are any.
110
	for {
111 112
		if len(batch) >= ibc.batchSize {
			break
113 114
		}
		select {
115 116 117
		case reqElem, ok := <-ibc.scheduled:
			if !ok { // no more requests to do
				return io.EOF
118
			}
119 120
			batch = append(batch, reqElem)
			continue
121
		case <-ctx.Done():
122 123 124
			for _, r := range batch {
				ibc.scheduled <- r
			}
125
			return ctx.Err()
126 127 128 129 130
		default:
		}
		break
	}

131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
	if len(batch) == 0 {
		return nil
	}

	if ibc.batchSize == 1 {
		first := batch[0]
		if err := ibc.getSingle(ctx, &first.Result, first.Method, first.Args...); err != nil {
			ibc.scheduled <- first
			return err
		}
	} else {
		if err := ibc.getBatch(ctx, batch); err != nil {
			for _, r := range batch {
				ibc.scheduled <- r
			}
			return fmt.Errorf("failed batch-retrieval: %w", err)
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
		}
	}
	var result error
	for _, elem := range batch {
		if elem.Error != nil {
			result = multierror.Append(result, elem.Error)
			elem.Error = nil // reset, we'll try this element again
			ibc.scheduled <- elem
			continue
		} else {
			atomic.AddUint32(&ibc.completed, 1)
			if atomic.LoadUint32(&ibc.completed) >= uint32(len(ibc.requestsKeys)) {
				close(ibc.scheduled)
				return io.EOF
			}
162 163
		}
	}
164 165 166 167
	return result
}

// Complete indicates if the batch call is done.
168
func (ibc *IterativeBatchCall[K, V]) Complete() bool {
169 170 171 172 173 174 175
	ibc.resetLock.RLock()
	defer ibc.resetLock.RUnlock()
	return atomic.LoadUint32(&ibc.completed) >= uint32(len(ibc.requestsKeys))
}

// Result returns the fetched values, checked and transformed to the final output type, if available.
// If the check fails, the IterativeBatchCall will Reset itself, to be ready for a re-attempt in fetching new data.
176
func (ibc *IterativeBatchCall[K, V]) Result() ([]V, error) {
177 178 179
	ibc.resetLock.RLock()
	if atomic.LoadUint32(&ibc.completed) < uint32(len(ibc.requestsKeys)) {
		ibc.resetLock.RUnlock()
180
		return nil, errors.New("results not available yet, Fetch more first")
181 182
	}
	ibc.resetLock.RUnlock()
183
	return ibc.requestsValues, nil
184
}