Commit 03b526d3 authored by John Chase's avatar John Chase Committed by GitHub

MTCannon: Fix `AssertEVMReverts` to correctly construct data (#12200)

* 1. Added `WithRegisters` for initializing state's `Registers`.
2. Added `ThreadProofEncoder` helper function to generate `ThreadProof`.
3. Updated `AssertEVMReverts` to include `statePCs` for multiple memory requirements, added `ProofData` parameter for passing `ThreadProof`, and introduced `expectedReason` parameter for more precise testing.
4. Revised `TestEVMFault` and `TestEVM_UnsupportedSyscall` test functions to be compatible with `AssertEVMReverts`.

* lint fix and comment fix

* avoid false negatives

* delete WithRegisters && create ProofGenerator

* assert multi outputs

* fix naming issue

* Update cannon/mipsevm/tests/evm_multithreaded_test.go

Nice catch!
Co-authored-by: default avatarmbaxter <meredith.a.baxter@gmail.com>

* Small fix

* link check

* make expectedReason a value & delete nil check

---------
Co-authored-by: default avatarmbaxter <meredith.a.baxter@gmail.com>
parent d37f753e
...@@ -234,7 +234,6 @@ func (s *State) EncodeThreadProof() []byte { ...@@ -234,7 +234,6 @@ func (s *State) EncodeThreadProof() []byte {
out := make([]byte, 0, THREAD_WITNESS_SIZE) out := make([]byte, 0, THREAD_WITNESS_SIZE)
out = append(out, threadBytes[:]...) out = append(out, threadBytes[:]...)
out = append(out, otherThreadsWitness[:]...) out = append(out, otherThreadsWitness[:]...)
return out return out
} }
......
...@@ -583,10 +583,12 @@ func TestEVMFault(t *testing.T) { ...@@ -583,10 +583,12 @@ func TestEVMFault(t *testing.T) {
name string name string
nextPC arch.Word nextPC arch.Word
insn uint32 insn uint32
errMsg string
memoryProofAddresses []uint32
}{ }{
{"illegal instruction", 0, 0xFF_FF_FF_FF}, {"illegal instruction", 0, 0xFF_FF_FF_FF, "invalid instruction", []uint32{0xa7ef00cc}},
{"branch in delay-slot", 8, 0x11_02_00_03}, {"branch in delay-slot", 8, 0x11_02_00_03, "branch in delay slot", []uint32{0}},
{"jump in delay-slot", 8, 0x0c_00_00_0c}, {"jump in delay-slot", 8, 0x0c_00_00_0c, "jump in delay slot", []uint32{0}},
} }
for _, v := range versions { for _, v := range versions {
...@@ -599,8 +601,9 @@ func TestEVMFault(t *testing.T) { ...@@ -599,8 +601,9 @@ func TestEVMFault(t *testing.T) {
// set the return address ($ra) to jump into when test completes // set the return address ($ra) to jump into when test completes
state.GetRegistersRef()[31] = testutil.EndAddr state.GetRegistersRef()[31] = testutil.EndAddr
proofData := v.ProofGenerator(t, goVm.GetState(), tt.memoryProofAddresses...)
require.Panics(t, func() { _, _ = goVm.Step(true) }) require.Panics(t, func() { _, _ = goVm.Step(true) })
testutil.AssertEVMReverts(t, state, v.Contracts, tracer) testutil.AssertEVMReverts(t, state, v.Contracts, tracer, proofData, tt.errMsg)
}) })
} }
} }
......
...@@ -1183,10 +1183,12 @@ func TestEVM_UnsupportedSyscall(t *testing.T) { ...@@ -1183,10 +1183,12 @@ func TestEVM_UnsupportedSyscall(t *testing.T) {
// Setup basic getThreadId syscall instruction // Setup basic getThreadId syscall instruction
state.Memory.SetUint32(state.GetPC(), syscallInsn) state.Memory.SetUint32(state.GetPC(), syscallInsn)
state.GetRegistersRef()[2] = Word(syscallNum) state.GetRegistersRef()[2] = Word(syscallNum)
proofData := multiThreadedProofGenerator(t, state)
// Set up post-state expectations // Set up post-state expectations
require.Panics(t, func() { _, _ = goVm.Step(true) }) require.Panics(t, func() { _, _ = goVm.Step(true) })
testutil.AssertEVMReverts(t, state, contracts, tracer)
errorMessage := "MIPS2: unimplemented syscall"
testutil.AssertEVMReverts(t, state, contracts, tracer, proofData, errorMessage)
}) })
} }
} }
......
...@@ -50,12 +50,47 @@ func multiThreadElfVmFactory(t require.TestingT, elfFile string, po mipsevm.Prei ...@@ -50,12 +50,47 @@ func multiThreadElfVmFactory(t require.TestingT, elfFile string, po mipsevm.Prei
return fpvm return fpvm
} }
type ProofGenerator func(t require.TestingT, state mipsevm.FPVMState, memoryProofAddresses ...uint32) []byte
func singalThreadedProofGenerator(t require.TestingT, state mipsevm.FPVMState, memoryProofAddresses ...uint32) []byte {
var proofData []byte
insnProof := state.GetMemory().MerkleProof(state.GetPC())
proofData = append(proofData, insnProof[:]...)
for _, addr := range memoryProofAddresses {
memProof := state.GetMemory().MerkleProof(addr)
proofData = append(proofData, memProof[:]...)
}
return proofData
}
func multiThreadedProofGenerator(t require.TestingT, state mipsevm.FPVMState, memoryProofAddresses ...uint32) []byte {
mtState, ok := state.(*multithreaded.State)
if !ok {
require.Fail(t, "Failed to cast FPVMState to multithreaded State type")
}
proofData := mtState.EncodeThreadProof()
insnProof := mtState.GetMemory().MerkleProof(mtState.GetPC())
proofData = append(proofData, insnProof[:]...)
for _, addr := range memoryProofAddresses {
memProof := mtState.GetMemory().MerkleProof(addr)
proofData = append(proofData, memProof[:]...)
}
return proofData
}
type VersionedVMTestCase struct { type VersionedVMTestCase struct {
Name string Name string
Contracts *testutil.ContractMetadata Contracts *testutil.ContractMetadata
StateHashFn mipsevm.HashFn StateHashFn mipsevm.HashFn
VMFactory VMFactory VMFactory VMFactory
ElfVMFactory ElfVMFactory ElfVMFactory ElfVMFactory
ProofGenerator ProofGenerator
} }
func GetSingleThreadedTestCase(t require.TestingT) VersionedVMTestCase { func GetSingleThreadedTestCase(t require.TestingT) VersionedVMTestCase {
...@@ -65,6 +100,7 @@ func GetSingleThreadedTestCase(t require.TestingT) VersionedVMTestCase { ...@@ -65,6 +100,7 @@ func GetSingleThreadedTestCase(t require.TestingT) VersionedVMTestCase {
StateHashFn: singlethreaded.GetStateHashFn(), StateHashFn: singlethreaded.GetStateHashFn(),
VMFactory: singleThreadedVmFactory, VMFactory: singleThreadedVmFactory,
ElfVMFactory: singleThreadElfVmFactory, ElfVMFactory: singleThreadElfVmFactory,
ProofGenerator: singalThreadedProofGenerator,
} }
} }
...@@ -75,6 +111,7 @@ func GetMultiThreadedTestCase(t require.TestingT) VersionedVMTestCase { ...@@ -75,6 +111,7 @@ func GetMultiThreadedTestCase(t require.TestingT) VersionedVMTestCase {
StateHashFn: multithreaded.GetStateHashFn(), StateHashFn: multithreaded.GetStateHashFn(),
VMFactory: multiThreadedVmFactory, VMFactory: multiThreadedVmFactory,
ElfVMFactory: multiThreadElfVmFactory, ElfVMFactory: multiThreadElfVmFactory,
ProofGenerator: multiThreadedProofGenerator,
} }
} }
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/state"
...@@ -182,12 +183,11 @@ func ValidateEVM(t *testing.T, stepWitness *mipsevm.StepWitness, step uint64, go ...@@ -182,12 +183,11 @@ func ValidateEVM(t *testing.T, stepWitness *mipsevm.StepWitness, step uint64, go
} }
// AssertEVMReverts runs a single evm step from an FPVM prestate and asserts that the VM panics // AssertEVMReverts runs a single evm step from an FPVM prestate and asserts that the VM panics
func AssertEVMReverts(t *testing.T, state mipsevm.FPVMState, contracts *ContractMetadata, tracer *tracing.Hooks) { func AssertEVMReverts(t *testing.T, state mipsevm.FPVMState, contracts *ContractMetadata, tracer *tracing.Hooks, ProofData []byte, expectedReason string) {
insnProof := state.GetMemory().MerkleProof(state.GetPC())
encodedWitness, _ := state.EncodeWitness() encodedWitness, _ := state.EncodeWitness()
stepWitness := &mipsevm.StepWitness{ stepWitness := &mipsevm.StepWitness{
State: encodedWitness, State: encodedWitness,
ProofData: insnProof[:], ProofData: ProofData,
} }
input := EncodeStepInput(t, stepWitness, mipsevm.LocalContext{}, contracts.Artifacts.MIPS) input := EncodeStepInput(t, stepWitness, mipsevm.LocalContext{}, contracts.Artifacts.MIPS)
startingGas := uint64(30_000_000) startingGas := uint64(30_000_000)
...@@ -195,8 +195,16 @@ func AssertEVMReverts(t *testing.T, state mipsevm.FPVMState, contracts *Contract ...@@ -195,8 +195,16 @@ func AssertEVMReverts(t *testing.T, state mipsevm.FPVMState, contracts *Contract
env, evmState := NewEVMEnv(contracts) env, evmState := NewEVMEnv(contracts)
env.Config.Tracer = tracer env.Config.Tracer = tracer
sender := common.Address{0x13, 0x37} sender := common.Address{0x13, 0x37}
_, _, err := env.Call(vm.AccountRef(sender), contracts.Addresses.MIPS, input, startingGas, common.U2560) ret, _, err := env.Call(vm.AccountRef(sender), contracts.Addresses.MIPS, input, startingGas, common.U2560)
require.EqualValues(t, err, vm.ErrExecutionReverted) require.EqualValues(t, err, vm.ErrExecutionReverted)
require.Greater(t, len(ret), 4, "Return data length should be greater than 4 bytes")
unpacked, decodeErr := abi.UnpackRevert(ret)
require.NoError(t, decodeErr, "Failed to unpack revert reason")
require.Equal(t, expectedReason, unpacked, "Revert reason mismatch")
logs := evmState.Logs() logs := evmState.Logs()
require.Equal(t, 0, len(logs)) require.Equal(t, 0, len(logs))
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment