Commit 29dd98b0 authored by mbaxter's avatar mbaxter Committed by GitHub

cannon: Clean up fuzz test todos (#12009)

* cannon: Add memory assertions to FuzzStatePreimageRead

* cannon: Rework hint write fuzz test to assert hint expectations

* cannon: Update FuzzStatePreimageWrite to assert on expected preimageKey

* cannon: Remove validation skipping logic from test util

* cannon: Cleanup - simplify code

* cannon: Cleanup - dedupe code
parent 4806d83e
...@@ -2,7 +2,6 @@ package testutil ...@@ -2,7 +2,6 @@ package testutil
import ( import (
"math" "math"
"math/rand"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
...@@ -23,14 +22,14 @@ func NewStateMutatorMultiThreaded(state *multithreaded.State) testutil.StateMuta ...@@ -23,14 +22,14 @@ func NewStateMutatorMultiThreaded(state *multithreaded.State) testutil.StateMuta
} }
func (m *StateMutatorMultiThreaded) Randomize(randSeed int64) { func (m *StateMutatorMultiThreaded) Randomize(randSeed int64) {
r := rand.New(rand.NewSource(randSeed)) r := testutil.NewRandHelper(randSeed)
step := testutil.RandStep(r) step := r.RandStep()
m.state.PreimageKey = testutil.RandHash(r) m.state.PreimageKey = r.RandHash()
m.state.PreimageOffset = r.Uint32() m.state.PreimageOffset = r.Uint32()
m.state.Step = step m.state.Step = step
m.state.LastHint = testutil.RandHint(r) m.state.LastHint = r.RandHint()
m.state.StepsSinceLastContextSwitch = uint64(r.Intn(exec.SchedQuantum)) m.state.StepsSinceLastContextSwitch = uint64(r.Intn(exec.SchedQuantum))
// Randomize memory-related fields // Randomize memory-related fields
......
package testutil package testutil
import ( import (
"math/rand"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/multithreaded" "github.com/ethereum-optimism/optimism/cannon/mipsevm/multithreaded"
"github.com/ethereum-optimism/optimism/cannon/mipsevm/testutil" "github.com/ethereum-optimism/optimism/cannon/mipsevm/testutil"
) )
func RandomThread(randSeed int64) *multithreaded.ThreadState { func RandomThread(randSeed int64) *multithreaded.ThreadState {
r := rand.New(rand.NewSource(randSeed)) r := testutil.NewRandHelper(randSeed)
thread := multithreaded.CreateEmptyThread() thread := multithreaded.CreateEmptyThread()
pc := testutil.RandPC(r) pc := r.RandPC()
thread.Registers = *testutil.RandRegisters(r) thread.Registers = *r.RandRegisters()
thread.Cpu.PC = pc thread.Cpu.PC = pc
thread.Cpu.NextPC = pc + 4 thread.Cpu.NextPC = pc + 4
thread.Cpu.HI = r.Uint32() thread.Cpu.HI = r.Uint32()
......
package testutil package testutil
import ( import (
"math/rand"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
...@@ -15,12 +13,12 @@ type StateMutatorSingleThreaded struct { ...@@ -15,12 +13,12 @@ type StateMutatorSingleThreaded struct {
} }
func (m *StateMutatorSingleThreaded) Randomize(randSeed int64) { func (m *StateMutatorSingleThreaded) Randomize(randSeed int64) {
r := rand.New(rand.NewSource(randSeed)) r := testutil.NewRandHelper(randSeed)
pc := testutil.RandPC(r) pc := r.RandPC()
step := testutil.RandStep(r) step := r.RandStep()
m.state.PreimageKey = testutil.RandHash(r) m.state.PreimageKey = r.RandHash()
m.state.PreimageOffset = r.Uint32() m.state.PreimageOffset = r.Uint32()
m.state.Cpu.PC = pc m.state.Cpu.PC = pc
m.state.Cpu.NextPC = pc + 4 m.state.Cpu.NextPC = pc + 4
...@@ -28,8 +26,8 @@ func (m *StateMutatorSingleThreaded) Randomize(randSeed int64) { ...@@ -28,8 +26,8 @@ func (m *StateMutatorSingleThreaded) Randomize(randSeed int64) {
m.state.Cpu.LO = r.Uint32() m.state.Cpu.LO = r.Uint32()
m.state.Heap = r.Uint32() m.state.Heap = r.Uint32()
m.state.Step = step m.state.Step = step
m.state.LastHint = testutil.RandHint(r) m.state.LastHint = r.RandHint()
m.state.Registers = *testutil.RandRegisters(r) m.state.Registers = *r.RandRegisters()
} }
var _ testutil.StateMutator = (*StateMutatorSingleThreaded)(nil) var _ testutil.StateMutator = (*StateMutatorSingleThreaded)(nil)
......
...@@ -2,10 +2,11 @@ package tests ...@@ -2,10 +2,11 @@ package tests
import ( import (
"bytes" "bytes"
"encoding/binary"
"math"
"os" "os"
"testing" "testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -216,24 +217,29 @@ func FuzzStateHintRead(f *testing.F) { ...@@ -216,24 +217,29 @@ func FuzzStateHintRead(f *testing.F) {
func FuzzStatePreimageRead(f *testing.F) { func FuzzStatePreimageRead(f *testing.F) {
versions := GetMipsVersionTestCases(f) versions := GetMipsVersionTestCases(f)
f.Fuzz(func(t *testing.T, addr uint32, count uint32, preimageOffset uint32, seed int64) { f.Fuzz(func(t *testing.T, addr uint32, pc uint32, count uint32, preimageOffset uint32, seed int64) {
for _, v := range versions { for _, v := range versions {
t.Run(v.Name, func(t *testing.T) { t.Run(v.Name, func(t *testing.T) {
effAddr := addr & 0xFF_FF_FF_FC
pc = pc & 0xFF_FF_FF_FC
preexistingMemoryVal := [4]byte{0xFF, 0xFF, 0xFF, 0xFF}
preimageValue := []byte("hello world") preimageValue := []byte("hello world")
if preimageOffset >= uint32(len(preimageValue)) { preimageData := testutil.AddPreimageLengthPrefix(preimageValue)
if preimageOffset >= uint32(len(preimageData)) || pc == effAddr {
t.SkipNow() t.SkipNow()
} }
preimageKey := preimage.Keccak256Key(crypto.Keccak256Hash(preimageValue)).PreimageKey() preimageKey := preimage.Keccak256Key(crypto.Keccak256Hash(preimageValue)).PreimageKey()
oracle := testutil.StaticOracle(t, preimageValue) oracle := testutil.StaticOracle(t, preimageValue)
goVm := v.VMFactory(oracle, os.Stdout, os.Stderr, testutil.CreateLogger(), goVm := v.VMFactory(oracle, os.Stdout, os.Stderr, testutil.CreateLogger(),
testutil.WithRandomization(seed), testutil.WithPreimageKey(preimageKey), testutil.WithPreimageOffset(preimageOffset)) testutil.WithRandomization(seed), testutil.WithPreimageKey(preimageKey), testutil.WithPreimageOffset(preimageOffset), testutil.WithPCAndNextPC(pc))
state := goVm.GetState() state := goVm.GetState()
state.GetRegistersRef()[2] = exec.SysRead state.GetRegistersRef()[2] = exec.SysRead
state.GetRegistersRef()[4] = exec.FdPreimageRead state.GetRegistersRef()[4] = exec.FdPreimageRead
state.GetRegistersRef()[5] = addr state.GetRegistersRef()[5] = addr
state.GetRegistersRef()[6] = count state.GetRegistersRef()[6] = count
state.GetMemory().SetMemory(state.GetPC(), syscallInsn) state.GetMemory().SetMemory(state.GetPC(), syscallInsn)
state.GetMemory().SetMemory(effAddr, binary.BigEndian.Uint32(preexistingMemoryVal[:]))
step := state.GetStep() step := state.GetStep()
alignment := addr & 3 alignment := addr & 3
...@@ -242,7 +248,7 @@ func FuzzStatePreimageRead(f *testing.F) { ...@@ -242,7 +248,7 @@ func FuzzStatePreimageRead(f *testing.F) {
writeLen = count writeLen = count
} }
// Cap write length to remaining bytes of the preimage // Cap write length to remaining bytes of the preimage
preimageDataLen := uint32(len(preimageValue) + 8) // Data len includes a length prefix preimageDataLen := uint32(len(preimageData))
if preimageOffset+writeLen > preimageDataLen { if preimageOffset+writeLen > preimageDataLen {
writeLen = preimageDataLen - preimageOffset writeLen = preimageDataLen - preimageOffset
} }
...@@ -254,18 +260,18 @@ func FuzzStatePreimageRead(f *testing.F) { ...@@ -254,18 +260,18 @@ func FuzzStatePreimageRead(f *testing.F) {
expected.Registers[2] = writeLen expected.Registers[2] = writeLen
expected.Registers[7] = 0 // no error expected.Registers[7] = 0 // no error
expected.PreimageOffset += writeLen expected.PreimageOffset += writeLen
if writeLen > 0 {
// Expect a memory write
expectedMemory := preexistingMemoryVal
copy(expectedMemory[alignment:], preimageData[preimageOffset:preimageOffset+writeLen])
expected.ExpectMemoryWrite(effAddr, binary.BigEndian.Uint32(expectedMemory[:]))
}
stepWitness, err := goVm.Step(true) stepWitness, err := goVm.Step(true)
require.NoError(t, err) require.NoError(t, err)
require.True(t, stepWitness.HasPreimage()) require.True(t, stepWitness.HasPreimage())
// TODO(cp-983) - Do stricter validation of expected memory expected.Validate(t, state)
expected.Validate(t, state, testutil.SkipMemoryValidation)
if writeLen == 0 {
// Note: We are not asserting a memory root change when writeLen > 0 because we may not necessarily
// modify memory - it's possible we just write the leading zero bytes of the length prefix
require.Equal(t, expected.MemoryRoot, common.Hash(state.GetMemory().MerkleRoot()))
}
testutil.ValidateEVM(t, stepWitness, step, goVm, v.StateHashFn, v.Contracts, nil) testutil.ValidateEVM(t, stepWitness, step, goVm, v.StateHashFn, v.Contracts, nil)
}) })
} }
...@@ -274,43 +280,80 @@ func FuzzStatePreimageRead(f *testing.F) { ...@@ -274,43 +280,80 @@ func FuzzStatePreimageRead(f *testing.F) {
func FuzzStateHintWrite(f *testing.F) { func FuzzStateHintWrite(f *testing.F) {
versions := GetMipsVersionTestCases(f) versions := GetMipsVersionTestCases(f)
f.Fuzz(func(t *testing.T, addr uint32, count uint32, randSeed int64) { f.Fuzz(func(t *testing.T, addr uint32, count uint32, hint1, hint2, hint3 []byte, randSeed int64) {
for _, v := range versions { for _, v := range versions {
t.Run(v.Name, func(t *testing.T) { t.Run(v.Name, func(t *testing.T) {
preimageData := []byte("hello world") // Make sure pc does not overlap with hint data in memory
preimageKey := preimage.Keccak256Key(crypto.Keccak256Hash(preimageData)).PreimageKey() pc := uint32(0)
// TODO(cp-983) - use testutil.HintTrackingOracle, validate expected hints if addr <= 8 {
oracle := testutil.StaticOracle(t, preimageData) // only used for hinting addr += 8
}
// Set up hint data
r := testutil.NewRandHelper(randSeed)
hints := [][]byte{hint1, hint2, hint3}
hintData := make([]byte, 0)
for _, hint := range hints {
prefixedHint := testutil.AddHintLengthPrefix(hint)
hintData = append(hintData, prefixedHint...)
}
lastHintLen := math.Round(r.Fraction() * float64(len(hintData)))
lastHint := hintData[:int(lastHintLen)]
expectedBytesToProcess := int(count) + int(lastHintLen)
if expectedBytesToProcess > len(hintData) {
// Add an extra hint to span the rest of the hint data
randomHint := r.RandomBytes(t, expectedBytesToProcess)
prefixedHint := testutil.AddHintLengthPrefix(randomHint)
hintData = append(hintData, prefixedHint...)
hints = append(hints, randomHint)
}
// Set up state
oracle := &testutil.HintTrackingOracle{}
goVm := v.VMFactory(oracle, os.Stdout, os.Stderr, testutil.CreateLogger(), goVm := v.VMFactory(oracle, os.Stdout, os.Stderr, testutil.CreateLogger(),
testutil.WithRandomization(randSeed), testutil.WithPreimageKey(preimageKey)) testutil.WithRandomization(randSeed), testutil.WithLastHint(lastHint), testutil.WithPCAndNextPC(pc))
state := goVm.GetState() state := goVm.GetState()
state.GetRegistersRef()[2] = exec.SysWrite state.GetRegistersRef()[2] = exec.SysWrite
state.GetRegistersRef()[4] = exec.FdHintWrite state.GetRegistersRef()[4] = exec.FdHintWrite
state.GetRegistersRef()[5] = addr state.GetRegistersRef()[5] = addr
state.GetRegistersRef()[6] = count state.GetRegistersRef()[6] = count
step := state.GetStep() step := state.GetStep()
err := state.GetMemory().SetMemoryRange(addr, bytes.NewReader(hintData[int(lastHintLen):]))
// Set random data at the target memory range
randBytes := testutil.RandomBytes(t, randSeed, count)
err := state.GetMemory().SetMemoryRange(addr, bytes.NewReader(randBytes))
require.NoError(t, err) require.NoError(t, err)
// Set instruction
state.GetMemory().SetMemory(state.GetPC(), syscallInsn) state.GetMemory().SetMemory(state.GetPC(), syscallInsn)
// Set up expectations
expected := testutil.NewExpectedState(state) expected := testutil.NewExpectedState(state)
expected.Step += 1 expected.Step += 1
expected.PC = state.GetCpu().NextPC expected.PC = state.GetCpu().NextPC
expected.NextPC = state.GetCpu().NextPC + 4 expected.NextPC = state.GetCpu().NextPC + 4
expected.Registers[2] = count expected.Registers[2] = count
expected.Registers[7] = 0 // no error expected.Registers[7] = 0 // no error
// Figure out hint expectations
var expectedHints [][]byte
expectedLastHint := make([]byte, 0)
byteIndex := 0
for _, hint := range hints {
hintDataLength := len(hint) + 4 // Hint data + prefix
hintLastByteIndex := hintDataLength + byteIndex - 1
if hintLastByteIndex < expectedBytesToProcess {
expectedHints = append(expectedHints, hint)
} else {
expectedLastHint = hintData[byteIndex:expectedBytesToProcess]
break
}
byteIndex += hintDataLength
}
expected.LastHint = expectedLastHint
// Run state transition
stepWitness, err := goVm.Step(true) stepWitness, err := goVm.Step(true)
require.NoError(t, err) require.NoError(t, err)
require.False(t, stepWitness.HasPreimage()) require.False(t, stepWitness.HasPreimage())
// TODO(cp-983) - validate expected hints // Validate
expected.Validate(t, state, testutil.SkipHintValidation) require.Equal(t, expectedHints, oracle.Hints())
expected.Validate(t, state)
testutil.ValidateEVM(t, stepWitness, step, goVm, v.StateHashFn, v.Contracts, nil) testutil.ValidateEVM(t, stepWitness, step, goVm, v.StateHashFn, v.Contracts, nil)
}) })
} }
...@@ -322,23 +365,33 @@ func FuzzStatePreimageWrite(f *testing.F) { ...@@ -322,23 +365,33 @@ func FuzzStatePreimageWrite(f *testing.F) {
f.Fuzz(func(t *testing.T, addr uint32, count uint32, seed int64) { f.Fuzz(func(t *testing.T, addr uint32, count uint32, seed int64) {
for _, v := range versions { for _, v := range versions {
t.Run(v.Name, func(t *testing.T) { t.Run(v.Name, func(t *testing.T) {
// Make sure pc does not overlap with preimage data in memory
pc := uint32(0)
if addr <= 8 {
addr += 8
}
effAddr := addr & 0xFF_FF_FF_FC
preexistingMemoryVal := [4]byte{0x12, 0x34, 0x56, 0x78}
preimageData := []byte("hello world") preimageData := []byte("hello world")
preimageKey := preimage.Keccak256Key(crypto.Keccak256Hash(preimageData)).PreimageKey() preimageKey := preimage.Keccak256Key(crypto.Keccak256Hash(preimageData)).PreimageKey()
oracle := testutil.StaticOracle(t, preimageData) oracle := testutil.StaticOracle(t, preimageData)
goVm := v.VMFactory(oracle, os.Stdout, os.Stderr, testutil.CreateLogger(), goVm := v.VMFactory(oracle, os.Stdout, os.Stderr, testutil.CreateLogger(),
testutil.WithRandomization(seed), testutil.WithPreimageKey(preimageKey), testutil.WithPreimageOffset(128)) testutil.WithRandomization(seed), testutil.WithPreimageKey(preimageKey), testutil.WithPreimageOffset(128), testutil.WithPCAndNextPC(pc))
state := goVm.GetState() state := goVm.GetState()
state.GetRegistersRef()[2] = exec.SysWrite state.GetRegistersRef()[2] = exec.SysWrite
state.GetRegistersRef()[4] = exec.FdPreimageWrite state.GetRegistersRef()[4] = exec.FdPreimageWrite
state.GetRegistersRef()[5] = addr state.GetRegistersRef()[5] = addr
state.GetRegistersRef()[6] = count state.GetRegistersRef()[6] = count
state.GetMemory().SetMemory(state.GetPC(), syscallInsn) state.GetMemory().SetMemory(state.GetPC(), syscallInsn)
state.GetMemory().SetMemory(effAddr, binary.BigEndian.Uint32(preexistingMemoryVal[:]))
step := state.GetStep() step := state.GetStep()
sz := 4 - (addr & 0x3) expectBytesWritten := count
if sz < count { alignment := addr & 0x3
count = sz sz := 4 - alignment
if sz < expectBytesWritten {
expectBytesWritten = sz
} }
expected := testutil.NewExpectedState(state) expected := testutil.NewExpectedState(state)
...@@ -346,15 +399,21 @@ func FuzzStatePreimageWrite(f *testing.F) { ...@@ -346,15 +399,21 @@ func FuzzStatePreimageWrite(f *testing.F) {
expected.PC = state.GetCpu().NextPC expected.PC = state.GetCpu().NextPC
expected.NextPC = state.GetCpu().NextPC + 4 expected.NextPC = state.GetCpu().NextPC + 4
expected.PreimageOffset = 0 expected.PreimageOffset = 0
expected.Registers[2] = count expected.Registers[2] = expectBytesWritten
expected.Registers[7] = 0 // No error expected.Registers[7] = 0 // No error
expected.PreimageKey = preimageKey
if expectBytesWritten > 0 {
// Copy original preimage key, but shift it left by expectBytesWritten
copy(expected.PreimageKey[:], preimageKey[expectBytesWritten:])
// Copy memory data to rightmost expectedBytesWritten
copy(expected.PreimageKey[32-expectBytesWritten:], preexistingMemoryVal[alignment:])
}
stepWitness, err := goVm.Step(true) stepWitness, err := goVm.Step(true)
require.NoError(t, err) require.NoError(t, err)
require.False(t, stepWitness.HasPreimage()) require.False(t, stepWitness.HasPreimage())
// TODO(cp-983) - validate preimage key expected.Validate(t, state)
expected.Validate(t, state, testutil.SkipPreimageKeyValidation)
testutil.ValidateEVM(t, stepWitness, step, goVm, v.StateHashFn, v.Contracts, nil) testutil.ValidateEVM(t, stepWitness, step, goVm, v.StateHashFn, v.Contracts, nil)
}) })
} }
......
package testutil package testutil
import ( import (
"encoding/binary"
"math/rand" "math/rand"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func RandHash(r *rand.Rand) common.Hash { type RandHelper struct {
r *rand.Rand
}
func NewRandHelper(seed int64) *RandHelper {
r := rand.New(rand.NewSource(seed))
return &RandHelper{r: r}
}
func (h *RandHelper) Uint32() uint32 {
return h.r.Uint32()
}
func (h *RandHelper) Fraction() float64 {
return h.r.Float64()
}
func (h *RandHelper) Intn(n int) int {
return h.r.Intn(n)
}
func (h *RandHelper) RandHash() common.Hash {
var bytes [32]byte var bytes [32]byte
_, err := r.Read(bytes[:]) _, err := h.r.Read(bytes[:])
if err != nil { if err != nil {
panic(err) panic(err)
} }
return bytes return bytes
} }
func RandHint(r *rand.Rand) []byte { func (h *RandHelper) RandHint() []byte {
count := r.Intn(10)
bytes := make([]byte, count) bytesCount := h.r.Intn(24)
_, err := r.Read(bytes[:]) bytes := make([]byte, bytesCount)
if err != nil {
panic(err) if bytesCount >= 8 {
// Set up a reasonable length prefix
nextHintLen := uint64(h.r.Intn(30))
binary.BigEndian.PutUint64(bytes, nextHintLen)
_, err := h.r.Read(bytes[8:])
if err != nil {
panic(err)
}
} }
return bytes return bytes
} }
func RandRegisters(r *rand.Rand) *[32]uint32 { func (h *RandHelper) RandRegisters() *[32]uint32 {
registers := new([32]uint32) registers := new([32]uint32)
for i := 0; i < 32; i++ { for i := 0; i < 32; i++ {
registers[i] = r.Uint32() registers[i] = h.r.Uint32()
} }
return registers return registers
} }
func RandomBytes(t require.TestingT, seed int64, length uint32) []byte { func (h *RandHelper) RandomBytes(t require.TestingT, length int) []byte {
r := rand.New(rand.NewSource(seed))
randBytes := make([]byte, length) randBytes := make([]byte, length)
if _, err := r.Read(randBytes); err != nil { if _, err := h.r.Read(randBytes); err != nil {
require.NoError(t, err) require.NoError(t, err)
} }
return randBytes return randBytes
} }
func RandPC(r *rand.Rand) uint32 { func (h *RandHelper) RandPC() uint32 {
return AlignPC(r.Uint32()) return AlignPC(h.r.Uint32())
} }
func RandStep(r *rand.Rand) uint64 { func (h *RandHelper) RandStep() uint64 {
return BoundStep(r.Uint64()) return BoundStep(h.r.Uint64())
} }
package testutil package testutil
import ( import (
"encoding/binary"
"fmt" "fmt"
"slices"
"testing" "testing"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
...@@ -13,10 +13,22 @@ import ( ...@@ -13,10 +13,22 @@ import (
"github.com/ethereum-optimism/optimism/cannon/mipsevm/memory" "github.com/ethereum-optimism/optimism/cannon/mipsevm/memory"
) )
func CopyRegisters(state mipsevm.FPVMState) *[32]uint32 { func AddHintLengthPrefix(data []byte) []byte {
copy := new([32]uint32) dataLen := len(data)
*copy = *state.GetRegistersRef() prefixed := make([]byte, 0, dataLen+4)
return copy prefixed = binary.BigEndian.AppendUint32(prefixed, uint32(dataLen))
prefixed = append(prefixed, data...)
return prefixed
}
func AddPreimageLengthPrefix(data []byte) []byte {
dataLen := len(data)
prefixed := make([]byte, 0, dataLen+8)
prefixed = binary.BigEndian.AppendUint64(prefixed, uint64(dataLen))
prefixed = append(prefixed, data...)
return prefixed
} }
type StateMutator interface { type StateMutator interface {
...@@ -48,6 +60,13 @@ func WithNextPC(nextPC uint32) StateOption { ...@@ -48,6 +60,13 @@ func WithNextPC(nextPC uint32) StateOption {
} }
} }
func WithPCAndNextPC(pc uint32) StateOption {
return func(state StateMutator) {
state.SetPC(pc)
state.SetNextPC(pc + 4)
}
}
func WithHeap(addr uint32) StateOption { func WithHeap(addr uint32) StateOption {
return func(state StateMutator) { return func(state StateMutator) {
state.SetHeap(addr) state.SetHeap(addr)
...@@ -150,19 +169,8 @@ func (e *ExpectedState) ExpectMemoryWrite(addr uint32, val uint32) { ...@@ -150,19 +169,8 @@ func (e *ExpectedState) ExpectMemoryWrite(addr uint32, val uint32) {
e.MemoryRoot = e.expectedMemory.MerkleRoot() e.MemoryRoot = e.expectedMemory.MerkleRoot()
} }
type StateValidationFlags int func (e *ExpectedState) Validate(t testing.TB, actualState mipsevm.FPVMState) {
require.Equal(t, e.PreimageKey, actualState.GetPreimageKey(), fmt.Sprintf("Expect preimageKey = %v", e.PreimageKey))
// TODO(cp-983) - Remove these validation hacks
const (
SkipMemoryValidation StateValidationFlags = iota
SkipHintValidation
SkipPreimageKeyValidation
)
func (e *ExpectedState) Validate(t testing.TB, actualState mipsevm.FPVMState, flags ...StateValidationFlags) {
if !slices.Contains(flags, SkipPreimageKeyValidation) {
require.Equal(t, e.PreimageKey, actualState.GetPreimageKey(), fmt.Sprintf("Expect preimageKey = %v", e.PreimageKey))
}
require.Equal(t, e.PreimageOffset, actualState.GetPreimageOffset(), fmt.Sprintf("Expect preimageOffset = %v", e.PreimageOffset)) require.Equal(t, e.PreimageOffset, actualState.GetPreimageOffset(), fmt.Sprintf("Expect preimageOffset = %v", e.PreimageOffset))
require.Equal(t, e.PC, actualState.GetCpu().PC, fmt.Sprintf("Expect PC = 0x%x", e.PC)) require.Equal(t, e.PC, actualState.GetCpu().PC, fmt.Sprintf("Expect PC = 0x%x", e.PC))
require.Equal(t, e.NextPC, actualState.GetCpu().NextPC, fmt.Sprintf("Expect nextPC = 0x%x", e.NextPC)) require.Equal(t, e.NextPC, actualState.GetCpu().NextPC, fmt.Sprintf("Expect nextPC = 0x%x", e.NextPC))
...@@ -172,11 +180,7 @@ func (e *ExpectedState) Validate(t testing.TB, actualState mipsevm.FPVMState, fl ...@@ -172,11 +180,7 @@ func (e *ExpectedState) Validate(t testing.TB, actualState mipsevm.FPVMState, fl
require.Equal(t, e.ExitCode, actualState.GetExitCode(), fmt.Sprintf("Expect exitCode = 0x%x", e.ExitCode)) require.Equal(t, e.ExitCode, actualState.GetExitCode(), fmt.Sprintf("Expect exitCode = 0x%x", e.ExitCode))
require.Equal(t, e.Exited, actualState.GetExited(), fmt.Sprintf("Expect exited = %v", e.Exited)) require.Equal(t, e.Exited, actualState.GetExited(), fmt.Sprintf("Expect exited = %v", e.Exited))
require.Equal(t, e.Step, actualState.GetStep(), fmt.Sprintf("Expect step = %d", e.Step)) require.Equal(t, e.Step, actualState.GetStep(), fmt.Sprintf("Expect step = %d", e.Step))
if !slices.Contains(flags, SkipHintValidation) { require.Equal(t, e.LastHint, actualState.GetLastHint(), fmt.Sprintf("Expect lastHint = %v", e.LastHint))
require.Equal(t, e.LastHint, actualState.GetLastHint(), fmt.Sprintf("Expect lastHint = %v", e.LastHint))
}
require.Equal(t, e.Registers, *actualState.GetRegistersRef(), fmt.Sprintf("Expect registers = %v", e.Registers)) require.Equal(t, e.Registers, *actualState.GetRegistersRef(), fmt.Sprintf("Expect registers = %v", e.Registers))
if !slices.Contains(flags, SkipMemoryValidation) { require.Equal(t, e.MemoryRoot, common.Hash(actualState.GetMemory().MerkleRoot()), fmt.Sprintf("Expect memory root = %v", e.MemoryRoot))
require.Equal(t, e.MemoryRoot, common.Hash(actualState.GetMemory().MerkleRoot()), fmt.Sprintf("Expect memory root = %v", e.MemoryRoot))
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment