Commit 219ebe00 authored by Matthew Slipper's avatar Matthew Slipper Committed by GitHub

Add broadcast API to Go forge scripts (#11826)

* Add broadcast API to Go forge scripts

Adds a hooks-based API to collect transactions broadcasted via `vm.broadcast(*)` in the Go-based Forge scripts. Users pass an `OnBroadcast` hook to the host, which will be called with a `Broadcast` struct with the following fields whenever a transaction needs to be emitted:

```go
type Broadcast struct {
	From     common.Address
	To       common.Address
	Calldata []byte
	Value    *big.Int
}
```

This API lets us layer on custom transaction management in the future which will be helpful for `op-deployer`.

As part of this PR, I changed the internal `callStack` data structure to contain pointers to `CallFrame`s rather than passing by value. I discovered a bug where the pranked sender was not being cleared in subsequent calls due to an ineffectual assignment error. I took a look at the implementation and there are many places where assignments to call frames within the stack happen after converting the value to a reference, so converting the stack to store pointers in the first place both simplified the code and eliminated a class of errors in the future. I updated the public API methods to return copies of the internal structs to prevent accidental mutation.

* Code review updates

* moar review updates

* fix bug with staticcall
parent e2356c35
package script
import (
"bytes"
"errors"
"math/big"
......@@ -69,7 +70,7 @@ func (h *Host) Prank(msgSender *common.Address, txOrigin *common.Address, repeat
h.log.Warn("no call stack")
return nil // cannot prank while not in a call.
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank != nil {
if cf.Prank.Broadcast && !broadcast {
return errors.New("you have an active broadcast; broadcasting and pranks are not compatible")
......@@ -98,7 +99,7 @@ func (h *Host) StopPrank(broadcast bool) error {
if len(h.callStack) == 0 {
return nil
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank == nil {
if broadcast {
return errors.New("no broadcast in progress to stop")
......@@ -127,7 +128,7 @@ func (h *Host) CallerMode() CallerMode {
if len(h.callStack) == 0 {
return CallerModeNone
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if cf.Prank != nil {
if cf.Prank.Broadcast {
if cf.Prank.Repeat {
......@@ -157,3 +158,36 @@ const (
CallerModePrank
CallerModeRecurrentPrank
)
// Broadcast captures a transaction that was selected to be broadcasted
// via vm.broadcast(). Actually submitting the transaction is left up
// to other tools.
type Broadcast struct {
From common.Address
To common.Address
Calldata []byte
Value *big.Int
}
// NewBroadcastFromCtx creates a Broadcast from a VM context. This method
// is preferred to manually creating the struct since it correctly handles
// data that must be copied prior to being returned to prevent accidental
// mutation.
func NewBroadcastFromCtx(ctx *vm.ScopeContext) Broadcast {
// Consistently return nil for zero values in order
// for tests to have a deterministic value to compare
// against.
value := ctx.CallValue().ToBig()
if value.Cmp(common.Big0) == 0 {
value = nil
}
// Need to clone CallInput() below since it's used within
// the VM itself elsewhere.
return Broadcast{
From: ctx.Caller(),
To: ctx.Address(),
Calldata: bytes.Clone(ctx.CallInput()),
Value: value,
}
}
......@@ -3,6 +3,7 @@ package script
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math/big"
......@@ -69,7 +70,7 @@ type Host struct {
precompiles map[common.Address]vm.PrecompiledContract
callStack []CallFrame
callStack []*CallFrame
// serializerStates are in-progress JSON payloads by name,
// for the serializeX family of cheat codes, see:
......@@ -86,12 +87,34 @@ type Host struct {
srcMaps map[common.Address]*srcmap.SourceMap
onLabel []func(name string, addr common.Address)
hooks *Hooks
}
type HostOption func(h *Host)
type BroadcastHook func(broadcast Broadcast)
type Hooks struct {
OnBroadcast BroadcastHook
}
func WithBroadcastHook(hook BroadcastHook) HostOption {
return func(h *Host) {
h.hooks.OnBroadcast = hook
}
}
// NewHost creates a Host that can load contracts from the given Artifacts FS,
// and with an EVM initialized to the given executionContext.
// Optionally src-map loading may be enabled, by providing a non-nil srcFS to read sources from.
func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMapFS, executionContext Context) *Host {
func NewHost(
logger log.Logger,
fs *foundry.ArtifactsFS,
srcFS *foundry.SourceMapFS,
executionContext Context,
options ...HostOption,
) *Host {
h := &Host{
log: logger,
af: fs,
......@@ -101,6 +124,13 @@ func NewHost(logger log.Logger, fs *foundry.ArtifactsFS, srcFS *foundry.SourceMa
precompiles: make(map[common.Address]vm.PrecompiledContract),
srcFS: srcFS,
srcMaps: make(map[common.Address]*srcmap.SourceMap),
hooks: &Hooks{
OnBroadcast: func(broadcast Broadcast) {},
},
}
for _, opt := range options {
opt(h)
}
// Init a default chain config, with all the mainnet L1 forks activated
......@@ -361,6 +391,19 @@ func (h *Host) unwindCallstack(depth int) {
if len(h.callStack) > 1 {
parentCallFrame := h.callStack[len(h.callStack)-2]
if parentCallFrame.Prank != nil {
if parentCallFrame.Prank.Broadcast && parentCallFrame.LastOp != vm.STATICCALL {
currentFrame := h.callStack[len(h.callStack)-1]
bcast := NewBroadcastFromCtx(currentFrame.Ctx)
h.hooks.OnBroadcast(bcast)
h.log.Debug(
"called broadcast hook",
"from", bcast.From,
"to", bcast.To,
"calldata", hex.EncodeToString(bcast.Calldata),
"value", bcast.Value,
)
}
// While going back to the parent, restore the tx.origin.
// It will later be re-applied on sub-calls if the prank persists (if Repeat == true).
if parentCallFrame.Prank.Origin != nil {
......@@ -372,7 +415,7 @@ func (h *Host) unwindCallstack(depth int) {
}
}
// Now pop the call-frame
h.callStack[len(h.callStack)-1] = CallFrame{} // don't hold on to the underlying call-frame resources
h.callStack[len(h.callStack)-1] = nil // don't hold on to the underlying call-frame resources
h.callStack = h.callStack[:len(h.callStack)-1]
}
}
......@@ -384,7 +427,7 @@ func (h *Host) onOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpCo
// Check if we are entering a new depth, add it to the call-stack if so.
// We do this here, instead of onEnter, to capture an initialized scope.
if len(h.callStack) == 0 || h.callStack[len(h.callStack)-1].Depth < depth {
h.callStack = append(h.callStack, CallFrame{
h.callStack = append(h.callStack, &CallFrame{
Depth: depth,
LastOp: vm.OpCode(op),
LastPC: pc,
......@@ -395,7 +438,7 @@ func (h *Host) onOpcode(pc uint64, op byte, gas, cost uint64, scope tracing.OpCo
if len(h.callStack) == 0 || h.callStack[len(h.callStack)-1].Ctx != scopeCtx {
panic("scope context changed without call-frame pop/push")
}
cf := &h.callStack[len(h.callStack)-1]
cf := h.callStack[len(h.callStack)-1]
if vm.OpCode(op) == vm.JUMPDEST { // remember the last PC before successful jump
cf.LastJumps = append(cf.LastJumps, cf.LastPC)
if len(cf.LastJumps) > jumpHistory {
......@@ -429,7 +472,7 @@ func (h *Host) CurrentCall() CallFrame {
if len(h.callStack) == 0 {
return CallFrame{}
}
return h.callStack[len(h.callStack)-1]
return *h.callStack[len(h.callStack)-1]
}
// MsgSender returns the msg.sender of the current active EVM call-frame,
......
package script
import (
"fmt"
"strings"
"testing"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
"github.com/stretchr/testify/require"
......@@ -35,3 +40,57 @@ func TestScript(t *testing.T) {
// and a second time, to see if we can revisit the host state.
require.NoError(t, h.cheatcodes.Precompile.DumpState("noop"))
}
func TestScriptBroadcast(t *testing.T) {
logger := testlog.Logger(t, log.LevelDebug)
af := foundry.OpenArtifactsDir("./testdata/test-artifacts")
mustEncodeCalldata := func(method, input string) []byte {
packer, err := abi.JSON(strings.NewReader(fmt.Sprintf(`[{"type":"function","name":"%s","inputs":[{"type":"string","name":"input"}]}]`, method)))
require.NoError(t, err)
data, err := packer.Pack(method, input)
require.NoError(t, err)
return data
}
senderAddr := common.HexToAddress("0x5b73C5498c1E3b4dbA84de0F1833c4a029d90519")
expBroadcasts := []Broadcast{
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call1", "single_call1"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call1", "startstop_call1"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("call2", "startstop_call2"),
},
{
From: senderAddr,
To: senderAddr,
Calldata: mustEncodeCalldata("nested1", "nested"),
},
}
scriptContext := DefaultContext
var broadcasts []Broadcast
hook := func(broadcast Broadcast) {
broadcasts = append(broadcasts, broadcast)
}
h := NewHost(logger, af, nil, scriptContext, WithBroadcastHook(hook))
addr, err := h.LoadContract("ScriptExample.s.sol", "ScriptExample")
require.NoError(t, err)
require.NoError(t, h.EnableCheats())
input := bytes4("runBroadcast()")
returnData, _, err := h.Call(scriptContext.Sender, addr, input[:], DefaultFoundryGasLimit, uint256.NewInt(0))
require.NoError(t, err, "call failed: %x", string(returnData))
require.EqualValues(t, expBroadcasts, broadcasts)
}
......@@ -8,6 +8,9 @@ interface Vm {
function parseJsonKeys(string calldata json, string calldata key) external pure returns (string[] memory keys);
function startPrank(address msgSender) external;
function stopPrank() external;
function broadcast() external;
function startBroadcast() external;
function stopBroadcast() external;
}
// console is a minimal version of the console2 lib.
......@@ -64,6 +67,9 @@ contract ScriptExample {
address internal constant VM_ADDRESS = address(uint160(uint256(keccak256("hevm cheat code"))));
Vm internal constant vm = Vm(VM_ADDRESS);
// @notice counter variable to force non-pure calls.
uint256 public counter;
/// @notice example function, runs through basic cheat-codes and console logs.
function run() public {
bool x = vm.envOr("EXAMPLE_BOOL", false);
......@@ -90,9 +96,54 @@ contract ScriptExample {
console.log("done!");
}
/// @notice example function, to test vm.broadcast with.
function runBroadcast() public {
console.log("testing single");
vm.broadcast();
this.call1("single_call1");
this.call2("single_call2");
console.log("testing start/stop");
vm.startBroadcast();
this.call1("startstop_call1");
this.call2("startstop_call2");
this.callPure("startstop_pure");
vm.stopBroadcast();
this.call1("startstop_call3");
console.log("testing nested");
vm.startBroadcast();
this.nested1("nested");
vm.stopBroadcast();
}
/// @notice example external function, to force a CALL, and test vm.startPrank with.
function hello(string calldata _v) external view {
console.log(_v);
console.log("hello msg.sender", address(msg.sender));
}
function call1(string calldata _v) external {
counter++;
console.log(_v);
}
function call2(string calldata _v) external {
counter++;
console.log(_v);
}
function nested1(string calldata _v) external {
counter++;
this.nested2(_v);
}
function nested2(string calldata _v) external {
counter++;
console.log(_v);
}
function callPure(string calldata _v) external pure {
console.log(_v);
}
}
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment