generic_stub.go 2.71 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
package test

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"math/big"
	"testing"

11
	"github.com/ethereum-optimism/optimism/op-service/sources/batching/rpcblock"
12 13 14 15 16 17 18 19 20 21
	"github.com/ethereum/go-ethereum/common"
	"github.com/ethereum/go-ethereum/common/hexutil"
	"github.com/ethereum/go-ethereum/rpc"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

type ExpectedRpcCall interface {
	fmt.Stringer
	Matches(rpcMethod string, args ...interface{}) error
22
	Execute(t *testing.T, out interface{}) error
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
}

type RpcStub struct {
	t             *testing.T
	expectedCalls []ExpectedRpcCall
}

func NewRpcStub(t *testing.T) *RpcStub {
	return &RpcStub{t: t}
}

func (r *RpcStub) ClearResponses() {
	r.expectedCalls = nil
}

func (r *RpcStub) AddExpectedCall(call ExpectedRpcCall) {
	r.expectedCalls = append(r.expectedCalls, call)
}

func (r *RpcStub) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
	var errs []error
	for _, elem := range b {
		elem.Error = r.CallContext(ctx, elem.Result, elem.Method, elem.Args...)
		errs = append(errs, elem.Error)
	}
	return errors.Join(errs...)
}

func (r *RpcStub) CallContext(_ context.Context, out interface{}, method string, args ...interface{}) error {
	call := r.findExpectedCall(method, args...)
53
	return call.Execute(r.t, out)
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
}

func (r *RpcStub) findExpectedCall(rpcMethod string, args ...interface{}) ExpectedRpcCall {
	var matchResults string
	for _, call := range r.expectedCalls {
		if err := call.Matches(rpcMethod, args...); err == nil {
			return call
		} else {
			matchResults += fmt.Sprintf("%v: %v", call, err)
		}
	}
	require.Failf(r.t, "No matching expected calls.", matchResults)
	return nil
}

type GenericExpectedCall struct {
	method string
	args   []interface{}
	result interface{}
}

75
func NewGetBalanceCall(addr common.Address, block rpcblock.Block, balance *big.Int) ExpectedRpcCall {
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
	return &GenericExpectedCall{
		method: "eth_getBalance",
		args:   []interface{}{addr, block.ArgValue()},
		result: (*hexutil.Big)(balance),
	}
}

func (c *GenericExpectedCall) Matches(rpcMethod string, args ...interface{}) error {
	if rpcMethod != c.method {
		return fmt.Errorf("expected method %v but was %v", c.method, rpcMethod)
	}
	if !assert.ObjectsAreEqualValues(c.args, args) {
		return fmt.Errorf("expected args %v but was %v", c.args, args)
	}
	return nil
}

93
func (c *GenericExpectedCall) Execute(t *testing.T, out interface{}) error {
94 95 96 97 98
	// I admit I do not understand Go reflection.
	// So leverage json.Unmarshal to set the out value correctly.
	j, err := json.Marshal(c.result)
	require.NoError(t, err)
	require.NoError(t, json.Unmarshal(j, out))
99
	return nil
100 101 102 103 104
}

func (c *GenericExpectedCall) String() string {
	return fmt.Sprintf("%v(%v)->%v", c.method, c.args, c.result)
}