batching_test.go 5.53 KB
Newer Older
1
package sources
2 3 4 5 6

import (
	"context"
	"errors"
	"fmt"
7
	"io"
8
	"testing"
9
	"time"
10

11
	"github.com/stretchr/testify/mock"
12
	"github.com/stretchr/testify/require"
13 14 15 16 17 18 19 20 21 22

	"github.com/ethereum/go-ethereum/rpc"
)

type elemCall struct {
	id  int
	err bool
}

type batchCall struct {
23 24 25 26 27 28
	elems  []elemCall
	rpcErr error
	err    string
	// Artificial delay to add before returning the call
	duration time.Duration
	makeCtx  func() context.Context
29 30 31 32 33 34
}

type batchTestCase struct {
	name  string
	items int

35
	batchSize int
36

37
	batchCalls []batchCall
38 39 40 41

	mock.Mock
}

42 43 44 45 46 47 48 49 50 51 52 53 54
func makeTestRequest(i int) (*string, rpc.BatchElem) {
	out := new(string)
	return out, rpc.BatchElem{
		Method: "testing_foobar",
		Args:   []interface{}{i},
		Result: out,
		Error:  nil,
	}
}

func makeTestResults() func(keys []int, values []*string) ([]*string, error) {
	return func(keys []int, values []*string) ([]*string, error) {
		return values, nil
55 56 57 58
	}
}

func (tc *batchTestCase) GetBatch(ctx context.Context, b []rpc.BatchElem) error {
59 60 61
	if ctx.Err() != nil {
		return ctx.Err()
	}
62 63 64
	return tc.Mock.MethodCalled("get", b).Get(0).([]error)[0]
}

65 66
var mockErr = errors.New("mockErr")

67
func (tc *batchTestCase) Run(t *testing.T) {
68 69 70 71
	keys := make([]int, tc.items)
	for i := 0; i < tc.items; i++ {
		keys[i] = i
	}
72

73 74
	makeMock := func(bci int, bc batchCall) func(args mock.Arguments) {
		return func(args mock.Arguments) {
75
			batch := args[0].([]rpc.BatchElem)
76 77 78 79 80 81 82
			for i, elem := range batch {
				id := elem.Args[0].(int)
				expectedID := bc.elems[i].id
				require.Equal(t, expectedID, id, "batch element should match expected batch element")
				if bc.elems[i].err {
					batch[i].Error = mockErr
					*batch[i].Result.(*string) = ""
83 84
				} else {
					batch[i].Error = nil
85
					*batch[i].Result.(*string) = fmt.Sprintf("mock result id %d", id)
86 87
				}
			}
88 89
			time.Sleep(bc.duration)
		}
90
	}
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
	// mock all the results of the batch calls
	for bci, bc := range tc.batchCalls {
		var batch []rpc.BatchElem
		for _, elem := range bc.elems {
			batch = append(batch, rpc.BatchElem{
				Method: "testing_foobar",
				Args:   []interface{}{elem.id},
				Result: new(string),
				Error:  nil,
			})
		}
		if len(bc.elems) > 0 {
			tc.On("get", batch).Once().Run(makeMock(bci, bc)).Return([]error{bc.rpcErr}) // wrap to preserve nil as type of error
		}
	}
	iter := NewIterativeBatchCall[int, *string, []*string](keys, makeTestRequest, makeTestResults(), tc.GetBatch, tc.batchSize)
	for i, bc := range tc.batchCalls {
		ctx := context.Background()
		if bc.makeCtx != nil {
			ctx = bc.makeCtx()
		}
112

113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
		err := iter.Fetch(ctx)
		if err == io.EOF {
			require.Equal(t, i, len(tc.batchCalls)-1, "EOF only on last call")
		} else {
			require.False(t, iter.Complete())
			if bc.err == "" {
				require.NoError(t, err)
			} else {
				require.ErrorContains(t, err, bc.err)
			}
		}
	}
	require.True(t, iter.Complete(), "batch iter should be complete after the expected calls")
	out, err := iter.Result()
	require.NoError(t, err)
	for i, v := range out {
		require.NotNil(t, v)
		require.Equal(t, fmt.Sprintf("mock result id %d", i), *v)
	}
	out2, err := iter.Result()
	require.NoError(t, err)
	require.Equal(t, out, out2, "cached result should match")
	require.Equal(t, io.EOF, iter.Fetch(context.Background()), "fetch after completion should EOF")
136 137 138 139 140 141 142

	tc.AssertExpectations(t)
}

func TestFetchBatched(t *testing.T) {
	testCases := []*batchTestCase{
		{
143 144 145
			name:       "empty",
			items:      0,
			batchCalls: []batchCall{},
146 147
		},
		{
148 149 150
			name:      "simple",
			items:     4,
			batchSize: 4,
151 152 153 154 155 156 157 158
			batchCalls: []batchCall{
				{
					elems: []elemCall{
						{id: 0, err: false},
						{id: 1, err: false},
						{id: 2, err: false},
						{id: 3, err: false},
					},
159
					err: "",
160 161 162 163
				},
			},
		},
		{
164 165 166
			name:      "split",
			items:     5,
			batchSize: 3,
167 168 169 170 171 172 173
			batchCalls: []batchCall{
				{
					elems: []elemCall{
						{id: 0, err: false},
						{id: 1, err: false},
						{id: 2, err: false},
					},
174
					err: "",
175 176 177 178 179 180
				},
				{
					elems: []elemCall{
						{id: 3, err: false},
						{id: 4, err: false},
					},
181
					err: "",
182 183 184 185
				},
			},
		},
		{
186 187 188
			name:      "efficient retry",
			items:     7,
			batchSize: 2,
189 190 191 192
			batchCalls: []batchCall{
				{
					elems: []elemCall{
						{id: 0, err: false},
193
						{id: 1, err: true},
194
					},
195
					err: "1 error occurred:",
196 197 198 199
				},
				{
					elems: []elemCall{
						{id: 2, err: false},
200
						{id: 3, err: false},
201
					},
202
					err: "",
203 204
				},
				{
205 206 207
					elems: []elemCall{ // in-process before retry even happens
						{id: 4, err: false},
						{id: 5, err: false},
208
					},
209
					err: "",
210 211 212
				},
				{
					elems: []elemCall{
213 214
						{id: 6, err: false},
						{id: 1, err: false}, // includes the element to retry
215
					},
216
					err: "",
217 218 219 220
				},
			},
		},
		{
221 222 223
			name:      "repeated sequential retries",
			items:     2,
			batchSize: 2,
224 225 226
			batchCalls: []batchCall{
				{
					elems: []elemCall{
227
						{id: 0, err: true},
228 229
						{id: 1, err: true},
					},
230
					err: "2 errors occurred:",
231 232 233
				},
				{
					elems: []elemCall{
234
						{id: 0, err: false},
235 236
						{id: 1, err: true},
					},
237
					err: "1 error occurred:",
238 239 240 241 242
				},
				{
					elems: []elemCall{
						{id: 1, err: false},
					},
243
					err: "",
244 245 246 247
				},
			},
		},
		{
248 249 250
			name:      "context timeout",
			items:     1,
			batchSize: 3,
251 252
			batchCalls: []batchCall{
				{
253 254 255 256 257 258
					elems: nil,
					err:   context.Canceled.Error(),
					makeCtx: func() context.Context {
						ctx, cancel := context.WithCancel(context.Background())
						cancel()
						return ctx
259 260 261 262
					},
				},
				{
					elems: []elemCall{
263
						{id: 0, err: false},
264
					},
265
					err: "",
266 267 268 269 270 271 272 273
				},
			},
		},
	}
	for _, tc := range testCases {
		t.Run(tc.name, tc.Run)
	}
}