batching_test.go 5.32 KB
Newer Older
1
package sources
2 3 4 5 6

import (
	"context"
	"errors"
	"fmt"
7
	"io"
8
	"testing"
9
	"time"
10

11
	"github.com/stretchr/testify/mock"
12
	"github.com/stretchr/testify/require"
13 14 15 16 17 18 19 20 21 22

	"github.com/ethereum/go-ethereum/rpc"
)

type elemCall struct {
	id  int
	err bool
}

type batchCall struct {
23 24 25 26 27 28
	elems  []elemCall
	rpcErr error
	err    string
	// Artificial delay to add before returning the call
	duration time.Duration
	makeCtx  func() context.Context
29 30 31 32 33 34
}

type batchTestCase struct {
	name  string
	items int

35
	batchSize int
36

37
	batchCalls []batchCall
38 39 40 41

	mock.Mock
}

42 43 44 45
func makeTestRequest(i int) (*string, rpc.BatchElem) {
	out := new(string)
	return out, rpc.BatchElem{
		Method: "testing_foobar",
46
		Args:   []any{i},
47 48 49 50 51
		Result: out,
		Error:  nil,
	}
}

52
func (tc *batchTestCase) GetBatch(ctx context.Context, b []rpc.BatchElem) error {
53 54 55
	if ctx.Err() != nil {
		return ctx.Err()
	}
56 57 58
	return tc.Mock.MethodCalled("get", b).Get(0).([]error)[0]
}

59 60
var mockErr = errors.New("mockErr")

61
func (tc *batchTestCase) Run(t *testing.T) {
62 63 64 65
	keys := make([]int, tc.items)
	for i := 0; i < tc.items; i++ {
		keys[i] = i
	}
66

67 68
	makeMock := func(bci int, bc batchCall) func(args mock.Arguments) {
		return func(args mock.Arguments) {
69
			batch := args[0].([]rpc.BatchElem)
70 71 72 73 74 75 76
			for i, elem := range batch {
				id := elem.Args[0].(int)
				expectedID := bc.elems[i].id
				require.Equal(t, expectedID, id, "batch element should match expected batch element")
				if bc.elems[i].err {
					batch[i].Error = mockErr
					*batch[i].Result.(*string) = ""
77 78
				} else {
					batch[i].Error = nil
79
					*batch[i].Result.(*string) = fmt.Sprintf("mock result id %d", id)
80 81
				}
			}
82 83
			time.Sleep(bc.duration)
		}
84
	}
85 86 87 88 89 90
	// mock all the results of the batch calls
	for bci, bc := range tc.batchCalls {
		var batch []rpc.BatchElem
		for _, elem := range bc.elems {
			batch = append(batch, rpc.BatchElem{
				Method: "testing_foobar",
91
				Args:   []any{elem.id},
92 93 94 95 96 97 98 99
				Result: new(string),
				Error:  nil,
			})
		}
		if len(bc.elems) > 0 {
			tc.On("get", batch).Once().Run(makeMock(bci, bc)).Return([]error{bc.rpcErr}) // wrap to preserve nil as type of error
		}
	}
100
	iter := NewIterativeBatchCall[int, *string](keys, makeTestRequest, tc.GetBatch, tc.batchSize)
101 102 103 104 105
	for i, bc := range tc.batchCalls {
		ctx := context.Background()
		if bc.makeCtx != nil {
			ctx = bc.makeCtx()
		}
106

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
		err := iter.Fetch(ctx)
		if err == io.EOF {
			require.Equal(t, i, len(tc.batchCalls)-1, "EOF only on last call")
		} else {
			require.False(t, iter.Complete())
			if bc.err == "" {
				require.NoError(t, err)
			} else {
				require.ErrorContains(t, err, bc.err)
			}
		}
	}
	require.True(t, iter.Complete(), "batch iter should be complete after the expected calls")
	out, err := iter.Result()
	require.NoError(t, err)
	for i, v := range out {
		require.NotNil(t, v)
		require.Equal(t, fmt.Sprintf("mock result id %d", i), *v)
	}
	out2, err := iter.Result()
	require.NoError(t, err)
	require.Equal(t, out, out2, "cached result should match")
	require.Equal(t, io.EOF, iter.Fetch(context.Background()), "fetch after completion should EOF")
130 131 132 133 134 135 136

	tc.AssertExpectations(t)
}

func TestFetchBatched(t *testing.T) {
	testCases := []*batchTestCase{
		{
137 138 139
			name:       "empty",
			items:      0,
			batchCalls: []batchCall{},
140 141
		},
		{
142 143 144
			name:      "simple",
			items:     4,
			batchSize: 4,
145 146 147 148 149 150 151 152
			batchCalls: []batchCall{
				{
					elems: []elemCall{
						{id: 0, err: false},
						{id: 1, err: false},
						{id: 2, err: false},
						{id: 3, err: false},
					},
153
					err: "",
154 155 156 157
				},
			},
		},
		{
158 159 160
			name:      "split",
			items:     5,
			batchSize: 3,
161 162 163 164 165 166 167
			batchCalls: []batchCall{
				{
					elems: []elemCall{
						{id: 0, err: false},
						{id: 1, err: false},
						{id: 2, err: false},
					},
168
					err: "",
169 170 171 172 173 174
				},
				{
					elems: []elemCall{
						{id: 3, err: false},
						{id: 4, err: false},
					},
175
					err: "",
176 177 178 179
				},
			},
		},
		{
180 181 182
			name:      "efficient retry",
			items:     7,
			batchSize: 2,
183 184 185 186
			batchCalls: []batchCall{
				{
					elems: []elemCall{
						{id: 0, err: false},
187
						{id: 1, err: true},
188
					},
189
					err: "1 error occurred:",
190 191 192 193
				},
				{
					elems: []elemCall{
						{id: 2, err: false},
194
						{id: 3, err: false},
195
					},
196
					err: "",
197 198
				},
				{
199 200 201
					elems: []elemCall{ // in-process before retry even happens
						{id: 4, err: false},
						{id: 5, err: false},
202
					},
203
					err: "",
204 205 206
				},
				{
					elems: []elemCall{
207 208
						{id: 6, err: false},
						{id: 1, err: false}, // includes the element to retry
209
					},
210
					err: "",
211 212 213 214
				},
			},
		},
		{
215 216 217
			name:      "repeated sequential retries",
			items:     2,
			batchSize: 2,
218 219 220
			batchCalls: []batchCall{
				{
					elems: []elemCall{
221
						{id: 0, err: true},
222 223
						{id: 1, err: true},
					},
224
					err: "2 errors occurred:",
225 226 227
				},
				{
					elems: []elemCall{
228
						{id: 0, err: false},
229 230
						{id: 1, err: true},
					},
231
					err: "1 error occurred:",
232 233 234 235 236
				},
				{
					elems: []elemCall{
						{id: 1, err: false},
					},
237
					err: "",
238 239 240 241
				},
			},
		},
		{
242 243 244
			name:      "context timeout",
			items:     1,
			batchSize: 3,
245 246
			batchCalls: []batchCall{
				{
247 248 249 250 251 252
					elems: nil,
					err:   context.Canceled.Error(),
					makeCtx: func() context.Context {
						ctx, cancel := context.WithCancel(context.Background())
						cancel()
						return ctx
253 254 255 256
					},
				},
				{
					elems: []elemCall{
257
						{id: 0, err: false},
258
					},
259
					err: "",
260 261 262 263 264 265 266 267
				},
			},
		},
	}
	for _, tc := range testCases {
		t.Run(tc.name, tc.Run)
	}
}