Commit 748773f1 authored by protolambda's avatar protolambda

mipsevm: step witness type, update syscall handling

parent e482f254
...@@ -3,7 +3,6 @@ package mipsevm ...@@ -3,7 +3,6 @@ package mipsevm
import ( import (
"bytes" "bytes"
"debug/elf" "debug/elf"
"encoding/binary"
"io" "io"
"math/big" "math/big"
"os" "os"
...@@ -64,7 +63,7 @@ func TestEVM(t *testing.T) { ...@@ -64,7 +63,7 @@ func TestEVM(t *testing.T) {
err = LoadUnicorn(state, mu) err = LoadUnicorn(state, mu)
require.NoError(t, err, "load state into unicorn") require.NoError(t, err, "load state into unicorn")
us, err := NewUnicornState(mu, state, os.Stdout, os.Stderr) us, err := NewUnicornState(mu, state, nil, os.Stdout, os.Stderr)
require.NoError(t, err, "hook unicorn to state") require.NoError(t, err, "hook unicorn to state")
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
...@@ -74,8 +73,8 @@ func TestEVM(t *testing.T) { ...@@ -74,8 +73,8 @@ func TestEVM(t *testing.T) {
insn := state.Memory.GetMemory(state.PC) insn := state.Memory.GetMemory(state.PC)
t.Logf("step: %4d pc: 0x%08x insn: 0x%08x", state.Step, state.PC, insn) t.Logf("step: %4d pc: 0x%08x insn: 0x%08x", state.Step, state.PC, insn)
stateData, proofData := us.Step(true) stepWitness := us.Step(true)
input := FormatStepInput(stateData, proofData) input := stepWitness.EncodeStepInput()
startingGas := uint64(30_000_000) startingGas := uint64(30_000_000)
// we take a snapshot so we can clean up the state, and isolate the logs of this instruction run. // we take a snapshot so we can clean up the state, and isolate the logs of this instruction run.
...@@ -108,21 +107,6 @@ func TestEVM(t *testing.T) { ...@@ -108,21 +107,6 @@ func TestEVM(t *testing.T) {
} }
} }
func FormatStepInput(stateData, proofData []byte) []byte {
stateHash := crypto.Keccak256Hash(stateData)
var input []byte
input = append(input, StepBytes4...)
input = append(input, stateHash[:]...)
input = append(input, uint32ToBytes32(32*3)...) // state data offset in bytes
input = append(input, uint32ToBytes32(32*3+32+uint32(len(stateData)))...) // proof data offset in bytes
input = append(input, uint32ToBytes32(uint32(len(stateData)))...) // state data length in bytes
input = append(input, stateData[:]...)
input = append(input, uint32ToBytes32(uint32(len(proofData)))...) // proof data length in bytes
input = append(input, proofData[:]...)
return input
}
func TestMinimalEVM(t *testing.T) { func TestMinimalEVM(t *testing.T) {
contracts, err := LoadContracts() contracts, err := LoadContracts()
require.NoError(t, err) require.NoError(t, err)
...@@ -151,7 +135,7 @@ func TestMinimalEVM(t *testing.T) { ...@@ -151,7 +135,7 @@ func TestMinimalEVM(t *testing.T) {
err = LoadUnicorn(state, mu) err = LoadUnicorn(state, mu)
require.NoError(t, err, "load state into unicorn") require.NoError(t, err, "load state into unicorn")
var stdOutBuf, stdErrBuf bytes.Buffer var stdOutBuf, stdErrBuf bytes.Buffer
us, err := NewUnicornState(mu, state, io.MultiWriter(&stdOutBuf, os.Stdout), io.MultiWriter(&stdErrBuf, os.Stderr)) us, err := NewUnicornState(mu, state, nil, io.MultiWriter(&stdOutBuf, os.Stdout), io.MultiWriter(&stdErrBuf, os.Stderr))
require.NoError(t, err, "hook unicorn to state") require.NoError(t, err, "hook unicorn to state")
env, evmState := NewEVMEnv(contracts, addrs) env, evmState := NewEVMEnv(contracts, addrs)
...@@ -169,8 +153,8 @@ func TestMinimalEVM(t *testing.T) { ...@@ -169,8 +153,8 @@ func TestMinimalEVM(t *testing.T) {
t.Logf("step: %4d pc: 0x%08x insn: 0x%08x", state.Step, state.PC, insn) t.Logf("step: %4d pc: 0x%08x insn: 0x%08x", state.Step, state.PC, insn)
} }
stateData, proofData := us.Step(true) stepWitness := us.Step(true)
input := FormatStepInput(stateData, proofData) input := stepWitness.EncodeStepInput()
startingGas := uint64(30_000_000) startingGas := uint64(30_000_000)
// we take a snapshot so we can clean up the state, and isolate the logs of this instruction run. // we take a snapshot so we can clean up the state, and isolate the logs of this instruction run.
...@@ -204,9 +188,3 @@ func TestMinimalEVM(t *testing.T) { ...@@ -204,9 +188,3 @@ func TestMinimalEVM(t *testing.T) {
require.Equal(t, "hello world!", stdOutBuf.String(), "stdout says hello") require.Equal(t, "hello world!", stdOutBuf.String(), "stdout says hello")
require.Equal(t, "", stdErrBuf.String(), "stderr silent") require.Equal(t, "", stdErrBuf.String(), "stderr silent")
} }
func uint32ToBytes32(v uint32) []byte {
var out [32]byte
binary.BigEndian.PutUint32(out[32-4:], v)
return out[:]
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
) )
type State struct { type State struct {
...@@ -24,6 +25,13 @@ type State struct { ...@@ -24,6 +25,13 @@ type State struct {
Step uint64 `json:"step"` Step uint64 `json:"step"`
Registers [32]uint32 `json:"registers"` Registers [32]uint32 `json:"registers"`
// LastHint is optional metadata, and not part of the VM state itself.
// It is used to remember the last pre-image hint,
// so a VM can start from any state without fetching prior pre-images,
// and instead just repeat the last hint on setup,
// to make sure pre-image requests can be served.
LastHint hexutil.Bytes `json:"lastHint,omitempty"`
} }
func (s *State) EncodeWitness() []byte { func (s *State) EncodeWitness() []byte {
...@@ -49,5 +57,3 @@ func (s *State) EncodeWitness() []byte { ...@@ -49,5 +57,3 @@ func (s *State) EncodeWitness() []byte {
} }
return out return out
} }
// TODO convert access-list to calldata and state-sets for EVM
...@@ -58,7 +58,7 @@ func TestState(t *testing.T) { ...@@ -58,7 +58,7 @@ func TestState(t *testing.T) {
err = LoadUnicorn(state, mu) err = LoadUnicorn(state, mu)
require.NoError(t, err, "load state into unicorn") require.NoError(t, err, "load state into unicorn")
us, err := NewUnicornState(mu, state, os.Stdout, os.Stderr) us, err := NewUnicornState(mu, state, nil, os.Stdout, os.Stderr)
require.NoError(t, err, "hook unicorn to state") require.NoError(t, err, "hook unicorn to state")
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
...@@ -92,7 +92,7 @@ func TestMinimal(t *testing.T) { ...@@ -92,7 +92,7 @@ func TestMinimal(t *testing.T) {
err = LoadUnicorn(state, mu) err = LoadUnicorn(state, mu)
require.NoError(t, err, "load state into unicorn") require.NoError(t, err, "load state into unicorn")
var stdOutBuf, stdErrBuf bytes.Buffer var stdOutBuf, stdErrBuf bytes.Buffer
us, err := NewUnicornState(mu, state, io.MultiWriter(&stdOutBuf, os.Stdout), io.MultiWriter(&stdErrBuf, os.Stderr)) us, err := NewUnicornState(mu, state, nil, io.MultiWriter(&stdOutBuf, os.Stdout), io.MultiWriter(&stdErrBuf, os.Stderr))
require.NoError(t, err, "hook unicorn to state") require.NoError(t, err, "hook unicorn to state")
for i := 0; i < 400_000; i++ { for i := 0; i < 400_000; i++ {
......
...@@ -10,6 +10,11 @@ import ( ...@@ -10,6 +10,11 @@ import (
uc "github.com/unicorn-engine/unicorn/bindings/go/unicorn" uc "github.com/unicorn-engine/unicorn/bindings/go/unicorn"
) )
type PreimageOracle interface {
Hint(v []byte)
GetPreimage(k [32]byte) []byte
}
type UnicornState struct { type UnicornState struct {
sync.Mutex sync.Mutex
...@@ -24,16 +29,41 @@ type UnicornState struct { ...@@ -24,16 +29,41 @@ type UnicornState struct {
memProofEnabled bool memProofEnabled bool
memProof [28 * 32]byte memProof [28 * 32]byte
preimageOracle PreimageOracle
// number of bytes last read from the oracle.
// The read data is preimage[state.PreimageOffset-lastPreimageRead : state.PreimageOffset]
// when inspecting the post-step state.
lastPreimageRead uint32
// cached pre-image data for state.PreimageKey
lastPreimage []byte
onStep func() onStep func()
} }
// TODO add pre-image oracle const (
func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*UnicornState, error) { fdStdin = 0
fdStdout = 1
fdStderr = 2
fdHintRead = 3
fdHintWrite = 4
fdPreimageRead = 5
fdPreimageWrite = 6
)
const (
MipsEBADF = 0x9
MipsEINVAL = 0x16
)
func NewUnicornState(mu uc.Unicorn, state *State, po PreimageOracle, stdOut, stdErr io.Writer) (*UnicornState, error) {
m := &UnicornState{ m := &UnicornState{
mu: mu, mu: mu,
state: state, state: state,
stdOut: stdOut, stdOut: stdOut,
stdErr: stdErr, stdErr: stdErr,
preimageOracle: po,
} }
st := m.state st := m.state
...@@ -43,33 +73,48 @@ func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*Un ...@@ -43,33 +73,48 @@ func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*Un
log.Fatal("invalid interrupt ", intno, " at step ", st.Step) log.Fatal("invalid interrupt ", intno, " at step ", st.Step)
} }
syscallNum, _ := mu.RegRead(uc.MIPS_REG_V0) syscallNum := st.Registers[2] // v0
v0 := uint32(0)
//v1 := uint32(0)
a0 := st.Registers[4]
a1 := st.Registers[5]
a2 := st.Registers[6]
fmt.Printf("syscall: %d\n", syscallNum) fmt.Printf("syscall: %d\n", syscallNum)
v0 := uint64(0)
switch syscallNum { switch syscallNum {
case 4004: // write case 4004: // write
fd, _ := mu.RegRead(uc.MIPS_REG_A0) fd := a0
addr, _ := mu.RegRead(uc.MIPS_REG_A1) addr := a1
count, _ := mu.RegRead(uc.MIPS_REG_A2) count := a2
switch fd { switch fd {
case 1: case fdStdout:
_, _ = io.Copy(stdOut, st.Memory.ReadMemoryRange(uint32(addr), uint32(count))) _, _ = io.Copy(stdOut, st.Memory.ReadMemoryRange(addr, count))
case 2: v0 = count
_, _ = io.Copy(stdErr, st.Memory.ReadMemoryRange(uint32(addr), uint32(count))) case fdStderr:
_, _ = io.Copy(stdErr, st.Memory.ReadMemoryRange(addr, count))
v0 = count
case fdHintWrite:
hint, _ := io.ReadAll(st.Memory.ReadMemoryRange(addr, count))
v0 = count
po.Hint(hint)
case fdPreimageWrite:
// TODO
v0 = count
default: default:
v0 = 0xFFffFFff
//v1 = MipsEBADF
// ignore other output data // ignore other output data
} }
case 4090: // mmap case 4090: // mmap
a0, _ := mu.RegRead(uc.MIPS_REG_A0) sz := a1
sz, _ := mu.RegRead(uc.MIPS_REG_A1)
if sz&pageAddrMask != 0 { // adjust size to align with page size if sz&pageAddrMask != 0 { // adjust size to align with page size
sz += pageSize - (sz & pageAddrMask) sz += pageSize - (sz & pageAddrMask)
} }
if a0 == 0 { if a0 == 0 {
v0 = uint64(st.Heap) v0 = st.Heap
fmt.Printf("mmap heap 0x%x size 0x%x\n", v0, sz) fmt.Printf("mmap heap 0x%x size 0x%x\n", v0, sz)
st.Heap += uint32(sz) st.Heap += sz
} else { } else {
v0 = a0 v0 = a0
fmt.Printf("mmap hint 0x%x size 0x%x\n", v0, sz) fmt.Printf("mmap hint 0x%x size 0x%x\n", v0, sz)
...@@ -77,9 +122,9 @@ func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*Un ...@@ -77,9 +122,9 @@ func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*Un
// Go does this thing where it first gets memory with PROT_NONE, // Go does this thing where it first gets memory with PROT_NONE,
// and then mmaps with a hint with prot=3 (PROT_READ|WRITE). // and then mmaps with a hint with prot=3 (PROT_READ|WRITE).
// We can ignore the NONE case, to avoid duplicate/overlapping mmap calls to unicorn. // We can ignore the NONE case, to avoid duplicate/overlapping mmap calls to unicorn.
prot, _ := mu.RegRead(uc.MIPS_REG_A2) prot := a2
if prot != 0 { if prot != 0 {
if err := mu.MemMap(v0, sz); err != nil { if err := mu.MemMap(uint64(v0), uint64(sz)); err != nil {
log.Fatalf("mmap fail: %v", err) log.Fatalf("mmap fail: %v", err)
} }
} }
...@@ -91,7 +136,7 @@ func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*Un ...@@ -91,7 +136,7 @@ func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*Un
st.ExitCode = uint8(v0) st.ExitCode = uint8(v0)
return return
} }
mu.RegWrite(uc.MIPS_REG_V0, v0) mu.RegWrite(uc.MIPS_REG_V0, uint64(v0))
mu.RegWrite(uc.MIPS_REG_A3, 0) mu.RegWrite(uc.MIPS_REG_A3, 0)
}, 0, ^uint64(0)) }, 0, ^uint64(0))
if err != nil { if err != nil {
...@@ -165,15 +210,16 @@ func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*Un ...@@ -165,15 +210,16 @@ func NewUnicornState(mu uc.Unicorn, state *State, stdOut, stdErr io.Writer) (*Un
return m, nil return m, nil
} }
func (m *UnicornState) Step(proof bool) (stateWitness []byte, memProof []byte) { func (m *UnicornState) Step(proof bool) (wit *StepWitness) {
m.memProofEnabled = proof m.memProofEnabled = proof
m.lastMemAccess = ^uint32(0) m.lastMemAccess = ^uint32(0)
if proof { if proof {
stateWitness = m.state.EncodeWitness()
insnProof := m.state.Memory.MerkleProof(m.state.PC) insnProof := m.state.Memory.MerkleProof(m.state.PC)
memProof = append(memProof, insnProof[:]...) wit = &StepWitness{
state: m.state.EncodeWitness(),
memProof: insnProof[:],
}
} }
insn := m.state.Memory.GetMemory(m.state.PC) insn := m.state.Memory.GetMemory(m.state.PC)
...@@ -222,7 +268,12 @@ func (m *UnicornState) Step(proof bool) (stateWitness []byte, memProof []byte) { ...@@ -222,7 +268,12 @@ func (m *UnicornState) Step(proof bool) (stateWitness []byte, memProof []byte) {
}) })
if proof { if proof {
memProof = append(memProof, m.memProof[:]...) wit.memProof = append(wit.memProof, m.memProof[:]...)
if m.lastPreimageRead > 0 {
wit.preimageOffset = m.state.PreimageOffset
wit.preimageKey = m.state.PreimageKey
wit.preimageValue = m.lastPreimage
}
} }
// count it // count it
......
package mipsevm
import (
"encoding/binary"
"github.com/ethereum/go-ethereum/crypto"
)
type StepWitness struct {
state []byte
memProof []byte
preimageKey [32]byte // zeroed when no pre-image is accessed
preimageValue []byte
preimageOffset uint32
}
func uint32ToBytes32(v uint32) []byte {
var out [32]byte
binary.BigEndian.PutUint32(out[32-4:], v)
return out[:]
}
func (wit *StepWitness) EncodeStepInput() []byte {
stateHash := crypto.Keccak256Hash(wit.state)
var input []byte
input = append(input, StepBytes4...)
input = append(input, stateHash[:]...)
input = append(input, uint32ToBytes32(32*3)...) // state data offset in bytes
input = append(input, uint32ToBytes32(32*3+32+uint32(len(wit.state)))...) // proof data offset in bytes
input = append(input, uint32ToBytes32(uint32(len(wit.state)))...) // state data length in bytes
input = append(input, wit.state[:]...)
input = append(input, uint32ToBytes32(uint32(len(wit.memProof)))...) // proof data length in bytes
input = append(input, wit.memProof[:]...)
return input
}
func (wit *StepWitness) EncodePreimageOracleInput() []byte {
if wit.preimageKey == ([32]byte{}) {
return nil
}
// TODO
return nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment