Commit d1ea098c authored by Inphi's avatar Inphi Committed by GitHub

cannon: Fix post-state hash in proof (#11005)

* Reapply "cannon: Generic FPVM interface (#10993)" (#11004)

This reverts commit 124b1a9e.

* cannon: Fix post-state hash in proof

* remove PostStateHash from StepWitness
parent 98038608
...@@ -23,6 +23,13 @@ import ( ...@@ -23,6 +23,13 @@ import (
) )
var ( var (
RunType = &cli.StringFlag{
Name: "type",
Usage: "VM type to run. Options are 'cannon' (default)",
Value: "cannon",
// TODO(client-pod#903): This should be required once we have additional vm types
Required: false,
}
RunInputFlag = &cli.PathFlag{ RunInputFlag = &cli.PathFlag{
Name: "input", Name: "input",
Usage: "path of input JSON state. Stdin if left empty.", Usage: "path of input JSON state. Stdin if left empty.",
...@@ -250,14 +257,20 @@ func Guard(proc *os.ProcessState, fn StepFn) StepFn { ...@@ -250,14 +257,20 @@ func Guard(proc *os.ProcessState, fn StepFn) StepFn {
var _ mipsevm.PreimageOracle = (*ProcessPreimageOracle)(nil) var _ mipsevm.PreimageOracle = (*ProcessPreimageOracle)(nil)
type VMType string
var cannonVMType VMType = "cannon"
func Run(ctx *cli.Context) error { func Run(ctx *cli.Context) error {
if ctx.Bool(RunPProfCPU.Name) { if ctx.Bool(RunPProfCPU.Name) {
defer profile.Start(profile.NoShutdownHook, profile.ProfilePath("."), profile.CPUProfile).Stop() defer profile.Start(profile.NoShutdownHook, profile.ProfilePath("."), profile.CPUProfile).Stop()
} }
state, err := jsonutil.LoadJSON[mipsevm.State](ctx.Path(RunInputFlag.Name)) var vmType VMType
if err != nil { if vmTypeStr := ctx.String(RunType.Name); vmTypeStr == string(cannonVMType) {
return err vmType = cannonVMType
} else {
return fmt.Errorf("unknown VM type %q", vmType)
} }
guestLogger := Logger(os.Stderr, log.LevelInfo) guestLogger := Logger(os.Stderr, log.LevelInfo)
...@@ -349,53 +362,65 @@ func Run(ctx *cli.Context) error { ...@@ -349,53 +362,65 @@ func Run(ctx *cli.Context) error {
} }
} }
us := mipsevm.NewInstrumentedState(state, po, outLog, errLog) var vm mipsevm.FPVM
debugProgram := ctx.Bool(RunDebugFlag.Name) var debugProgram bool
if vmType == cannonVMType {
cannon, err := mipsevm.NewInstrumentedStateFromFile(ctx.Path(RunInputFlag.Name), po, outLog, errLog)
if err != nil {
return err
}
debugProgram = ctx.Bool(RunDebugFlag.Name)
if debugProgram { if debugProgram {
if metaPath := ctx.Path(RunMetaFlag.Name); metaPath == "" { if metaPath := ctx.Path(RunMetaFlag.Name); metaPath == "" {
return fmt.Errorf("cannot enable debug mode without a metadata file") return fmt.Errorf("cannot enable debug mode without a metadata file")
} }
if err := us.InitDebug(meta); err != nil { if err := cannon.InitDebug(meta); err != nil {
return fmt.Errorf("failed to initialize debug mode: %w", err) return fmt.Errorf("failed to initialize debug mode: %w", err)
} }
} }
vm = cannon
} else {
return fmt.Errorf("unknown VM type %q", vmType)
}
proofFmt := ctx.String(RunProofFmtFlag.Name) proofFmt := ctx.String(RunProofFmtFlag.Name)
snapshotFmt := ctx.String(RunSnapshotFmtFlag.Name) snapshotFmt := ctx.String(RunSnapshotFmtFlag.Name)
stepFn := us.Step stepFn := vm.Step
if po.cmd != nil { if po.cmd != nil {
stepFn = Guard(po.cmd.ProcessState, stepFn) stepFn = Guard(po.cmd.ProcessState, stepFn)
} }
start := time.Now() start := time.Now()
startStep := state.Step
state := vm.GetState()
startStep := state.GetStep()
// avoid symbol lookups every instruction by preparing a matcher func // avoid symbol lookups every instruction by preparing a matcher func
sleepCheck := meta.SymbolMatcher("runtime.notesleep") sleepCheck := meta.SymbolMatcher("runtime.notesleep")
for !state.Exited { for !state.GetExited() {
if state.Step%100 == 0 { // don't do the ctx err check (includes lock) too often step := state.GetStep()
if step%100 == 0 { // don't do the ctx err check (includes lock) too often
if err := ctx.Context.Err(); err != nil { if err := ctx.Context.Err(); err != nil {
return err return err
} }
} }
step := state.Step
if infoAt(state) { if infoAt(state) {
delta := time.Since(start) delta := time.Since(start)
l.Info("processing", l.Info("processing",
"step", step, "step", step,
"pc", mipsevm.HexU32(state.Cpu.PC), "pc", mipsevm.HexU32(state.GetPC()),
"insn", mipsevm.HexU32(state.Memory.GetMemory(state.Cpu.PC)), "insn", mipsevm.HexU32(state.GetMemory().GetMemory(state.GetPC())),
"ips", float64(step-startStep)/(float64(delta)/float64(time.Second)), "ips", float64(step-startStep)/(float64(delta)/float64(time.Second)),
"pages", state.Memory.PageCount(), "pages", state.GetMemory().PageCount(),
"mem", state.Memory.Usage(), "mem", state.GetMemory().Usage(),
"name", meta.LookupSymbol(state.Cpu.PC), "name", meta.LookupSymbol(state.GetPC()),
) )
} }
if sleepCheck(state.Cpu.PC) { // don't loop forever when we get stuck because of an unexpected bad program if sleepCheck(state.GetPC()) { // don't loop forever when we get stuck because of an unexpected bad program
return fmt.Errorf("got stuck in Go sleep at step %d", step) return fmt.Errorf("got stuck in Go sleep at step %d", step)
} }
...@@ -411,21 +436,14 @@ func Run(ctx *cli.Context) error { ...@@ -411,21 +436,14 @@ func Run(ctx *cli.Context) error {
} }
if proofAt(state) { if proofAt(state) {
preStateHash, err := state.EncodeWitness().StateHash()
if err != nil {
return fmt.Errorf("failed to hash prestate witness: %w", err)
}
witness, err := stepFn(true) witness, err := stepFn(true)
if err != nil { if err != nil {
return fmt.Errorf("failed at proof-gen step %d (PC: %08x): %w", step, state.Cpu.PC, err) return fmt.Errorf("failed at proof-gen step %d (PC: %08x): %w", step, state.GetPC(), err)
}
postStateHash, err := state.EncodeWitness().StateHash()
if err != nil {
return fmt.Errorf("failed to hash poststate witness: %w", err)
} }
_, postStateHash := state.EncodeWitness()
proof := &Proof{ proof := &Proof{
Step: step, Step: step,
Pre: preStateHash, Pre: witness.StateHash,
Post: postStateHash, Post: postStateHash,
StateData: witness.State, StateData: witness.State,
ProofData: witness.MemProof, ProofData: witness.MemProof,
...@@ -441,11 +459,11 @@ func Run(ctx *cli.Context) error { ...@@ -441,11 +459,11 @@ func Run(ctx *cli.Context) error {
} else { } else {
_, err = stepFn(false) _, err = stepFn(false)
if err != nil { if err != nil {
return fmt.Errorf("failed at step %d (PC: %08x): %w", step, state.Cpu.PC, err) return fmt.Errorf("failed at step %d (PC: %08x): %w", step, state.GetPC(), err)
} }
} }
lastPreimageKey, lastPreimageValue, lastPreimageOffset := us.LastPreimage() lastPreimageKey, lastPreimageValue, lastPreimageOffset := vm.LastPreimage()
if lastPreimageOffset != ^uint32(0) { if lastPreimageOffset != ^uint32(0) {
if stopAtAnyPreimage { if stopAtAnyPreimage {
l.Info("Stopping at preimage read") l.Info("Stopping at preimage read")
...@@ -464,16 +482,16 @@ func Run(ctx *cli.Context) error { ...@@ -464,16 +482,16 @@ func Run(ctx *cli.Context) error {
} }
} }
} }
l.Info("Execution stopped", "exited", state.Exited, "code", state.ExitCode) l.Info("Execution stopped", "exited", state.GetExited(), "code", state.GetExitCode())
if debugProgram { if debugProgram {
us.Traceback() vm.Traceback()
} }
if err := jsonutil.WriteJSON(ctx.Path(RunOutputFlag.Name), state, OutFilePerm); err != nil { if err := jsonutil.WriteJSON(ctx.Path(RunOutputFlag.Name), state, OutFilePerm); err != nil {
return fmt.Errorf("failed to write state output: %w", err) return fmt.Errorf("failed to write state output: %w", err)
} }
if debugInfoFile := ctx.Path(RunDebugInfoFlag.Name); debugInfoFile != "" { if debugInfoFile := ctx.Path(RunDebugInfoFlag.Name); debugInfoFile != "" {
if err := jsonutil.WriteJSON(debugInfoFile, us.GetDebugInfo(), OutFilePerm); err != nil { if err := jsonutil.WriteJSON(debugInfoFile, vm.GetDebugInfo(), OutFilePerm); err != nil {
return fmt.Errorf("failed to write benchmark data: %w", err) return fmt.Errorf("failed to write benchmark data: %w", err)
} }
} }
...@@ -486,6 +504,7 @@ var RunCommand = &cli.Command{ ...@@ -486,6 +504,7 @@ var RunCommand = &cli.Command{
Description: "Run VM step(s) and generate proof data to replicate onchain. See flags to match when to output a proof, a snapshot, or to stop early.", Description: "Run VM step(s) and generate proof data to replicate onchain. See flags to match when to output a proof, a snapshot, or to stop early.",
Action: Run, Action: Run,
Flags: []cli.Flag{ Flags: []cli.Flag{
RunType,
RunInputFlag, RunInputFlag,
RunOutputFlag, RunOutputFlag,
RunProofAtFlag, RunProofAtFlag,
......
...@@ -30,11 +30,7 @@ func Witness(ctx *cli.Context) error { ...@@ -30,11 +30,7 @@ func Witness(ctx *cli.Context) error {
if err != nil { if err != nil {
return fmt.Errorf("invalid input state (%v): %w", input, err) return fmt.Errorf("invalid input state (%v): %w", input, err)
} }
witness := state.EncodeWitness() witness, h := state.EncodeWitness()
h, err := witness.StateHash()
if err != nil {
return fmt.Errorf("failed to compute witness hash: %w", err)
}
if output != "" { if output != "" {
if err := os.WriteFile(output, witness, 0755); err != nil { if err := os.WriteFile(output, witness, 0755); err != nil {
return fmt.Errorf("writing output to %v: %w", output, err) return fmt.Errorf("writing output to %v: %w", output, err)
......
...@@ -196,7 +196,7 @@ func TestEVM(t *testing.T) { ...@@ -196,7 +196,7 @@ func TestEVM(t *testing.T) {
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
// verify the post-state matches. // verify the post-state matches.
// TODO: maybe more readable to decode the evmPost state, and do attribute-wise comparison. // TODO: maybe more readable to decode the evmPost state, and do attribute-wise comparison.
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equalf(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equalf(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM at step %d", state.Step) "mipsevm produced different state than EVM at step %d", state.Step)
} }
...@@ -243,7 +243,7 @@ func TestEVMSingleStep(t *testing.T) { ...@@ -243,7 +243,7 @@ func TestEVMSingleStep(t *testing.T) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evm.SetTracer(tracer) evm.SetTracer(tracer)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := us.state.EncodeWitness() goPost, _ := us.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -421,7 +421,7 @@ func TestEVMSysWriteHint(t *testing.T) { ...@@ -421,7 +421,7 @@ func TestEVMSysWriteHint(t *testing.T) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evm.SetTracer(tracer) evm.SetTracer(tracer)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := us.state.EncodeWitness() goPost, _ := us.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -459,8 +459,9 @@ func TestEVMFault(t *testing.T) { ...@@ -459,8 +459,9 @@ func TestEVMFault(t *testing.T) {
require.Panics(t, func() { _, _ = us.Step(true) }) require.Panics(t, func() { _, _ = us.Step(true) })
insnProof := initialState.Memory.MerkleProof(0) insnProof := initialState.Memory.MerkleProof(0)
encodedWitness, _ := initialState.EncodeWitness()
stepWitness := &StepWitness{ stepWitness := &StepWitness{
State: initialState.EncodeWitness(), State: encodedWitness,
MemProof: insnProof[:], MemProof: insnProof[:],
} }
input := encodeStepInput(t, stepWitness, LocalContext{}, contracts.MIPS) input := encodeStepInput(t, stepWitness, LocalContext{}, contracts.MIPS)
...@@ -509,7 +510,7 @@ func TestHelloEVM(t *testing.T) { ...@@ -509,7 +510,7 @@ func TestHelloEVM(t *testing.T) {
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
// verify the post-state matches. // verify the post-state matches.
// TODO: maybe more readable to decode the evmPost state, and do attribute-wise comparison. // TODO: maybe more readable to decode the evmPost state, and do attribute-wise comparison.
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
} }
...@@ -560,7 +561,7 @@ func TestClaimEVM(t *testing.T) { ...@@ -560,7 +561,7 @@ func TestClaimEVM(t *testing.T) {
evm.SetTracer(tracer) evm.SetTracer(tracer)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
} }
......
...@@ -61,7 +61,7 @@ func FuzzStateSyscallBrk(f *testing.F) { ...@@ -61,7 +61,7 @@ func FuzzStateSyscallBrk(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -112,7 +112,7 @@ func FuzzStateSyscallClone(f *testing.F) { ...@@ -112,7 +112,7 @@ func FuzzStateSyscallClone(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -173,7 +173,7 @@ func FuzzStateSyscallMmap(f *testing.F) { ...@@ -173,7 +173,7 @@ func FuzzStateSyscallMmap(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -223,7 +223,7 @@ func FuzzStateSyscallExitGroup(f *testing.F) { ...@@ -223,7 +223,7 @@ func FuzzStateSyscallExitGroup(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -288,7 +288,7 @@ func FuzzStateSyscallFcntl(f *testing.F) { ...@@ -288,7 +288,7 @@ func FuzzStateSyscallFcntl(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -340,7 +340,7 @@ func FuzzStateHintRead(f *testing.F) { ...@@ -340,7 +340,7 @@ func FuzzStateHintRead(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -406,7 +406,7 @@ func FuzzStatePreimageRead(f *testing.F) { ...@@ -406,7 +406,7 @@ func FuzzStatePreimageRead(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -466,7 +466,7 @@ func FuzzStateHintWrite(f *testing.F) { ...@@ -466,7 +466,7 @@ func FuzzStateHintWrite(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
...@@ -521,7 +521,7 @@ func FuzzStatePreimageWrite(f *testing.F) { ...@@ -521,7 +521,7 @@ func FuzzStatePreimageWrite(f *testing.F) {
evm := NewMIPSEVM(contracts, addrs) evm := NewMIPSEVM(contracts, addrs)
evmPost := evm.Step(t, stepWitness) evmPost := evm.Step(t, stepWitness)
goPost := goState.state.EncodeWitness() goPost, _ := goState.state.EncodeWitness()
require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(), require.Equal(t, hexutil.Bytes(goPost).String(), hexutil.Bytes(evmPost).String(),
"mipsevm produced different state than EVM") "mipsevm produced different state than EVM")
}) })
......
package mipsevm
import "github.com/ethereum/go-ethereum/common"
type FPVMState interface {
GetMemory() *Memory
// GetPC returns the currently executing program counter
GetPC() uint32
// GetStep returns the current VM step
GetStep() uint64
// GetExited returns whether the state exited bit is set
GetExited() bool
// GetExitCode returns the exit code
GetExitCode() uint8
// EncodeWitness returns the witness for the current state and the state hash
EncodeWitness() (witness []byte, hash common.Hash)
}
type FPVM interface {
// GetState returns the current state of the VM. The FPVMState is updated by successive calls to Step
GetState() FPVMState
// Step executes a single instruction and returns the witness for the step
Step(includeProof bool) (*StepWitness, error)
// LastPreimage returns the last preimage accessed by the VM
LastPreimage() (preimageKey [32]byte, preimage []byte, preimageOffset uint32)
// Traceback prints a traceback of the program to the console
Traceback()
// GetDebugInfo returns debug information about the VM
GetDebugInfo() *DebugInfo
}
...@@ -3,6 +3,8 @@ package mipsevm ...@@ -3,6 +3,8 @@ package mipsevm
import ( import (
"errors" "errors"
"io" "io"
"github.com/ethereum-optimism/optimism/op-service/jsonutil"
) )
type PreimageOracle interface { type PreimageOracle interface {
...@@ -48,6 +50,19 @@ func NewInstrumentedState(state *State, po PreimageOracle, stdOut, stdErr io.Wri ...@@ -48,6 +50,19 @@ func NewInstrumentedState(state *State, po PreimageOracle, stdOut, stdErr io.Wri
} }
} }
func NewInstrumentedStateFromFile(stateFile string, po PreimageOracle, stdOut, stdErr io.Writer) (*InstrumentedState, error) {
state, err := jsonutil.LoadJSON[State](stateFile)
if err != nil {
return nil, err
}
return &InstrumentedState{
state: state,
stdOut: stdOut,
stdErr: stdErr,
preimageOracle: &trackingOracle{po: po},
}, nil
}
func (m *InstrumentedState) InitDebug(meta *Metadata) error { func (m *InstrumentedState) InitDebug(meta *Metadata) error {
if meta == nil { if meta == nil {
return errors.New("metadata is nil") return errors.New("metadata is nil")
...@@ -64,8 +79,10 @@ func (m *InstrumentedState) Step(proof bool) (wit *StepWitness, err error) { ...@@ -64,8 +79,10 @@ func (m *InstrumentedState) Step(proof bool) (wit *StepWitness, err error) {
if proof { if proof {
insnProof := m.state.Memory.MerkleProof(m.state.Cpu.PC) insnProof := m.state.Memory.MerkleProof(m.state.Cpu.PC)
encodedWitness, stateHash := m.state.EncodeWitness()
wit = &StepWitness{ wit = &StepWitness{
State: m.state.EncodeWitness(), State: encodedWitness,
StateHash: stateHash,
MemProof: insnProof[:], MemProof: insnProof[:],
} }
} }
...@@ -89,11 +106,15 @@ func (m *InstrumentedState) LastPreimage() ([32]byte, []byte, uint32) { ...@@ -89,11 +106,15 @@ func (m *InstrumentedState) LastPreimage() ([32]byte, []byte, uint32) {
return m.lastPreimageKey, m.lastPreimage, m.lastPreimageOffset return m.lastPreimageKey, m.lastPreimage, m.lastPreimageOffset
} }
func (d *InstrumentedState) GetDebugInfo() *DebugInfo { func (m *InstrumentedState) GetState() FPVMState {
return m.state
}
func (m *InstrumentedState) GetDebugInfo() *DebugInfo {
return &DebugInfo{ return &DebugInfo{
Pages: d.state.Memory.PageCount(), Pages: m.state.Memory.PageCount(),
NumPreimageRequests: d.preimageOracle.numPreimageRequests, NumPreimageRequests: m.preimageOracle.numPreimageRequests,
TotalPreimageSize: d.preimageOracle.totalPreimageSize, TotalPreimageSize: m.preimageOracle.totalPreimageSize,
} }
} }
......
...@@ -104,13 +104,23 @@ func (s *State) UnmarshalJSON(data []byte) error { ...@@ -104,13 +104,23 @@ func (s *State) UnmarshalJSON(data []byte) error {
return nil return nil
} }
func (s *State) GetPC() uint32 { return s.Cpu.PC }
func (s *State) GetExitCode() uint8 { return s.ExitCode }
func (s *State) GetExited() bool { return s.Exited }
func (s *State) GetStep() uint64 { return s.Step } func (s *State) GetStep() uint64 { return s.Step }
func (s *State) VMStatus() uint8 { func (s *State) VMStatus() uint8 {
return vmStatus(s.Exited, s.ExitCode) return vmStatus(s.Exited, s.ExitCode)
} }
func (s *State) EncodeWitness() StateWitness { func (s *State) GetMemory() *Memory {
return s.Memory
}
func (s *State) EncodeWitness() ([]byte, common.Hash) {
out := make([]byte, 0) out := make([]byte, 0)
memRoot := s.Memory.MerkleRoot() memRoot := s.Memory.MerkleRoot()
out = append(out, memRoot[:]...) out = append(out, memRoot[:]...)
...@@ -131,7 +141,7 @@ func (s *State) EncodeWitness() StateWitness { ...@@ -131,7 +141,7 @@ func (s *State) EncodeWitness() StateWitness {
for _, r := range s.Registers { for _, r := range s.Registers {
out = binary.BigEndian.AppendUint32(out, r) out = binary.BigEndian.AppendUint32(out, r)
} }
return out return out, stateHashFromWitness(out)
} }
type StateWitness []byte type StateWitness []byte
...@@ -147,14 +157,20 @@ func (sw StateWitness) StateHash() (common.Hash, error) { ...@@ -147,14 +157,20 @@ func (sw StateWitness) StateHash() (common.Hash, error) {
if len(sw) != 226 { if len(sw) != 226 {
return common.Hash{}, fmt.Errorf("Invalid witness length. Got %d, expected 226", len(sw)) return common.Hash{}, fmt.Errorf("Invalid witness length. Got %d, expected 226", len(sw))
} }
return stateHashFromWitness(sw), nil
}
func stateHashFromWitness(sw []byte) common.Hash {
if len(sw) != 226 {
panic("Invalid witness length")
}
hash := crypto.Keccak256Hash(sw) hash := crypto.Keccak256Hash(sw)
offset := 32*2 + 4*6 offset := 32*2 + 4*6
exitCode := sw[offset] exitCode := sw[offset]
exited := sw[offset+1] exited := sw[offset+1]
status := vmStatus(exited == 1, exitCode) status := vmStatus(exited == 1, exitCode)
hash[0] = status hash[0] = status
return hash, nil return hash
} }
func vmStatus(exited bool, exitCode uint8) uint8 { func vmStatus(exited bool, exitCode uint8) uint8 {
......
...@@ -104,9 +104,7 @@ func TestStateHash(t *testing.T) { ...@@ -104,9 +104,7 @@ func TestStateHash(t *testing.T) {
ExitCode: c.exitCode, ExitCode: c.exitCode,
} }
actualWitness := state.EncodeWitness() actualWitness, actualStateHash := state.EncodeWitness()
actualStateHash, err := StateWitness(actualWitness).StateHash()
require.NoError(t, err, "Error hashing witness")
require.Equal(t, len(actualWitness), StateWitnessSize, "Incorrect witness size") require.Equal(t, len(actualWitness), StateWitnessSize, "Incorrect witness size")
expectedWitness := make(StateWitness, 226) expectedWitness := make(StateWitness, 226)
...@@ -118,7 +116,7 @@ func TestStateHash(t *testing.T) { ...@@ -118,7 +116,7 @@ func TestStateHash(t *testing.T) {
exited = 1 exited = 1
} }
expectedWitness[exitedOffset+1] = uint8(exited) expectedWitness[exitedOffset+1] = uint8(exited)
require.Equal(t, expectedWitness[:], actualWitness[:], "Incorrect witness") require.EqualValues(t, expectedWitness[:], actualWitness[:], "Incorrect witness")
expectedStateHash := crypto.Keccak256Hash(actualWitness) expectedStateHash := crypto.Keccak256Hash(actualWitness)
expectedStateHash[0] = vmStatus(c.exited, c.exitCode) expectedStateHash[0] = vmStatus(c.exited, c.exitCode)
......
...@@ -7,6 +7,7 @@ type LocalContext common.Hash ...@@ -7,6 +7,7 @@ type LocalContext common.Hash
type StepWitness struct { type StepWitness struct {
// encoded state witness // encoded state witness
State []byte State []byte
StateHash common.Hash
MemProof []byte MemProof []byte
......
...@@ -6,7 +6,6 @@ import ( ...@@ -6,7 +6,6 @@ import (
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum-optimism/optimism/cannon/mipsevm"
"github.com/ethereum-optimism/optimism/op-challenger/game/fault/types" "github.com/ethereum-optimism/optimism/op-challenger/game/fault/types"
) )
...@@ -22,26 +21,23 @@ func NewPrestateProvider(prestate string) *CannonPrestateProvider { ...@@ -22,26 +21,23 @@ func NewPrestateProvider(prestate string) *CannonPrestateProvider {
return &CannonPrestateProvider{prestate: prestate} return &CannonPrestateProvider{prestate: prestate}
} }
func (p *CannonPrestateProvider) absolutePreState() ([]byte, error) { func (p *CannonPrestateProvider) absolutePreState() ([]byte, common.Hash, error) {
state, err := parseState(p.prestate) state, err := parseState(p.prestate)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot load absolute pre-state: %w", err) return nil, common.Hash{}, fmt.Errorf("cannot load absolute pre-state: %w", err)
} }
return state.EncodeWitness(), nil witness, hash := state.EncodeWitness()
return witness, hash, nil
} }
func (p *CannonPrestateProvider) AbsolutePreStateCommitment(_ context.Context) (common.Hash, error) { func (p *CannonPrestateProvider) AbsolutePreStateCommitment(_ context.Context) (common.Hash, error) {
if p.prestateCommitment != (common.Hash{}) { if p.prestateCommitment != (common.Hash{}) {
return p.prestateCommitment, nil return p.prestateCommitment, nil
} }
state, err := p.absolutePreState() _, hash, err := p.absolutePreState()
if err != nil { if err != nil {
return common.Hash{}, fmt.Errorf("cannot load absolute pre-state: %w", err) return common.Hash{}, fmt.Errorf("cannot load absolute pre-state: %w", err)
} }
hash, err := mipsevm.StateWitness(state).StateHash()
if err != nil {
return common.Hash{}, fmt.Errorf("cannot hash absolute pre-state: %w", err)
}
p.prestateCommitment = hash p.prestateCommitment = hash
return hash, nil return hash, nil
} }
...@@ -56,8 +56,7 @@ func TestAbsolutePreStateCommitment(t *testing.T) { ...@@ -56,8 +56,7 @@ func TestAbsolutePreStateCommitment(t *testing.T) {
Step: 0, Step: 0,
Registers: [32]uint32{}, Registers: [32]uint32{},
} }
expected, err := state.EncodeWitness().StateHash() _, expected := state.EncodeWitness()
require.NoError(t, err)
require.Equal(t, expected, actual) require.Equal(t, expected, actual)
}) })
......
...@@ -132,11 +132,7 @@ func (p *CannonTraceProvider) loadProof(ctx context.Context, i uint64) (*utils.P ...@@ -132,11 +132,7 @@ func (p *CannonTraceProvider) loadProof(ctx context.Context, i uint64) (*utils.P
p.lastStep = state.Step - 1 p.lastStep = state.Step - 1
// Extend the trace out to the full length using a no-op instruction that doesn't change any state // Extend the trace out to the full length using a no-op instruction that doesn't change any state
// No execution is done, so no proof-data or oracle values are required. // No execution is done, so no proof-data or oracle values are required.
witness := state.EncodeWitness() witness, witnessHash := state.EncodeWitness()
witnessHash, err := mipsevm.StateWitness(witness).StateHash()
if err != nil {
return nil, fmt.Errorf("cannot hash witness: %w", err)
}
proof := &utils.ProofData{ proof := &utils.ProofData{
ClaimValue: witnessHash, ClaimValue: witnessHash,
StateData: hexutil.Bytes(witness), StateData: hexutil.Bytes(witness),
......
...@@ -57,8 +57,7 @@ func TestGet(t *testing.T) { ...@@ -57,8 +57,7 @@ func TestGet(t *testing.T) {
value, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, big.NewInt(7000))) value, err := provider.Get(context.Background(), PositionFromTraceIndex(provider, big.NewInt(7000)))
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, generator.generated, 7000, "should have tried to generate the proof") require.Contains(t, generator.generated, 7000, "should have tried to generate the proof")
stateHash, err := generator.finalState.EncodeWitness().StateHash() _, stateHash := generator.finalState.EncodeWitness()
require.NoError(t, err)
require.Equal(t, stateHash, value) require.Equal(t, stateHash, value)
}) })
...@@ -149,7 +148,7 @@ func TestGetStepData(t *testing.T) { ...@@ -149,7 +148,7 @@ func TestGetStepData(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, generator.generated, 7000, "should have tried to generate the proof") require.Contains(t, generator.generated, 7000, "should have tried to generate the proof")
witness := generator.finalState.EncodeWitness() witness, _ := generator.finalState.EncodeWitness()
require.EqualValues(t, witness, preimage) require.EqualValues(t, witness, preimage)
require.Equal(t, []byte{}, proof) require.Equal(t, []byte{}, proof)
require.Nil(t, data) require.Nil(t, data)
...@@ -190,7 +189,8 @@ func TestGetStepData(t *testing.T) { ...@@ -190,7 +189,8 @@ func TestGetStepData(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Empty(t, generator.generated, "should not have to generate the proof again") require.Empty(t, generator.generated, "should not have to generate the proof again")
require.EqualValues(t, initGenerator.finalState.EncodeWitness(), preimage) encodedWitness, _ := initGenerator.finalState.EncodeWitness()
require.EqualValues(t, encodedWitness, preimage)
require.Empty(t, proof) require.Empty(t, proof)
require.Nil(t, data) require.Nil(t, data)
}) })
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment