receipts_basic_test.go 5.21 KB
Newer Older
1 2 3 4 5 6 7
package sources

import (
	"context"
	"errors"
	"fmt"
	"math/rand"
8
	"sync/atomic"
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
	"testing"
	"time"

	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/core/types"
	"github.com/ethereum/go-ethereum/rpc"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
)

// simpleMockRPC is needed for some tests where the return value dynamically
// depends on the input, so that the test can set the function.
type simpleMockRPC struct {
	callFn      func(ctx context.Context, result any, method string, args ...any) error
	batchCallFn func(ctx context.Context, b []rpc.BatchElem) error
}

func (m *simpleMockRPC) CallContext(ctx context.Context, result any, method string, args ...any) error {
	return m.callFn(ctx, result, method, args...)
}

func (m *simpleMockRPC) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
	return m.batchCallFn(ctx, b)
}

func TestBasicRPCReceiptsFetcher_Reuse(t *testing.T) {
	require := require.New(t)
	batchSize, txCount := 2, uint64(4)
	block, receipts := randomRpcBlockAndReceipts(rand.New(rand.NewSource(123)), txCount)
	txHashes := make([]common.Hash, 0, len(receipts))
	recMap := make(map[common.Hash]*types.Receipt, len(receipts))
	for _, rec := range receipts {
		txHashes = append(txHashes, rec.TxHash)
		recMap[rec.TxHash] = rec
	}
	mrpc := new(simpleMockRPC)
	rp := NewBasicRPCReceiptsFetcher(mrpc, batchSize)

	// prepare mock
	ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
	defer done()
	// 1st fetching
	response := map[common.Hash]bool{
		txHashes[0]: true,
		txHashes[1]: true,
		txHashes[2]: false,
		txHashes[3]: false,
	}
57
	var numCalls atomic.Int32
58
	mrpc.batchCallFn = func(_ context.Context, b []rpc.BatchElem) (err error) {
59
		numCalls.Add(1)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
		for i, el := range b {
			if el.Method == "eth_getTransactionReceipt" {
				txHash := el.Args[0].(common.Hash)
				if response[txHash] {
					// The IterativeBatchCall expects that the values are written
					// to the fields of the allocated *types.Receipt.
					**(el.Result.(**types.Receipt)) = *recMap[txHash]
				} else {
					err = errors.Join(err, fmt.Errorf("receipt[%d] error, hash %x", i, txHash))
				}
			} else {
				err = errors.Join(err, fmt.Errorf("unknown method %s", el.Method))
			}
		}
		return err
	}

77 78
	bInfo, _, _ := block.Info(true, true)

79
	// 1st fetching should result in errors
80
	recs, err := rp.FetchReceipts(ctx, bInfo, txHashes)
81 82
	require.Error(err)
	require.Nil(recs)
83
	require.EqualValues(2, numCalls.Load())
84 85 86 87

	// prepare 2nd fetching - all should succeed now
	response[txHashes[2]] = true
	response[txHashes[3]] = true
88
	recs, err = rp.FetchReceipts(ctx, bInfo, txHashes)
89 90 91 92 93
	require.NoError(err)
	require.NotNil(recs)
	for i, rec := range recs {
		requireEqualReceipt(t, receipts[i], rec)
	}
94
	require.EqualValues(3, numCalls.Load())
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
}

func TestBasicRPCReceiptsFetcher_Concurrency(t *testing.T) {
	require := require.New(t)
	const numFetchers = 32
	batchSize, txCount := 4, uint64(18) // 4.5 * 4
	block, receipts := randomRpcBlockAndReceipts(rand.New(rand.NewSource(123)), txCount)
	recMap := make(map[common.Hash]*types.Receipt, len(receipts))
	for _, rec := range receipts {
		recMap[rec.TxHash] = rec
	}
	mrpc := new(mockRPC)
	rp := NewBasicRPCReceiptsFetcher(mrpc, batchSize)

	// prepare mock
110
	var numCalls atomic.Int32
111 112
	mrpc.On("BatchCallContext", mock.Anything, mock.AnythingOfType("[]rpc.BatchElem")).
		Run(func(args mock.Arguments) {
113
			numCalls.Add(1)
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
			els := args.Get(1).([]rpc.BatchElem)
			for _, el := range els {
				if el.Method == "eth_getTransactionReceipt" {
					txHash := el.Args[0].(common.Hash)
					// The IterativeBatchCall expects that the values are written
					// to the fields of the allocated *types.Receipt.
					**(el.Result.(**types.Receipt)) = *recMap[txHash]
				}
			}
		}).
		Return([]error{nil})

	runConcurrentFetchingTest(t, rp, numFetchers, receipts, block)

	mrpc.AssertExpectations(t)
129 130 131
	finalNumCalls := int(numCalls.Load())
	require.NotZero(finalNumCalls, "BatchCallContext should have been called.")
	require.Less(finalNumCalls, numFetchers, "Some IterativeBatchCalls should have been shared.")
132 133
}

134
func runConcurrentFetchingTest(t *testing.T, rp ReceiptsProvider, numFetchers int, receipts types.Receipts, block *RPCBlock) {
135 136 137 138 139 140 141 142 143 144
	require := require.New(t)
	txHashes := receiptTxHashes(receipts)

	// start n fetchers
	type fetchResult struct {
		rs  types.Receipts
		err error
	}
	fetchResults := make(chan fetchResult, numFetchers)
	barrier := make(chan struct{})
145
	bInfo, _, _ := block.Info(true, true)
146 147 148 149 150
	ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
	defer done()
	for i := 0; i < numFetchers; i++ {
		go func() {
			<-barrier
151
			recs, err := rp.FetchReceipts(ctx, bInfo, txHashes)
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
			fetchResults <- fetchResult{rs: recs, err: err}
		}()
	}
	close(barrier) // Go!

	// assert results
	for i := 0; i < numFetchers; i++ {
		select {
		case f := <-fetchResults:
			require.NoError(f.err)
			require.Len(f.rs, len(receipts))
			for j, r := range receipts {
				requireEqualReceipt(t, r, f.rs[j])
			}
		case <-ctx.Done():
			t.Fatal("Test timeout")
		}
	}
}

func receiptTxHashes(receipts types.Receipts) []common.Hash {
	txHashes := make([]common.Hash, 0, len(receipts))
	for _, rec := range receipts {
		txHashes = append(txHashes, rec.TxHash)
	}
	return txHashes
}