Commit 9e74a79a authored by protolambda's avatar protolambda

mipsevm,contracts: fix Oracle contract, fix oracle witness data encoding, add...

mipsevm,contracts: fix Oracle contract, fix oracle witness data encoding, add mem-proof for syscall mem read, implement multi-contract source-map tracing
parent c96647f5
...@@ -12,9 +12,8 @@ contract Oracle { ...@@ -12,9 +12,8 @@ contract Oracle {
require(preimagePartOk[key][offset], "preimage must exist"); require(preimagePartOk[key][offset], "preimage must exist");
datLen = 32; datLen = 32;
uint256 length = preimageLengths[key]; uint256 length = preimageLengths[key];
// TODO: insert length prefix before data if(offset + 32 >= length + 8) { // add 8 for the length-prefix part
if(offset + 32 >= length) { datLen = length + 8 - offset;
datLen = length - offset;
} }
dat = preimageParts[key][offset]; dat = preimageParts[key][offset];
} }
...@@ -36,18 +35,17 @@ contract Oracle { ...@@ -36,18 +35,17 @@ contract Oracle {
bytes32 key; bytes32 key;
bytes32 part; bytes32 part;
assembly { assembly {
// calldata layout: 4 (sel) + 0x20 (part offset) + 0x20 (start offset) + 0x20 (size) + preimage payload size := calldataload(0x44) // len(sig) + len(partOffset) + len(preimage offset) = 4 + 32 + 32 = 0x64
let startOffset := calldataload(0x24) if iszero(lt(partOffset, add(size, 8))) { // revert if part offset >= size+8 (i.e. parts must be within bounds)
if not(eq(startOffset, 0x44)) { // must always point to expected location of the size value.
revert(0, 0)
}
size := calldataload(0x44)
if iszero(lt(partOffset, size)) { // revert if part offset >= size (i.e. parts must be within bounds)
revert(0, 0) revert(0, 0)
} }
let ptr := 0x80 // we leave solidity slots 0x40 and 0x60 untouched, and everything after as scratch-memory. let ptr := 0x80 // we leave solidity slots 0x40 and 0x60 untouched, and everything after as scratch-memory.
calldatacopy(ptr, 0x64, size) // copy preimage payload into memory so we can hash and read it. mstore(ptr, shl(192, size)) // put size as big-endian uint64 at start of pre-image
part := mload(add(ptr, partOffset)) // this will be zero-padded at the end, since memory at end is clean. ptr := add(ptr, 8)
calldatacopy(ptr, preimage.offset, size) // copy preimage payload into memory so we can hash and read it.
// Note that it includes the 8-byte big-endian uint64 length prefix.
// this will be zero-padded at the end, since memory at end is clean.
part := mload(add(sub(ptr, 8), partOffset))
let h := keccak256(ptr, size) // compute preimage keccak256 hash let h := keccak256(ptr, size) // compute preimage keccak256 hash
key := or(and(h, not(shl(248, 0xFF))), shl(248, 2)) // mask out prefix byte, replace with type 2 byte key := or(and(h, not(shl(248, 0xFF))), shl(248, 2)) // mask out prefix byte, replace with type 2 byte
} }
......
...@@ -20,15 +20,21 @@ import ( ...@@ -20,15 +20,21 @@ import (
"github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/params"
) )
var StepBytes4 = crypto.Keccak256Hash([]byte("Step(bytes32,bytes,bytes)")).Bytes()[:4] var (
StepBytes4 = crypto.Keccak256([]byte("Step(bytes32,bytes,bytes)"))[:4]
CheatBytes4 = crypto.Keccak256([]byte("cheat(uint256,bytes32,bytes32,uint256)"))[:4]
LoadKeccak256PreimagePartBytes4 = crypto.Keccak256([]byte("loadKeccak256PreimagePart(uint256,bytes)"))[:4]
)
func LoadContracts() (*Contracts, error) { func LoadContracts() (*Contracts, error) {
mips, err := LoadContract("MIPS") mips, err := LoadContract("MIPS")
if err != nil { if err != nil {
return nil, err return nil, err
} }
oracle, err := LoadContract("Oracle")
return &Contracts{ return &Contracts{
MIPS: mips, MIPS: mips,
Oracle: oracle,
}, nil }, nil
} }
...@@ -59,17 +65,22 @@ func (c *Contract) SourceMap(sourcePaths []string) (*SourceMap, error) { ...@@ -59,17 +65,22 @@ func (c *Contract) SourceMap(sourcePaths []string) (*SourceMap, error) {
} }
type Contracts struct { type Contracts struct {
MIPS *Contract MIPS *Contract
Oracle *Contract
} }
type Addresses struct { type Addresses struct {
MIPS common.Address MIPS common.Address
Oracle common.Address
Sender common.Address
FeeRecipient common.Address
} }
func NewEVMEnv(contracts *Contracts, addrs *Addresses) (*vm.EVM, *state.StateDB) { func NewEVMEnv(contracts *Contracts, addrs *Addresses) (*vm.EVM, *state.StateDB) {
chainCfg := params.MainnetChainConfig chainCfg := params.MainnetChainConfig
bc := &testChain{} offsetBlocks := uint64(1000) // blocks after shanghai fork
header := bc.GetHeader(common.Hash{}, 100) bc := &testChain{startTime: *chainCfg.ShanghaiTime + offsetBlocks*12}
header := bc.GetHeader(common.Hash{}, 17034870+offsetBlocks)
db := rawdb.NewMemoryDatabase() db := rawdb.NewMemoryDatabase()
statedb := state.NewDatabase(db) statedb := state.NewDatabase(db)
state, err := state.New(types.EmptyRootHash, statedb, nil) state, err := state.New(types.EmptyRootHash, statedb, nil)
...@@ -81,13 +92,17 @@ func NewEVMEnv(contracts *Contracts, addrs *Addresses) (*vm.EVM, *state.StateDB) ...@@ -81,13 +92,17 @@ func NewEVMEnv(contracts *Contracts, addrs *Addresses) (*vm.EVM, *state.StateDB)
env := vm.NewEVM(blockContext, vm.TxContext{}, state, chainCfg, vmCfg) env := vm.NewEVM(blockContext, vm.TxContext{}, state, chainCfg, vmCfg)
// pre-deploy the contracts // pre-deploy the contracts
env.StateDB.SetCode(addrs.MIPS, contracts.MIPS.DeployedBytecode.Object) env.StateDB.SetCode(addrs.MIPS, contracts.MIPS.DeployedBytecode.Object)
// TODO: any state to set, or immutables to replace, to link the contracts together? env.StateDB.SetCode(addrs.Oracle, contracts.Oracle.DeployedBytecode.Object)
env.StateDB.SetState(addrs.MIPS, common.Hash{}, addrs.Oracle.Hash())
rules := env.ChainConfig().Rules(header.Number, true, header.Time)
env.StateDB.Prepare(rules, addrs.Sender, addrs.FeeRecipient, &addrs.MIPS, vm.ActivePrecompiles(rules), nil)
return env, state return env, state
} }
type testChain struct { type testChain struct {
startTime uint64
} }
func (d *testChain) Engine() consensus.Engine { func (d *testChain) Engine() consensus.Engine {
...@@ -109,7 +124,7 @@ func (d *testChain) GetHeader(h common.Hash, n uint64) *types.Header { ...@@ -109,7 +124,7 @@ func (d *testChain) GetHeader(h common.Hash, n uint64) *types.Header {
Number: new(big.Int).SetUint64(n), Number: new(big.Int).SetUint64(n),
GasLimit: 30_000_000, GasLimit: 30_000_000,
GasUsed: 0, GasUsed: 0,
Time: 1337, Time: d.startTime + n*12,
Extra: nil, Extra: nil,
MixDigest: common.Hash{}, MixDigest: common.Hash{},
Nonce: types.BlockNonce{}, Nonce: types.BlockNonce{},
......
...@@ -17,21 +17,32 @@ import ( ...@@ -17,21 +17,32 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestEVM(t *testing.T) { func testContractsSetup(t *testing.T) (*Contracts, *Addresses, *SourceMapTracer) {
testFiles, err := os.ReadDir("test/bin")
require.NoError(t, err)
contracts, err := LoadContracts() contracts, err := LoadContracts()
require.NoError(t, err) require.NoError(t, err)
// the first unlisted source seems to be the ABIDecoderV2 code that the compiler inserts mipsSrcMap, err := contracts.MIPS.SourceMap([]string{"../contracts/src/MIPS.sol"})
mipsSrcMap, err := contracts.MIPS.SourceMap([]string{"../contracts/src/MIPS.sol", "~compiler?", "../contracts/src/MIPS.sol"}) require.NoError(t, err)
oracleSrcMap, err := contracts.Oracle.SourceMap([]string{"../contracts/src/Oracle.sol"})
require.NoError(t, err) require.NoError(t, err)
addrs := &Addresses{ addrs := &Addresses{
MIPS: common.Address{0: 0xff, 19: 1}, MIPS: common.Address{0: 0xff, 19: 1},
Oracle: common.Address{0: 0xff, 19: 2},
Sender: common.Address{0x13, 0x37},
FeeRecipient: common.Address{0xaa},
} }
tracer := NewSourceMapTracer(map[common.Address]*SourceMap{addrs.MIPS: mipsSrcMap, addrs.Oracle: oracleSrcMap}, os.Stdout)
return contracts, addrs, tracer
}
func TestEVM(t *testing.T) {
testFiles, err := os.ReadDir("test/bin")
require.NoError(t, err)
contracts, addrs, tracer := testContractsSetup(t)
sender := common.Address{0x13, 0x37} sender := common.Address{0x13, 0x37}
//tracer = logger.NewMarkdownLogger(&logger.Config{}, os.Stdout)
for _, f := range testFiles { for _, f := range testFiles {
t.Run(f.Name(), func(t *testing.T) { t.Run(f.Name(), func(t *testing.T) {
...@@ -41,8 +52,7 @@ func TestEVM(t *testing.T) { ...@@ -41,8 +52,7 @@ func TestEVM(t *testing.T) {
env, evmState := NewEVMEnv(contracts, addrs) env, evmState := NewEVMEnv(contracts, addrs)
env.Config.Debug = false env.Config.Debug = false
//env.Config.Tracer = logger.NewMarkdownLogger(&logger.Config{}, os.Stdout) env.Config.Tracer = tracer
env.Config.Tracer = mipsSrcMap.Tracer(os.Stdout)
fn := path.Join("test/bin", f.Name()) fn := path.Join("test/bin", f.Name())
programMem, err := os.ReadFile(fn) programMem, err := os.ReadFile(fn)
...@@ -108,16 +118,7 @@ func TestEVM(t *testing.T) { ...@@ -108,16 +118,7 @@ func TestEVM(t *testing.T) {
} }
func TestHelloEVM(t *testing.T) { func TestHelloEVM(t *testing.T) {
contracts, err := LoadContracts() contracts, addrs, tracer := testContractsSetup(t)
require.NoError(t, err)
// the first unlisted source seems to be the ABIDecoderV2 code that the compiler inserts
mipsSrcMap, err := contracts.MIPS.SourceMap([]string{"../contracts/src/MIPS.sol", "~compiler?", "../contracts/src/MIPS.sol"})
require.NoError(t, err)
addrs := &Addresses{
MIPS: common.Address{0: 0xff, 19: 1},
}
sender := common.Address{0x13, 0x37} sender := common.Address{0x13, 0x37}
elfProgram, err := elf.Open("../example/bin/hello.elf") elfProgram, err := elf.Open("../example/bin/hello.elf")
...@@ -140,8 +141,7 @@ func TestHelloEVM(t *testing.T) { ...@@ -140,8 +141,7 @@ func TestHelloEVM(t *testing.T) {
env, evmState := NewEVMEnv(contracts, addrs) env, evmState := NewEVMEnv(contracts, addrs)
env.Config.Debug = false env.Config.Debug = false
//env.Config.Tracer = logger.NewMarkdownLogger(&logger.Config{}, os.Stdout) env.Config.Tracer = tracer
env.Config.Tracer = mipsSrcMap.Tracer(os.Stdout)
start := time.Now() start := time.Now()
for i := 0; i < 400_000; i++ { for i := 0; i < 400_000; i++ {
...@@ -188,3 +188,75 @@ func TestHelloEVM(t *testing.T) { ...@@ -188,3 +188,75 @@ func TestHelloEVM(t *testing.T) {
require.Equal(t, "hello world!", stdOutBuf.String(), "stdout says hello") require.Equal(t, "hello world!", stdOutBuf.String(), "stdout says hello")
require.Equal(t, "", stdErrBuf.String(), "stderr silent") require.Equal(t, "", stdErrBuf.String(), "stderr silent")
} }
func TestClaimEVM(t *testing.T) {
contracts, addrs, tracer := testContractsSetup(t)
elfProgram, err := elf.Open("../example/bin/claim.elf")
require.NoError(t, err, "open ELF file")
state, err := LoadELF(elfProgram)
require.NoError(t, err, "load ELF into state")
err = patchVM(elfProgram, state)
require.NoError(t, err, "apply Go runtime patches")
mu, err := NewUnicorn()
require.NoError(t, err, "load unicorn")
defer mu.Close()
err = LoadUnicorn(state, mu)
require.NoError(t, err, "load state into unicorn")
oracle, expectedStdOut, expectedStdErr := claimTestOracle(t)
var stdOutBuf, stdErrBuf bytes.Buffer
us, err := NewUnicornState(mu, state, oracle, io.MultiWriter(&stdOutBuf, os.Stdout), io.MultiWriter(&stdErrBuf, os.Stderr))
require.NoError(t, err, "hook unicorn to state")
env, evmState := NewEVMEnv(contracts, addrs)
env.Config.Debug = false
env.Config.Tracer = tracer
for i := 0; i < 2000_000; i++ {
if us.state.Exited {
break
}
insn := state.Memory.GetMemory(state.PC)
if i%1000 == 0 { // avoid spamming test logs, we are executing many steps
t.Logf("step: %4d pc: 0x%08x insn: 0x%08x", state.Step, state.PC, insn)
}
stepWitness := us.Step(true)
input := stepWitness.EncodeStepInput()
startingGas := uint64(30_000_000)
// we take a snapshot so we can clean up the state, and isolate the logs of this instruction run.
snap := env.StateDB.Snapshot()
// prepare pre-image oracle data, if any
if stepWitness.HasPreimage() {
poInput, err := stepWitness.EncodePreimageOracleInput()
require.NoError(t, err, "encode preimage oracle input")
_, leftOverGas, err := env.Call(vm.AccountRef(addrs.Sender), addrs.Oracle, poInput, startingGas, big.NewInt(0))
require.NoErrorf(t, err, "evm should not fail, took %d gas", startingGas-leftOverGas)
}
ret, leftOverGas, err := env.Call(vm.AccountRef(addrs.Sender), addrs.MIPS, input, startingGas, big.NewInt(0))
require.NoErrorf(t, err, "evm should not fail, took %d gas", startingGas-leftOverGas)
require.Len(t, ret, 32, "expecting 32-byte state hash")
// remember state hash, to check it against state
postHash := common.Hash(*(*[32]byte)(ret))
logs := evmState.Logs()
require.Equal(t, 1, len(logs), "expecting a log with post-state")
evmPost := logs[0].Data
require.Equal(t, crypto.Keccak256Hash(evmPost), postHash, "logged state must be accurate")
env.StateDB.RevertToSnapshot(snap)
}
require.True(t, state.Exited, "must complete program")
require.Equal(t, uint8(0), state.ExitCode, "exit with 0")
require.Equal(t, expectedStdOut, stdOutBuf.String(), "stdout")
require.Equal(t, expectedStdErr, stdErrBuf.String(), "stderr")
}
...@@ -175,13 +175,13 @@ func ParseSourceMap(sources []string, bytecode []byte, sourceMap string) (*Sourc ...@@ -175,13 +175,13 @@ func ParseSourceMap(sources []string, bytecode []byte, sourceMap string) (*Sourc
return srcMap, nil return srcMap, nil
} }
func (s *SourceMap) Tracer(out io.Writer) *SourceMapTracer { func NewSourceMapTracer(srcMaps map[common.Address]*SourceMap, out io.Writer) *SourceMapTracer {
return &SourceMapTracer{s, out} return &SourceMapTracer{srcMaps, out}
} }
type SourceMapTracer struct { type SourceMapTracer struct {
srcMap *SourceMap srcMaps map[common.Address]*SourceMap
out io.Writer out io.Writer
} }
func (s *SourceMapTracer) CaptureTxStart(gasLimit uint64) {} func (s *SourceMapTracer) CaptureTxStart(gasLimit uint64) {}
...@@ -198,12 +198,25 @@ func (s *SourceMapTracer) CaptureEnter(typ vm.OpCode, from common.Address, to co ...@@ -198,12 +198,25 @@ func (s *SourceMapTracer) CaptureEnter(typ vm.OpCode, from common.Address, to co
func (s *SourceMapTracer) CaptureExit(output []byte, gasUsed uint64, err error) {} func (s *SourceMapTracer) CaptureExit(output []byte, gasUsed uint64, err error) {}
func (s *SourceMapTracer) info(codeAddr *common.Address, pc uint64) string {
info := "non-contract"
if codeAddr != nil {
srcMap, ok := s.srcMaps[*codeAddr]
if ok {
info = srcMap.FormattedInfo(pc)
} else {
info = "unknown-contract"
}
}
return info
}
func (s *SourceMapTracer) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { func (s *SourceMapTracer) CaptureState(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) {
fmt.Fprintf(s.out, "%-40s : pc %x opcode %s\n", s.srcMap.FormattedInfo(pc), pc, op.String()) fmt.Fprintf(s.out, "%-40s : pc %x opcode %s\n", s.info(scope.Contract.CodeAddr, pc), pc, op.String())
} }
func (s *SourceMapTracer) CaptureFault(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, depth int, err error) { func (s *SourceMapTracer) CaptureFault(pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, depth int, err error) {
fmt.Fprintf(s.out, "%-40s: pc %x opcode %s FAULT %v\n", s.srcMap.FormattedInfo(pc), pc, op.String(), err) fmt.Fprintf(s.out, "%-40s: pc %x opcode %s FAULT %v\n", s.info(scope.Contract.CodeAddr, pc), pc, op.String(), err)
fmt.Println("----") fmt.Println("----")
fmt.Fprintf(s.out, "calldata: %x\n", scope.Contract.Input) fmt.Fprintf(s.out, "calldata: %x\n", scope.Contract.Input)
fmt.Println("----") fmt.Println("----")
......
...@@ -11,7 +11,7 @@ type State struct { ...@@ -11,7 +11,7 @@ type State struct {
Memory *Memory `json:"memory"` Memory *Memory `json:"memory"`
PreimageKey common.Hash `json:"preimageKey"` PreimageKey common.Hash `json:"preimageKey"`
PreimageOffset uint32 `json:"preimageOffset"` PreimageOffset uint32 `json:"preimageOffset"` // note that the offset includes the 8-byte length prefix
PC uint32 `json:"pc"` PC uint32 `json:"pc"`
NextPC uint32 `json:"nextPC"` NextPC uint32 `json:"nextPC"`
......
...@@ -132,22 +132,7 @@ func (t *testOracle) GetPreimage(k [32]byte) []byte { ...@@ -132,22 +132,7 @@ func (t *testOracle) GetPreimage(k [32]byte) []byte {
var _ PreimageOracle = (*testOracle)(nil) var _ PreimageOracle = (*testOracle)(nil)
func TestClaim(t *testing.T) { func claimTestOracle(t *testing.T) (po PreimageOracle, stdOut string, stdErr string) {
elfProgram, err := elf.Open("../example/bin/claim.elf")
require.NoError(t, err, "open ELF file")
state, err := LoadELF(elfProgram)
require.NoError(t, err, "load ELF into state")
err = patchVM(elfProgram, state)
require.NoError(t, err, "apply Go runtime patches")
mu, err := NewUnicorn()
require.NoError(t, err, "load unicorn")
defer mu.Close()
err = LoadUnicorn(state, mu)
require.NoError(t, err, "load state into unicorn")
s := uint64(1000) s := uint64(1000)
a := uint64(3) a := uint64(3)
b := uint64(4) b := uint64(4)
...@@ -198,6 +183,27 @@ func TestClaim(t *testing.T) { ...@@ -198,6 +183,27 @@ func TestClaim(t *testing.T) {
}, },
} }
return oracle, fmt.Sprintf("computing %d * %d + %d\nclaim %d is good!\n", s, a, b, s*a+b), "started!"
}
func TestClaim(t *testing.T) {
elfProgram, err := elf.Open("../example/bin/claim.elf")
require.NoError(t, err, "open ELF file")
state, err := LoadELF(elfProgram)
require.NoError(t, err, "load ELF into state")
err = patchVM(elfProgram, state)
require.NoError(t, err, "apply Go runtime patches")
mu, err := NewUnicorn()
require.NoError(t, err, "load unicorn")
defer mu.Close()
err = LoadUnicorn(state, mu)
require.NoError(t, err, "load state into unicorn")
oracle, expectedStdOut, expectedStdErr := claimTestOracle(t)
var stdOutBuf, stdErrBuf bytes.Buffer var stdOutBuf, stdErrBuf bytes.Buffer
us, err := NewUnicornState(mu, state, oracle, io.MultiWriter(&stdOutBuf, os.Stdout), io.MultiWriter(&stdErrBuf, os.Stderr)) us, err := NewUnicornState(mu, state, oracle, io.MultiWriter(&stdOutBuf, os.Stdout), io.MultiWriter(&stdErrBuf, os.Stderr))
require.NoError(t, err, "hook unicorn to state") require.NoError(t, err, "hook unicorn to state")
...@@ -212,6 +218,6 @@ func TestClaim(t *testing.T) { ...@@ -212,6 +218,6 @@ func TestClaim(t *testing.T) {
require.True(t, state.Exited, "must complete program") require.True(t, state.Exited, "must complete program")
require.Equal(t, uint8(0), state.ExitCode, "exit with 0") require.Equal(t, uint8(0), state.ExitCode, "exit with 0")
require.Equal(t, fmt.Sprintf("computing %d * %d + %d\nclaim %d is good!\n", s, a, b, s*a+b), stdOutBuf.String(), "stdout says hello") require.Equal(t, expectedStdOut, stdOutBuf.String(), "stdout")
require.Equal(t, "started!", stdErrBuf.String(), "stderr silent") require.Equal(t, expectedStdErr, stdErrBuf.String(), "stderr")
} }
...@@ -83,6 +83,16 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std ...@@ -83,6 +83,16 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std
return return
} }
trackMemAccess := func(effAddr uint32) {
if m.memProofEnabled && m.lastMemAccess != effAddr {
if m.lastMemAccess != ^uint32(0) {
panic(fmt.Errorf("unexpected different mem access at %08x, already have access at %08x buffered", effAddr, m.lastMemAccess))
}
m.lastMemAccess = effAddr
m.memProof = m.state.Memory.MerkleProof(effAddr)
}
}
var err error var err error
_, err = mu.HookAdd(uc.HOOK_INTR, func(mu uc.Unicorn, intno uint32) { _, err = mu.HookAdd(uc.HOOK_INTR, func(mu uc.Unicorn, intno uint32) {
if intno != 17 { if intno != 17 {
...@@ -137,6 +147,7 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std ...@@ -137,6 +147,7 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std
// leave v0 and v1 zero: read nothing, no error // leave v0 and v1 zero: read nothing, no error
case fdPreimageRead: // pre-image oracle case fdPreimageRead: // pre-image oracle
effAddr := a1 & 0xFFffFFfc effAddr := a1 & 0xFFffFFfc
trackMemAccess(effAddr)
mem := st.Memory.GetMemory(effAddr) mem := st.Memory.GetMemory(effAddr)
dat, datLen := readPreimage(st.PreimageKey, st.PreimageOffset) dat, datLen := readPreimage(st.PreimageKey, st.PreimageOffset)
fmt.Printf("reading pre-image data: addr: %08x, offset: %d, datLen: %d, data: %x, key: %s count: %d\n", a1, st.PreimageOffset, datLen, dat[:datLen], st.PreimageKey, a2) fmt.Printf("reading pre-image data: addr: %08x, offset: %d, datLen: %d, data: %x, key: %s count: %d\n", a1, st.PreimageOffset, datLen, dat[:datLen], st.PreimageKey, a2)
...@@ -190,7 +201,9 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std ...@@ -190,7 +201,9 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std
} }
v0 = a2 v0 = a2
case fdPreimageWrite: case fdPreimageWrite:
mem := st.Memory.GetMemory(a1 & 0xFFffFFfc) effAddr := a1 & 0xFFffFFfc
trackMemAccess(effAddr)
mem := st.Memory.GetMemory(effAddr)
key := st.PreimageKey key := st.PreimageKey
alignment := a1 & 3 alignment := a1 & 3
space := 4 - alignment space := 4 - alignment
...@@ -244,13 +257,7 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std ...@@ -244,13 +257,7 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std
_, err = mu.HookAdd(uc.HOOK_MEM_READ, func(mu uc.Unicorn, access int, addr64 uint64, size int, value int64) { _, err = mu.HookAdd(uc.HOOK_MEM_READ, func(mu uc.Unicorn, access int, addr64 uint64, size int, value int64) {
effAddr := uint32(addr64 & 0xFFFFFFFC) // pass effective addr to tracer effAddr := uint32(addr64 & 0xFFFFFFFC) // pass effective addr to tracer
if m.memProofEnabled && m.lastMemAccess != effAddr { trackMemAccess(effAddr)
if m.lastMemAccess != ^uint32(0) {
panic(fmt.Errorf("unexpected different mem access at %08x, already have access at %08x buffered", effAddr, m.lastMemAccess))
}
m.lastMemAccess = effAddr
m.memProof = m.state.Memory.MerkleProof(effAddr)
}
}, 0, ^uint64(0)) }, 0, ^uint64(0))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to set up mem-write hook: %w", err) return nil, fmt.Errorf("failed to set up mem-write hook: %w", err)
...@@ -283,13 +290,7 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std ...@@ -283,13 +290,7 @@ func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, std
} else { } else {
log.Fatal("bad size write to ram") log.Fatal("bad size write to ram")
} }
if m.memProofEnabled && m.lastMemAccess != effAddr { trackMemAccess(effAddr)
if m.lastMemAccess != ^uint32(0) {
panic(fmt.Errorf("unexpected different mem access at %08x, already have access at %08x buffered", effAddr, m.lastMemAccess))
}
m.lastMemAccess = effAddr
m.memProof = m.state.Memory.MerkleProof(effAddr)
}
// only set memory after making the proof: we need the pre-state // only set memory after making the proof: we need the pre-state
st.Memory.SetMemory(effAddr, post) st.Memory.SetMemory(effAddr, post)
}, 0, ^uint64(0)) }, 0, ^uint64(0))
......
...@@ -2,8 +2,12 @@ package mipsevm ...@@ -2,8 +2,12 @@ package mipsevm
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum-optimism/cannon/preimage"
) )
type StepWitness struct { type StepWitness struct {
...@@ -12,7 +16,7 @@ type StepWitness struct { ...@@ -12,7 +16,7 @@ type StepWitness struct {
memProof []byte memProof []byte
preimageKey [32]byte // zeroed when no pre-image is accessed preimageKey [32]byte // zeroed when no pre-image is accessed
preimageValue []byte preimageValue []byte // including the 8-byte length prefix
preimageOffset uint32 preimageOffset uint32
} }
...@@ -37,10 +41,42 @@ func (wit *StepWitness) EncodeStepInput() []byte { ...@@ -37,10 +41,42 @@ func (wit *StepWitness) EncodeStepInput() []byte {
return input return input
} }
func (wit *StepWitness) EncodePreimageOracleInput() []byte { func (wit *StepWitness) HasPreimage() bool {
return wit.preimageKey != ([32]byte{})
}
func (wit *StepWitness) EncodePreimageOracleInput() ([]byte, error) {
if wit.preimageKey == ([32]byte{}) { if wit.preimageKey == ([32]byte{}) {
return nil return nil, errors.New("cannot encode pre-image oracle input, witness has no pre-image to proof")
}
switch preimage.KeyType(wit.preimageKey[0]) {
case preimage.LocalKeyType:
// We have no on-chain form of preparing the bootstrap pre-images onchain yet.
// So instead we cheat them in.
// In production usage there should be an on-chain contract that exposes this,
// rather than going through the global keccak256 oracle.
var input []byte
input = append(input, CheatBytes4...)
input = append(input, uint32ToBytes32(wit.preimageOffset)...)
input = append(input, wit.preimageKey[:]...)
var tmp [32]byte
copy(tmp[:], wit.preimageValue[wit.preimageOffset:])
input = append(input, tmp[:]...)
input = append(input, uint32ToBytes32(uint32(len(wit.preimageValue))-8)...)
// TODO: do we want to pad the end to a multiple of 32 bytes?
return input, nil
case preimage.Keccak256KeyType:
var input []byte
input = append(input, LoadKeccak256PreimagePartBytes4...)
input = append(input, uint32ToBytes32(wit.preimageOffset)...)
input = append(input, uint32ToBytes32(32+32)...) // partOffset, calldata offset
input = append(input, uint32ToBytes32(uint32(len(wit.preimageValue))-8)...)
input = append(input, wit.preimageValue[8:]...)
// TODO: do we want to pad the end to a multiple of 32 bytes?
return input, nil
default:
return nil, fmt.Errorf("unsupported pre-image type %d, cannot prepare preimage with key %x offset %d for oracle",
wit.preimageKey[0], wit.preimageKey, wit.preimageOffset)
} }
// TODO
return nil
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment