Commit f082bd04 authored by Adrian Sutton's avatar Adrian Sutton

challenger: Move DisputeGameFactory to use new approach to contract reads.

parent e5f8d680
......@@ -39,7 +39,7 @@ func NewFaultDisputeGameContract(addr common.Address, caller *batching.MultiCall
}
func (f *FaultDisputeGameContract) GetGameDuration(ctx context.Context) (uint64, error) {
result, err := f.multiCaller.SingleCallLatest(ctx, f.contract.Call(methodGameDuration))
result, err := f.multiCaller.SingleCall(ctx, batching.BlockLatest, f.contract.Call(methodGameDuration))
if err != nil {
return 0, fmt.Errorf("failed to fetch game duration: %w", err)
}
......@@ -47,7 +47,7 @@ func (f *FaultDisputeGameContract) GetGameDuration(ctx context.Context) (uint64,
}
func (f *FaultDisputeGameContract) GetMaxGameDepth(ctx context.Context) (uint64, error) {
result, err := f.multiCaller.SingleCallLatest(ctx, f.contract.Call(methodMaxGameDepth))
result, err := f.multiCaller.SingleCall(ctx, batching.BlockLatest, f.contract.Call(methodMaxGameDepth))
if err != nil {
return 0, fmt.Errorf("failed to fetch max game depth: %w", err)
}
......@@ -55,7 +55,7 @@ func (f *FaultDisputeGameContract) GetMaxGameDepth(ctx context.Context) (uint64,
}
func (f *FaultDisputeGameContract) GetAbsolutePrestateHash(ctx context.Context) (common.Hash, error) {
result, err := f.multiCaller.SingleCallLatest(ctx, f.contract.Call(methodAbsolutePrestate))
result, err := f.multiCaller.SingleCall(ctx, batching.BlockLatest, f.contract.Call(methodAbsolutePrestate))
if err != nil {
return common.Hash{}, fmt.Errorf("failed to fetch absolute prestate hash: %w", err)
}
......@@ -63,7 +63,7 @@ func (f *FaultDisputeGameContract) GetAbsolutePrestateHash(ctx context.Context)
}
func (f *FaultDisputeGameContract) GetStatus(ctx context.Context) (gameTypes.GameStatus, error) {
result, err := f.multiCaller.SingleCallLatest(ctx, f.contract.Call(methodStatus))
result, err := f.multiCaller.SingleCall(ctx, batching.BlockLatest, f.contract.Call(methodStatus))
if err != nil {
return 0, fmt.Errorf("failed to fetch status: %w", err)
}
......@@ -71,7 +71,7 @@ func (f *FaultDisputeGameContract) GetStatus(ctx context.Context) (gameTypes.Gam
}
func (f *FaultDisputeGameContract) GetClaimCount(ctx context.Context) (uint64, error) {
result, err := f.multiCaller.SingleCallLatest(ctx, f.contract.Call(methodClaimCount))
result, err := f.multiCaller.SingleCall(ctx, batching.BlockLatest, f.contract.Call(methodClaimCount))
if err != nil {
return 0, fmt.Errorf("failed to fetch claim count: %w", err)
}
......@@ -79,7 +79,7 @@ func (f *FaultDisputeGameContract) GetClaimCount(ctx context.Context) (uint64, e
}
func (f *FaultDisputeGameContract) GetClaim(ctx context.Context, idx uint64) (types.Claim, error) {
result, err := f.multiCaller.SingleCallLatest(ctx, f.contract.Call(methodClaim, new(big.Int).SetUint64(idx)))
result, err := f.multiCaller.SingleCall(ctx, batching.BlockLatest, f.contract.Call(methodClaim, new(big.Int).SetUint64(idx)))
if err != nil {
return types.Claim{}, fmt.Errorf("failed to fetch claim %v: %w", idx, err)
}
......@@ -97,7 +97,7 @@ func (f *FaultDisputeGameContract) GetAllClaims(ctx context.Context) ([]types.Cl
calls[i] = f.contract.Call(methodClaim, new(big.Int).SetUint64(i))
}
results, err := f.multiCaller.CallLatest(ctx, calls...)
results, err := f.multiCaller.Call(ctx, batching.BlockLatest, calls...)
if err != nil {
return nil, fmt.Errorf("failed to fetch claim data: %w", err)
}
......
......@@ -65,7 +65,7 @@ func TestSimpleGetters(t *testing.T) {
test := test
t.Run(test.method, func(t *testing.T) {
stubRpc, game := setup(t)
stubRpc.SetResponse(test.method, nil, []interface{}{test.result})
stubRpc.SetResponse(test.method, batching.BlockLatest, nil, []interface{}{test.result})
status, err := test.call(game)
require.NoError(t, err)
expected := test.expected
......@@ -85,7 +85,7 @@ func TestGetClaim(t *testing.T) {
value := common.Hash{0xab}
position := big.NewInt(2)
clock := big.NewInt(1234)
stubRpc.SetResponse(methodClaim, []interface{}{idx}, []interface{}{parentIndex, countered, value, position, clock})
stubRpc.SetResponse(methodClaim, batching.BlockLatest, []interface{}{idx}, []interface{}{parentIndex, countered, value, position, clock})
status, err := game.GetClaim(context.Background(), idx.Uint64())
require.NoError(t, err)
require.Equal(t, faultTypes.Claim{
......@@ -133,7 +133,7 @@ func TestGetAllClaims(t *testing.T) {
ParentContractIndex: 1,
}
expectedClaims := []faultTypes.Claim{claim0, claim1, claim2}
stubRpc.SetResponse(methodClaimCount, nil, []interface{}{big.NewInt(int64(len(expectedClaims)))})
stubRpc.SetResponse(methodClaimCount, batching.BlockLatest, nil, []interface{}{big.NewInt(int64(len(expectedClaims)))})
for _, claim := range expectedClaims {
expectGetClaim(stubRpc, claim)
}
......@@ -145,6 +145,7 @@ func TestGetAllClaims(t *testing.T) {
func expectGetClaim(stubRpc *batchingTest.AbiBasedRpc, claim faultTypes.Claim) {
stubRpc.SetResponse(
methodClaim,
batching.BlockLatest,
[]interface{}{big.NewInt(int64(claim.ContractIndex))},
[]interface{}{
uint32(claim.ParentContractIndex),
......
package contracts
import (
"context"
"fmt"
"math/big"
"github.com/ethereum-optimism/optimism/op-bindings/bindings"
"github.com/ethereum-optimism/optimism/op-challenger/game/types"
"github.com/ethereum-optimism/optimism/op-service/sources/batching"
"github.com/ethereum/go-ethereum/common"
)
const (
methodGameCount = "gameCount"
methodGameAtIndex = "gameAtIndex"
)
type DisputeGameFactoryContract struct {
multiCaller *batching.MultiCaller
contract *batching.BoundContract
}
func NewDisputeGameFactoryContract(addr common.Address, caller *batching.MultiCaller) (*DisputeGameFactoryContract, error) {
factoryAbi, err := bindings.DisputeGameFactoryMetaData.GetAbi()
if err != nil {
return nil, fmt.Errorf("failed to load dispute game factory ABI: %w", err)
}
return &DisputeGameFactoryContract{
multiCaller: caller,
contract: batching.NewBoundContract(factoryAbi, addr),
}, nil
}
func (f *DisputeGameFactoryContract) GetGameCount(ctx context.Context, blockNum uint64) (uint64, error) {
result, err := f.multiCaller.SingleCall(ctx, batching.BlockByNumber(blockNum), f.contract.Call(methodGameCount))
if err != nil {
return 0, fmt.Errorf("failed to load game count: %w", err)
}
return result.GetBigInt(0).Uint64(), nil
}
func (f *DisputeGameFactoryContract) GetGame(ctx context.Context, idx uint64, blockNum uint64) (types.GameMetadata, error) {
result, err := f.multiCaller.SingleCall(ctx, batching.BlockByNumber(blockNum), f.contract.Call(methodGameAtIndex, new(big.Int).SetUint64(idx)))
if err != nil {
return types.GameMetadata{}, fmt.Errorf("failed to load game %v: %w", idx, err)
}
return f.decodeGame(result), nil
}
func (f *DisputeGameFactoryContract) decodeGame(result *batching.CallResult) types.GameMetadata {
gameType := result.GetUint8(0)
timestamp := result.GetUint64(1)
proxy := result.GetAddress(2)
return types.GameMetadata{
GameType: gameType,
Timestamp: timestamp,
Proxy: proxy,
}
}
package contracts
import (
"context"
"math/big"
"testing"
"github.com/ethereum-optimism/optimism/op-bindings/bindings"
"github.com/ethereum-optimism/optimism/op-challenger/game/types"
"github.com/ethereum-optimism/optimism/op-service/sources/batching"
batchingTest "github.com/ethereum-optimism/optimism/op-service/sources/batching/test"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
)
func TestDisputeGameFactorySimpleGetters(t *testing.T) {
blockNum := uint64(23)
tests := []struct {
method string
args []interface{}
result interface{}
expected interface{} // Defaults to expecting the same as result
call func(game *DisputeGameFactoryContract) (any, error)
}{
{
method: methodGameCount,
result: big.NewInt(9876),
expected: uint64(9876),
call: func(game *DisputeGameFactoryContract) (any, error) {
return game.GetGameCount(context.Background(), blockNum)
},
},
}
for _, test := range tests {
test := test
t.Run(test.method, func(t *testing.T) {
stubRpc, factory := setupDisputeGameFactoryTest(t)
stubRpc.SetResponse(test.method, batching.BlockByNumber(blockNum), nil, []interface{}{test.result})
status, err := test.call(factory)
require.NoError(t, err)
expected := test.expected
if expected == nil {
expected = test.result
}
require.Equal(t, expected, status)
})
}
}
func TestLoadGame(t *testing.T) {
blockNum := uint64(23)
stubRpc, factory := setupDisputeGameFactoryTest(t)
game0 := types.GameMetadata{
GameType: 0,
Timestamp: 1234,
Proxy: common.Address{0xaa},
}
game1 := types.GameMetadata{
GameType: 1,
Timestamp: 5678,
Proxy: common.Address{0xbb},
}
game2 := types.GameMetadata{
GameType: 99,
Timestamp: 9988,
Proxy: common.Address{0xcc},
}
expectedGames := []types.GameMetadata{game0, game1, game2}
for idx, expected := range expectedGames {
expectGetGame(stubRpc, idx, blockNum, expected)
actual, err := factory.GetGame(context.Background(), uint64(idx), blockNum)
require.NoError(t, err)
require.Equal(t, expected, actual)
}
}
func expectGetGame(stubRpc *batchingTest.AbiBasedRpc, idx int, blockNum uint64, game types.GameMetadata) {
stubRpc.SetResponse(
methodGameAtIndex,
batching.BlockByNumber(blockNum),
[]interface{}{big.NewInt(int64(idx))},
[]interface{}{
game.GameType,
game.Timestamp,
game.Proxy,
})
}
func setupDisputeGameFactoryTest(t *testing.T) (*batchingTest.AbiBasedRpc, *DisputeGameFactoryContract) {
fdgAbi, err := bindings.DisputeGameFactoryMetaData.GetAbi()
require.NoError(t, err)
address := common.HexToAddress("0x24112842371dFC380576ebb09Ae16Cb6B6caD7CB")
stubRpc := batchingTest.NewAbiBasedRpc(t, fdgAbi, address)
caller := batching.NewMultiCaller(stubRpc, 100)
factory, err := NewDisputeGameFactoryContract(address, caller)
require.NoError(t, err)
return stubRpc, factory
}
......@@ -47,7 +47,7 @@ func NewGamePlayer(
creator resourceCreator,
) (*GamePlayer, error) {
logger = logger.New("game", addr)
loader, err := contracts.NewFaultDisputeGameContract(addr, batching.NewMultiCaller(client.Client(), 100))
loader, err := contracts.NewFaultDisputeGameContract(addr, batching.NewMultiCaller(client.Client(), batching.DefaultBatchSize))
if err != nil {
return nil, fmt.Errorf("failed to create fault dispute game contract wrapper: %w", err)
}
......
......@@ -4,11 +4,8 @@ import (
"context"
"errors"
"fmt"
"math/big"
"github.com/ethereum-optimism/optimism/op-challenger/game/types"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
)
var (
......@@ -18,12 +15,8 @@ var (
// MinimalDisputeGameFactoryCaller is a minimal interface around [bindings.DisputeGameFactoryCaller].
// This needs to be updated if the [bindings.DisputeGameFactoryCaller] interface changes.
type MinimalDisputeGameFactoryCaller interface {
GameCount(opts *bind.CallOpts) (*big.Int, error)
GameAtIndex(opts *bind.CallOpts, _index *big.Int) (struct {
GameType uint8
Timestamp uint64
Proxy common.Address
}, error)
GetGameCount(ctx context.Context, blockNum uint64) (uint64, error)
GetGame(ctx context.Context, idx uint64, blockNum uint64) (types.GameMetadata, error)
}
type GameLoader struct {
......@@ -38,27 +31,17 @@ func NewGameLoader(caller MinimalDisputeGameFactoryCaller) *GameLoader {
}
// FetchAllGamesAtBlock fetches all dispute games from the factory at a given block number.
func (l *GameLoader) FetchAllGamesAtBlock(ctx context.Context, earliestTimestamp uint64, blockNumber *big.Int) ([]types.GameMetadata, error) {
if blockNumber == nil {
return nil, ErrMissingBlockNumber
}
callOpts := &bind.CallOpts{
Context: ctx,
BlockNumber: blockNumber,
}
gameCount, err := l.caller.GameCount(callOpts)
func (l *GameLoader) FetchAllGamesAtBlock(ctx context.Context, earliestTimestamp uint64, blockNumber uint64) ([]types.GameMetadata, error) {
gameCount, err := l.caller.GetGameCount(ctx, blockNumber)
if err != nil {
return nil, fmt.Errorf("failed to fetch game count: %w", err)
}
games := make([]types.GameMetadata, 0)
if gameCount.Uint64() == 0 {
return games, nil
}
for i := gameCount.Uint64(); i > 0; i-- {
game, err := l.caller.GameAtIndex(callOpts, big.NewInt(int64(i-1)))
games := make([]types.GameMetadata, 0, gameCount)
for i := gameCount; i > 0; i-- {
game, err := l.caller.GetGame(ctx, i-1, blockNumber)
if err != nil {
return nil, fmt.Errorf("failed to fetch game at index %d: %w", i, err)
return nil, fmt.Errorf("failed to fetch game at index %d: %w", i-1, err)
}
if game.Timestamp < earliestTimestamp {
break
......
......@@ -7,7 +7,6 @@ import (
"testing"
"github.com/ethereum-optimism/optimism/op-challenger/game/types"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
)
......@@ -25,44 +24,39 @@ func TestGameLoader_FetchAllGames(t *testing.T) {
name string
caller *mockMinimalDisputeGameFactoryCaller
earliest uint64
blockNumber *big.Int
blockNumber uint64
expectedErr error
expectedLen int
}{
{
name: "success",
caller: newMockMinimalDisputeGameFactoryCaller(10, false, false),
blockNumber: big.NewInt(1),
blockNumber: 1,
expectedLen: 10,
},
{
name: "expired game ignored",
caller: newMockMinimalDisputeGameFactoryCaller(10, false, false),
earliest: 500,
blockNumber: big.NewInt(1),
blockNumber: 1,
expectedLen: 5,
},
{
name: "game count error",
caller: newMockMinimalDisputeGameFactoryCaller(10, true, false),
blockNumber: big.NewInt(1),
blockNumber: 1,
expectedErr: gameCountErr,
},
{
name: "game index error",
caller: newMockMinimalDisputeGameFactoryCaller(10, false, true),
blockNumber: big.NewInt(1),
blockNumber: 1,
expectedErr: gameIndexErr,
},
{
name: "no games",
caller: newMockMinimalDisputeGameFactoryCaller(0, false, false),
blockNumber: big.NewInt(1),
},
{
name: "missing block number",
caller: newMockMinimalDisputeGameFactoryCaller(0, false, false),
expectedErr: ErrMissingBlockNumber,
blockNumber: 1,
},
}
......@@ -144,20 +138,15 @@ func newMockMinimalDisputeGameFactoryCaller(count uint64, gameCountErr bool, ind
}
}
func (m *mockMinimalDisputeGameFactoryCaller) GameCount(opts *bind.CallOpts) (*big.Int, error) {
func (m *mockMinimalDisputeGameFactoryCaller) GetGameCount(_ context.Context, blockNum uint64) (uint64, error) {
if m.gameCountErr {
return nil, gameCountErr
return 0, gameCountErr
}
return big.NewInt(int64(m.gameCount)), nil
return m.gameCount, nil
}
func (m *mockMinimalDisputeGameFactoryCaller) GameAtIndex(opts *bind.CallOpts, _index *big.Int) (struct {
GameType uint8
Timestamp uint64
Proxy common.Address
}, error) {
index := _index.Uint64()
func (m *mockMinimalDisputeGameFactoryCaller) GetGame(_ context.Context, index uint64, blockNum uint64) (types.GameMetadata, error) {
if m.indexErrors[index] {
return struct {
GameType uint8
......@@ -166,13 +155,5 @@ func (m *mockMinimalDisputeGameFactoryCaller) GameAtIndex(opts *bind.CallOpts, _
}{}, gameIndexErr
}
return struct {
GameType uint8
Timestamp uint64
Proxy common.Address
}{
GameType: m.games[index].GameType,
Timestamp: m.games[index].Timestamp,
Proxy: m.games[index].Proxy,
}, nil
return m.games[index], nil
}
......@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"math/big"
"time"
"github.com/ethereum-optimism/optimism/op-challenger/game/scheduler"
......@@ -23,7 +22,7 @@ type blockNumberFetcher func(ctx context.Context) (uint64, error)
// gameSource loads information about the games available to play
type gameSource interface {
FetchAllGamesAtBlock(ctx context.Context, earliest uint64, blockNumber *big.Int) ([]types.GameMetadata, error)
FetchAllGamesAtBlock(ctx context.Context, earliest uint64, blockNumber uint64) ([]types.GameMetadata, error)
}
type gameScheduler interface {
......@@ -101,7 +100,7 @@ func (m *gameMonitor) minGameTimestamp() uint64 {
}
func (m *gameMonitor) progressGames(ctx context.Context, blockNum uint64) error {
games, err := m.source.FetchAllGamesAtBlock(ctx, m.minGameTimestamp(), new(big.Int).SetUint64(blockNum))
games, err := m.source.FetchAllGamesAtBlock(ctx, m.minGameTimestamp(), blockNum)
if err != nil {
return fmt.Errorf("failed to load games: %w", err)
}
......
......@@ -230,7 +230,7 @@ type stubGameSource struct {
func (s *stubGameSource) FetchAllGamesAtBlock(
ctx context.Context,
earliest uint64,
blockNumber *big.Int,
blockNumber uint64,
) ([]types.GameMetadata, error) {
return s.games, nil
}
......
......@@ -5,9 +5,9 @@ import (
"errors"
"fmt"
"github.com/ethereum-optimism/optimism/op-bindings/bindings"
"github.com/ethereum-optimism/optimism/op-challenger/config"
"github.com/ethereum-optimism/optimism/op-challenger/game/fault"
"github.com/ethereum-optimism/optimism/op-challenger/game/fault/contracts"
"github.com/ethereum-optimism/optimism/op-challenger/game/loader"
"github.com/ethereum-optimism/optimism/op-challenger/game/registry"
"github.com/ethereum-optimism/optimism/op-challenger/game/scheduler"
......@@ -18,6 +18,7 @@ import (
"github.com/ethereum-optimism/optimism/op-service/dial"
"github.com/ethereum-optimism/optimism/op-service/httputil"
oppprof "github.com/ethereum-optimism/optimism/op-service/pprof"
"github.com/ethereum-optimism/optimism/op-service/sources/batching"
"github.com/ethereum-optimism/optimism/op-service/txmgr"
"github.com/ethereum/go-ethereum/log"
)
......@@ -88,7 +89,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se
m.StartBalanceMetrics(ctx, logger, l1Client, txMgr.From())
}
factoryContract, err := bindings.NewDisputeGameFactory(cfg.GameFactoryAddress, l1Client)
factoryContract, err := contracts.NewDisputeGameFactoryContract(cfg.GameFactoryAddress, batching.NewMultiCaller(l1Client.Client(), batching.DefaultBatchSize))
if err != nil {
return nil, errors.Join(fmt.Errorf("failed to bind the fault dispute game factory contract: %w", err), s.Stop(ctx))
}
......
......@@ -110,6 +110,10 @@ func (c *CallResult) GetHash(i int) common.Hash {
return *abi.ConvertType(c.out[i], new([32]byte)).(*[32]byte)
}
func (c *CallResult) GetAddress(i int) common.Address {
return *abi.ConvertType(c.out[i], new([20]byte)).(*[20]byte)
}
func (c *CallResult) GetBigInt(i int) *big.Int {
return *abi.ConvertType(c.out[i], new(*big.Int)).(**big.Int)
}
......@@ -118,6 +118,13 @@ func TestCallResult_GetValues(t *testing.T) {
},
expected: true,
},
{
name: "GetAddress",
getter: func(result *CallResult, i int) interface{} {
return result.GetAddress(i)
},
expected: ([20]byte)(common.Address{0xaa, 0xbb, 0xcc}),
},
{
name: "GetHash",
getter: func(result *CallResult, i int) interface{} {
......
......@@ -5,10 +5,13 @@ import (
"fmt"
"io"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/rpc"
)
var DefaultBatchSize = 100
type EthRpc interface {
CallContext(ctx context.Context, out interface{}, method string, args ...interface{}) error
BatchCallContext(ctx context.Context, b []rpc.BatchElem) error
......@@ -26,15 +29,15 @@ func NewMultiCaller(rpc EthRpc, batchSize int) *MultiCaller {
}
}
func (m *MultiCaller) SingleCallLatest(ctx context.Context, call *ContractCall) (*CallResult, error) {
results, err := m.CallLatest(ctx, call)
func (m *MultiCaller) SingleCall(ctx context.Context, block Block, call *ContractCall) (*CallResult, error) {
results, err := m.Call(ctx, block, call)
if err != nil {
return nil, err
}
return results[0], nil
}
func (m *MultiCaller) CallLatest(ctx context.Context, calls ...*ContractCall) ([]*CallResult, error) {
func (m *MultiCaller) Call(ctx context.Context, block Block, calls ...*ContractCall) ([]*CallResult, error) {
keys := make([]interface{}, len(calls))
for i := 0; i < len(calls); i++ {
args, err := calls[i].ToCallArgs()
......@@ -49,7 +52,7 @@ func (m *MultiCaller) CallLatest(ctx context.Context, calls ...*ContractCall) ([
out := new(hexutil.Bytes)
return out, rpc.BatchElem{
Method: "eth_call",
Args: []interface{}{args, "latest"},
Args: []interface{}{args, block.value},
Result: &out,
}
},
......@@ -79,3 +82,30 @@ func (m *MultiCaller) CallLatest(ctx context.Context, calls ...*ContractCall) ([
}
return callResults, nil
}
// Block represents the block ref value in RPC calls.
// It can be either a label (e.g. latest), a block number or block hash.
type Block struct {
value any
}
func (b Block) ArgValue() any {
return b.value
}
var (
BlockPending = Block{"pending"}
BlockLatest = Block{"latest"}
BlockSafe = Block{"safe"}
BlockFinalized = Block{"finalized"}
)
// BlockByNumber references a canonical block by number.
func BlockByNumber(blockNum uint64) Block {
return Block{rpc.BlockNumber(blockNum)}
}
// BlockByHash references a block by hash. Canonical or non-canonical blocks may be referenced.
func BlockByHash(hash common.Hash) Block {
return Block{rpc.BlockNumberOrHashWithHash(hash, false)}
}
......@@ -7,22 +7,25 @@ import (
"fmt"
"testing"
"github.com/ethereum-optimism/optimism/op-service/sources/batching"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/rpc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
)
type expectedCall struct {
block batching.Block
args []interface{}
packedArgs []byte
outputs []interface{}
}
func (e *expectedCall) String() string {
return fmt.Sprintf("{args: %v, outputs: %v}", e.args, e.outputs)
return fmt.Sprintf("{block: %v, args: %v, outputs: %v}", e.block, e.args, e.outputs)
}
type AbiBasedRpc struct {
......@@ -42,7 +45,7 @@ func NewAbiBasedRpc(t *testing.T, contractAbi *abi.ABI, addr common.Address) *Ab
}
}
func (l *AbiBasedRpc) SetResponse(method string, expected []interface{}, output []interface{}) {
func (l *AbiBasedRpc) SetResponse(method string, block batching.Block, expected []interface{}, output []interface{}) {
if expected == nil {
expected = []interface{}{}
}
......@@ -54,6 +57,7 @@ func (l *AbiBasedRpc) SetResponse(method string, expected []interface{}, output
packedArgs, err := abiMethod.Inputs.Pack(expected...)
require.NoErrorf(l.t, err, "Invalid expected arguments for method %v: %v", method, expected)
l.expectedCalls[method] = append(l.expectedCalls[method], &expectedCall{
block: block,
args: expected,
packedArgs: packedArgs,
outputs: output,
......@@ -72,7 +76,7 @@ func (l *AbiBasedRpc) BatchCallContext(ctx context.Context, b []rpc.BatchElem) e
func (l *AbiBasedRpc) CallContext(_ context.Context, out interface{}, method string, args ...interface{}) error {
require.Equal(l.t, "eth_call", method)
require.Len(l.t, args, 2)
require.Equal(l.t, "latest", args[1])
actualBlockRef := args[1]
callOpts, ok := args[0].(map[string]any)
require.True(l.t, ok)
require.Equal(l.t, &l.addr, callOpts["to"])
......@@ -90,12 +94,12 @@ func (l *AbiBasedRpc) CallContext(_ context.Context, out interface{}, method str
require.Truef(l.t, ok, "Unexpected call to %v", abiMethod.Name)
var call *expectedCall
for _, candidate := range expectedCalls {
if slices.Equal(candidate.packedArgs, argData) {
if slices.Equal(candidate.packedArgs, argData) && assert.ObjectsAreEqualValues(candidate.block.ArgValue(), actualBlockRef) {
call = candidate
break
}
}
require.NotNilf(l.t, call, "No expected calls to %v with arguments: %v\nExpected calls: %v", abiMethod.Name, args, expectedCalls)
require.NotNilf(l.t, call, "No expected calls to %v at block %v with arguments: %v\nExpected calls: %v", abiMethod.Name, actualBlockRef, args, expectedCalls)
output, err := abiMethod.Outputs.Pack(call.outputs...)
require.NoErrorf(l.t, err, "Invalid outputs for method %v: %v", abiMethod.Name, call.outputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment