bound.go 2.39 KB
Newer Older
1 2 3 4 5 6 7 8
package batching

import (
	"errors"
	"fmt"

	"github.com/ethereum/go-ethereum/accounts/abi"
	"github.com/ethereum/go-ethereum/common"
9
	"github.com/ethereum/go-ethereum/core/types"
10 11 12 13 14
)

var (
	ErrUnknownMethod = errors.New("unknown method")
	ErrInvalidCall   = errors.New("invalid call")
15 16
	ErrUnknownEvent  = errors.New("unknown event")
	ErrInvalidEvent  = errors.New("invalid event")
17 18 19 20 21 22 23 24 25 26 27 28 29 30
)

type BoundContract struct {
	abi  *abi.ABI
	addr common.Address
}

func NewBoundContract(abi *abi.ABI, addr common.Address) *BoundContract {
	return &BoundContract{
		abi:  abi,
		addr: addr,
	}
}

31 32 33 34
func (b *BoundContract) Addr() common.Address {
	return b.addr
}

35 36 37 38 39 40 41 42 43 44 45
func (b *BoundContract) Call(method string, args ...interface{}) *ContractCall {
	return NewContractCall(b.abi, b.addr, method, args...)
}

func (b *BoundContract) DecodeCall(data []byte) (string, *CallResult, error) {
	if len(data) < 4 {
		return "", nil, ErrUnknownMethod
	}
	method, err := b.abi.MethodById(data[:4])
	if err != nil {
		// ABI doesn't return a nicely typed error so treat any failure to find the method as unknown
46
		return "", nil, fmt.Errorf("%w: %w", ErrUnknownMethod, err)
47 48 49
	}
	args, err := method.Inputs.Unpack(data[4:])
	if err != nil {
50
		return "", nil, fmt.Errorf("%w: %w", ErrInvalidCall, err)
51 52 53
	}
	return method.Name, &CallResult{args}, nil
}
54 55 56 57 58 59 60

func (b *BoundContract) DecodeEvent(log *types.Log) (string, *CallResult, error) {
	if len(log.Topics) == 0 {
		return "", nil, ErrUnknownEvent
	}
	event, err := b.abi.EventByID(log.Topics[0])
	if err != nil {
61
		return "", nil, fmt.Errorf("%w: %w", ErrUnknownEvent, err)
62 63 64 65 66 67 68 69 70 71
	}

	argsMap := make(map[string]interface{})
	var indexed abi.Arguments
	for _, arg := range event.Inputs {
		if arg.Indexed {
			indexed = append(indexed, arg)
		}
	}
	if err := abi.ParseTopicsIntoMap(argsMap, indexed, log.Topics[1:]); err != nil {
72
		return "", nil, fmt.Errorf("%w indexed topics: %w", ErrInvalidEvent, err)
73 74 75 76 77
	}

	nonIndexed := event.Inputs.NonIndexed()
	if len(nonIndexed) > 0 {
		if err := nonIndexed.UnpackIntoMap(argsMap, log.Data); err != nil {
78
			return "", nil, fmt.Errorf("%w non-indexed topics: %w", ErrInvalidEvent, err)
79 80 81 82 83 84 85 86 87 88 89 90
		}
	}
	args := make([]interface{}, 0, len(event.Inputs))
	for _, input := range event.Inputs {
		val, ok := argsMap[input.Name]
		if !ok {
			return "", nil, fmt.Errorf("%w missing argument: %v", ErrUnknownEvent, input.Name)
		}
		args = append(args, val)
	}
	return event.Name, &CallResult{args}, nil
}