Commit f27fd058 authored by protolambda's avatar protolambda

mipsevm: merkle-tree-cached memory and merkle proof generation

parent b2d410d6
...@@ -49,8 +49,8 @@ func TestEVM(t *testing.T) { ...@@ -49,8 +49,8 @@ func TestEVM(t *testing.T) {
fn := path.Join("test/bin", f.Name()) fn := path.Join("test/bin", f.Name())
programMem, err := os.ReadFile(fn) programMem, err := os.ReadFile(fn)
state := &State{PC: 0, NextPC: 4, Memory: make(map[uint32]*Page)} state := &State{PC: 0, NextPC: 4, Memory: NewMemory()}
err = state.SetMemoryRange(0, bytes.NewReader(programMem)) err = state.Memory.SetMemoryRange(0, bytes.NewReader(programMem))
require.NoError(t, err, "load program into state") require.NoError(t, err, "load program into state")
// set the return address ($ra) to jump into when test completes // set the return address ($ra) to jump into when test completes
...@@ -63,44 +63,32 @@ func TestEVM(t *testing.T) { ...@@ -63,44 +63,32 @@ func TestEVM(t *testing.T) {
require.NoError(t, mu.MemMap(baseAddrStart, ((baseAddrEnd-baseAddrStart)&^pageAddrMask)+pageSize)) require.NoError(t, mu.MemMap(baseAddrStart, ((baseAddrEnd-baseAddrStart)&^pageAddrMask)+pageSize))
require.NoError(t, mu.MemMap(endAddr&^pageAddrMask, pageSize)) require.NoError(t, mu.MemMap(endAddr&^pageAddrMask, pageSize))
al := &AccessList{} al := &AccessList{mem: state.Memory}
err = LoadUnicorn(state, mu) err = LoadUnicorn(state, mu)
require.NoError(t, err, "load state into unicorn") require.NoError(t, err, "load state into unicorn")
err = HookUnicorn(state, mu, os.Stdout, os.Stderr, al) err = HookUnicorn(state, mu, os.Stdout, os.Stderr, al)
require.NoError(t, err, "hook unicorn to state") require.NoError(t, err, "hook unicorn to state")
so := NewStateCache()
var stateData []byte var stateData []byte
var insn uint32 var insn uint32
var pc uint32 var pc uint32
var post []byte var post []byte
preCode := func() { preCode := func() {
insn = state.GetMemory(state.PC) insn = state.Memory.GetMemory(state.PC)
pc = state.PC pc = state.PC
fmt.Printf("PRE - pc: %08x insn: %08x\n", pc, insn) fmt.Printf("PRE - pc: %08x insn: %08x\n", pc, insn)
// remember the pre-state, to repeat it in the EVM during the post processing step // remember the pre-state, to repeat it in the EVM during the post processing step
stateData = state.EncodeWitness(so) stateData = state.EncodeWitness()
if post != nil { if post != nil {
require.Equal(t, hexutil.Bytes(stateData).String(), hexutil.Bytes(post).String(), require.Equal(t, hexutil.Bytes(stateData).String(), hexutil.Bytes(post).String(),
"unicorn produced different state than EVM") "unicorn produced different state than EVM")
} }
al.Reset() // reset access list
} }
postCode := func() { postCode := func() {
fmt.Printf("POST - pc: %08x insn: %08x\n", pc, insn) fmt.Printf("POST - pc: %08x insn: %08x\n", pc, insn)
var proofData []byte proofData := append([]byte(nil), al.proofData...)
proofData = binary.BigEndian.AppendUint32(proofData, insn)
if len(al.memReads) > 0 {
proofData = binary.BigEndian.AppendUint32(proofData, al.memReads[0].PreValue)
} else if len(al.memWrites) > 0 {
proofData = binary.BigEndian.AppendUint32(proofData, al.memWrites[0].PreValue)
} else {
proofData = append(proofData, make([]byte, 4)...)
}
proofData = append(proofData, make([]byte, 32-4-4)...)
stateHash := crypto.Keccak256Hash(stateData) stateHash := crypto.Keccak256Hash(stateData)
var input []byte var input []byte
...@@ -148,7 +136,7 @@ func TestEVM(t *testing.T) { ...@@ -148,7 +136,7 @@ func TestEVM(t *testing.T) {
require.NoError(t, err, "must run steps without error") require.NoError(t, err, "must run steps without error")
// inspect test result // inspect test result
done, result := state.GetMemory(baseAddrEnd+4), state.GetMemory(baseAddrEnd+8) done, result := state.Memory.GetMemory(baseAddrEnd+4), state.Memory.GetMemory(baseAddrEnd+8)
require.Equal(t, done, uint32(1), "must be done") require.Equal(t, done, uint32(1), "must be done")
require.Equal(t, result, uint32(1), "must have success result") require.Equal(t, result, uint32(1), "must have success result")
}) })
......
package mipsevm
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math/bits"
"sort"
"github.com/ethereum/go-ethereum/crypto"
)
const (
// Note: 2**12 = 4 KiB, the minimum page-size in Unicorn for mmap
// as well as the Go runtime min phys page size.
pageAddrSize = 12
pageKeySize = 32 - pageAddrSize
pageSize = 1 << pageAddrSize
pageAddrMask = pageSize - 1
maxPageCount = 1 << pageKeySize
pageKeyMask = maxPageCount - 1
)
func HashPair(left, right [32]byte) [32]byte {
return crypto.Keccak256Hash(left[:], right[:])
}
var zeroHashes = func() [256][32]byte {
// empty parts of the tree are all zero. Precompute the hash of each full-zero range sub-tree level.
var out [256][32]byte
for i := 1; i < 256; i++ {
out[i] = HashPair(out[i-1], out[i-1])
}
return out
}()
type Memory struct {
// generalized index -> merkle root or nil if invalidated
Nodes map[uint64]*[32]byte
// pageIndex -> cached page
Pages map[uint32]*CachedPage
// Note: since we don't de-alloc pages, we don't do ref-counting.
// Once a page exists, it doesn't leave memory
}
func NewMemory() *Memory {
return &Memory{
Nodes: make(map[uint64]*[32]byte),
Pages: make(map[uint32]*CachedPage),
}
}
func (m *Memory) Invalidate(addr uint32, count uint32) {
// we invalidate nodes of 32 bytes at a time
minGindex := ((uint64(1) << 32) | uint64(addr)) >> 5
count >>= 5
for minGindex > 0 {
for i := minGindex; i < minGindex+uint64(count); i++ {
m.Nodes[i] = nil
}
minGindex >>= 1
count >>= 1
if count == 0 {
count = 1
}
}
}
func (m *Memory) MerkleizeSubtree(gindex uint64) [32]byte {
l := uint64(bits.Len64(gindex))
if l > 28 {
panic("gindex too deep")
}
if l > pageKeySize {
depthIntoPage := l - pageKeySize
pageIndex := (gindex >> depthIntoPage) & pageKeyMask
if p, ok := m.Pages[uint32(pageIndex)]; ok {
pageGindex := (1 << depthIntoPage) | (gindex & ((1 << depthIntoPage) - 1))
return p.MerkleizeSubtree(pageGindex)
} else {
return zeroHashes[28-l] // page does not exist
}
}
if l > pageKeySize+1 {
panic("cannot jump into intermediate node of page")
}
n, ok := m.Nodes[gindex]
if !ok {
// if the node doesn't exist, the whole sub-tree is zeroed
return zeroHashes[28-l]
}
if n != nil {
return *n
}
left := m.MerkleizeSubtree(gindex << 1)
right := m.MerkleizeSubtree((gindex << 1) | 1)
r := HashPair(left, right)
m.Nodes[gindex] = &r
return r
}
func (m *Memory) MerkleProof(addr uint32) (out [28 * 32]byte) {
proof := m.traverseBranch(1, addr, 0)
// encode the proof
for i := 0; i < 28; i++ {
copy(out[i*32:(i+1)*32], proof[i][:])
}
return out
}
func (m *Memory) traverseBranch(parent uint64, addr uint32, depth uint8) (proof [][32]byte) {
if depth == 28 {
proof = make([][32]byte, 0, 32-5+1)
proof = append(proof, m.MerkleizeSubtree(parent))
return
}
if depth > 28 {
panic("traversed too deep")
}
self := parent << 1
sibling := self | 1
if addr&(1<<depth) == 1 {
self, sibling = sibling, self
}
proof = m.traverseBranch(self, addr, depth+1)
siblingNode := m.MerkleizeSubtree(sibling)
proof = append(proof, siblingNode)
return
}
func (m *Memory) MerkleRoot() [32]byte {
return m.MerkleizeSubtree(1)
}
func (m *Memory) SetMemory(addr uint32, v uint32) {
// addr must be aligned to 4 bytes
if addr&0x3 != 0 {
panic(fmt.Errorf("unaligned memory access: %x", addr))
}
pageIndex := addr >> pageAddrSize
pageAddr := addr & pageAddrMask
p, ok := m.Pages[pageIndex]
if !ok {
// allocate the page if we have not already.
// Go may mmap relatively large ranges, but we only allocate the pages just in time.
p = m.AllocPage(pageIndex)
} else {
m.Invalidate(addr, 4) // invalidate this branch of memory, now that the value changed
}
binary.BigEndian.PutUint32(p.Data[pageAddr:pageAddr+4], v)
}
func (m *Memory) GetMemory(addr uint32) uint32 {
// addr must be aligned to 4 bytes
if addr&0x3 != 0 {
panic(fmt.Errorf("unaligned memory access: %x", addr))
}
p, ok := m.Pages[addr>>pageAddrSize]
if !ok {
return 0
}
pageAddr := addr & pageAddrMask
return binary.BigEndian.Uint32(p.Data[pageAddr : pageAddr+4])
}
func (m *Memory) AllocPage(pageIndex uint32) *CachedPage {
p := &CachedPage{Data: new(Page)}
m.Pages[pageIndex] = p
// make nodes to root
k := (1 << pageKeySize) | uint64(pageIndex)
for k > 0 {
m.Nodes[k] = nil
k >>= 1
}
return p
}
type pageEntry struct {
Index uint32 `json:"index"`
Data *Page `json:"data"`
}
func (m *Memory) MarshalJSON() ([]byte, error) {
pages := make([]pageEntry, 0, len(m.Pages))
for k, p := range m.Pages {
pages = append(pages, pageEntry{
Index: k,
Data: p.Data,
})
}
sort.Slice(pages, func(i, j int) bool {
return pages[i].Index < pages[j].Index
})
return json.Marshal(pages)
}
func (m *Memory) UnmarshalJSON(data []byte) error {
var pages []pageEntry
if err := json.Unmarshal(data, &pages); err != nil {
return err
}
m.Nodes = make(map[uint64]*[32]byte)
m.Pages = make(map[uint32]*CachedPage)
for i, p := range pages {
if _, ok := m.Pages[p.Index]; ok {
return fmt.Errorf("cannot load duplicate page, entry %d, page index %d", i, p.Index)
}
m.Pages[p.Index] = &CachedPage{Data: p.Data}
}
return nil
}
func (m *Memory) SetMemoryRange(addr uint32, r io.Reader) error {
for {
pageIndex := addr >> pageAddrSize
pageAddr := addr & pageAddrMask
p, ok := m.Pages[pageIndex]
if !ok {
p = m.AllocPage(pageIndex)
}
p.InvalidateFull()
n, err := r.Read(p.Data[pageAddr:])
if err != nil {
if err == io.EOF {
return nil
}
return err
}
addr += uint32(n)
}
}
type memReader struct {
m *Memory
addr uint32
count uint32
}
func (r *memReader) Read(dest []byte) (n int, err error) {
if r.count == 0 {
return 0, io.EOF
}
// Keep iterating over memory until we have all our data.
// It may wrap around the address range, and may not be aligned
endAddr := r.addr + r.count
pageIndex := r.addr >> pageAddrSize
start := r.addr & pageAddrMask
end := uint32(pageSize)
if pageIndex == (endAddr >> pageAddrSize) {
end = endAddr & pageAddrMask
}
p, ok := r.m.Pages[pageIndex]
if ok {
n = copy(dest, p.Data[start:end])
} else {
n = copy(dest, make([]byte, end-start)) // default to zeroes
}
r.addr += uint32(n)
r.count -= uint32(n)
return n, nil
}
func (m *Memory) ReadMemoryRange(addr uint32, count uint32) io.Reader {
return &memReader{m: m, addr: addr, count: count}
}
package mipsevm
import (
"encoding/hex"
"fmt"
"github.com/ethereum/go-ethereum/crypto"
)
type Page [pageSize]byte
func (p *Page) MarshalText() ([]byte, error) {
dst := make([]byte, hex.EncodedLen(len(p)))
hex.Encode(dst, p[:])
return dst, nil
}
func (p *Page) UnmarshalText(dat []byte) error {
if len(dat) != pageSize*2 {
return fmt.Errorf("expected %d hex chars, but got %d", pageSize*2, len(dat))
}
_, err := hex.Decode(p[:], dat)
return err
}
type CachedPage struct {
Data *Page
// intermediate nodes only
Cache [pageSize / 32][32]byte
// true if the intermediate node is valid
Ok [pageSize / 32]bool
}
func (p *CachedPage) Invalidate(pageAddr uint32) {
if pageAddr >= pageSize {
panic("invalid page addr")
}
k := (1 << pageAddrSize) | pageAddr
// first cache layer caches nodes that has two 32 byte leaf nodes.
k >>= 5 + 1
for k > 0 {
p.Ok[k] = false
k >>= 1
}
}
func (p *CachedPage) InvalidateFull() {
p.Ok = [pageSize / 32]bool{} // reset everything to false
}
func (p *CachedPage) MerkleRoot() [32]byte {
// hash the bottom layer
for i := uint64(0); i < pageSize; i += 64 {
j := pageSize/32/2 + i/64
if p.Ok[j] {
continue
}
p.Cache[j] = crypto.Keccak256Hash(p.Data[i : i+64])
p.Ok[j] = true
}
// hash the cache layers
for i := pageSize/32 - 2; i > 0; i++ {
j := i >> 1
if p.Ok[j] {
continue
}
p.Cache[j] = HashPair(p.Cache[i], p.Cache[i+1])
p.Ok[j] = true
}
return p.Cache[1]
}
func (p *CachedPage) MerkleizeSubtree(gindex uint64) [32]byte {
_ = p.MerkleRoot() // fill cache
if gindex >= pageSize/32 {
if gindex >= pageSize/32*2 {
panic("gindex too deep")
}
// it's pointing to a bottom node
nodeIndex := gindex & pageAddrMask
return *(*[32]byte)(p.Data[nodeIndex*32 : nodeIndex*32+32])
}
return p.Cache[gindex]
}
...@@ -16,7 +16,7 @@ func LoadELF(f *elf.File) (*State, error) { ...@@ -16,7 +16,7 @@ func LoadELF(f *elf.File) (*State, error) {
LO: 0, LO: 0,
Heap: 0x20000000, Heap: 0x20000000,
Registers: [32]uint32{}, Registers: [32]uint32{},
Memory: make(map[uint32]*Page), Memory: NewMemory(),
ExitCode: 0, ExitCode: 0,
Exited: false, Exited: false,
Step: 0, Step: 0,
...@@ -43,7 +43,7 @@ func LoadELF(f *elf.File) (*State, error) { ...@@ -43,7 +43,7 @@ func LoadELF(f *elf.File) (*State, error) {
if prog.Vaddr+prog.Memsz >= uint64(1<<32) { if prog.Vaddr+prog.Memsz >= uint64(1<<32) {
return nil, fmt.Errorf("program %d out of 32-bit mem range: %x - %x (size: %x)", i, prog.Vaddr, prog.Vaddr+prog.Memsz, prog.Memsz) return nil, fmt.Errorf("program %d out of 32-bit mem range: %x - %x (size: %x)", i, prog.Vaddr, prog.Vaddr+prog.Memsz, prog.Memsz)
} }
if err := s.SetMemoryRange(uint32(prog.Vaddr), r); err != nil { if err := s.Memory.SetMemoryRange(uint32(prog.Vaddr), r); err != nil {
return nil, fmt.Errorf("failed to read program segment %d: %w", i, err) return nil, fmt.Errorf("failed to read program segment %d: %w", i, err)
} }
} }
...@@ -70,14 +70,14 @@ func patchVM(f *elf.File, st *State) error { ...@@ -70,14 +70,14 @@ func patchVM(f *elf.File, st *State) error {
// MIPS32 patch: ret (pseudo instruction) // MIPS32 patch: ret (pseudo instruction)
// 03e00008 = jr $ra = ret (pseudo instruction) // 03e00008 = jr $ra = ret (pseudo instruction)
// 00000000 = invalid, make sure it never enters the actual function // 00000000 = invalid, make sure it never enters the actual function
if err := st.SetMemoryRange(uint32(s.Value), bytes.NewReader([]byte{ if err := st.Memory.SetMemoryRange(uint32(s.Value), bytes.NewReader([]byte{
0x03, 0xe0, 0x00, 0x08, 0x03, 0xe0, 0x00, 0x08,
0, 0, 0, 0, 0, 0, 0, 0,
})); err != nil { })); err != nil {
return fmt.Errorf("failed to patch Go runtime.gcenable: %w", err) return fmt.Errorf("failed to patch Go runtime.gcenable: %w", err)
} }
case "runtime.MemProfileRate": case "runtime.MemProfileRate":
if err := st.SetMemoryRange(uint32(s.Value), bytes.NewReader(make([]byte, 4))); err != nil { // disable mem profiling, to avoid a lot of unnecessary floating point ops if err := st.Memory.SetMemoryRange(uint32(s.Value), bytes.NewReader(make([]byte, 4))); err != nil { // disable mem profiling, to avoid a lot of unnecessary floating point ops
return err return err
} }
} }
...@@ -86,7 +86,7 @@ func patchVM(f *elf.File, st *State) error { ...@@ -86,7 +86,7 @@ func patchVM(f *elf.File, st *State) error {
// setup stack pointer // setup stack pointer
sp := uint32(0x7f_ff_d0_00) sp := uint32(0x7f_ff_d0_00)
// allocate 1 page for the initial stack data, and 16KB = 4 pages for the stack to grow // allocate 1 page for the initial stack data, and 16KB = 4 pages for the stack to grow
if err := st.SetMemoryRange(sp-4*pageSize, bytes.NewReader(make([]byte, 5*pageSize))); err != nil { if err := st.Memory.SetMemoryRange(sp-4*pageSize, bytes.NewReader(make([]byte, 5*pageSize))); err != nil {
return fmt.Errorf("failed to allocate page for stack content") return fmt.Errorf("failed to allocate page for stack content")
} }
st.Registers[29] = sp st.Registers[29] = sp
...@@ -94,7 +94,7 @@ func patchVM(f *elf.File, st *State) error { ...@@ -94,7 +94,7 @@ func patchVM(f *elf.File, st *State) error {
storeMem := func(addr uint32, v uint32) { storeMem := func(addr uint32, v uint32) {
var dat [4]byte var dat [4]byte
binary.BigEndian.PutUint32(dat[:], v) binary.BigEndian.PutUint32(dat[:], v)
_ = st.SetMemoryRange(addr, bytes.NewReader(dat[:])) _ = st.Memory.SetMemoryRange(addr, bytes.NewReader(dat[:]))
} }
// init argc, argv, aux on stack // init argc, argv, aux on stack
...@@ -107,7 +107,7 @@ func patchVM(f *elf.File, st *State) error { ...@@ -107,7 +107,7 @@ func patchVM(f *elf.File, st *State) error {
storeMem(sp+4*7, sp+4*9) // auxv[3] = address of 16 bytes containing random value storeMem(sp+4*7, sp+4*9) // auxv[3] = address of 16 bytes containing random value
storeMem(sp+4*8, 0) // auxv[term] = 0 storeMem(sp+4*8, 0) // auxv[term] = 0
_ = st.SetMemoryRange(sp+4*9, bytes.NewReader([]byte("4;byfairdiceroll"))) // 16 bytes of "randomness" _ = st.Memory.SetMemoryRange(sp+4*9, bytes.NewReader([]byte("4;byfairdiceroll"))) // 16 bytes of "randomness"
return nil return nil
} }
...@@ -2,41 +2,12 @@ package mipsevm ...@@ -2,41 +2,12 @@ package mipsevm
import ( import (
"encoding/binary" "encoding/binary"
"encoding/hex"
"fmt"
"io"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
) )
const (
// Note: 2**12 = 4 KiB, the minimum page-size in Unicorn for mmap
// as well as the Go runtime min phys page size.
pageAddrSize = 12
pageKeySize = 32 - pageAddrSize
pageSize = 1 << pageAddrSize
pageAddrMask = pageSize - 1
maxPageCount = 1 << pageKeySize
)
type Page [pageSize]byte
func (p *Page) MarshalText() ([]byte, error) {
dst := make([]byte, hex.EncodedLen(len(p)))
hex.Encode(dst, p[:])
return dst, nil
}
func (p *Page) UnmarshalText(dat []byte) error {
if len(dat) != pageSize*2 {
return fmt.Errorf("expected %d hex chars, but got %d", pageSize*2, len(dat))
}
_, err := hex.Decode(p[:], dat)
return err
}
type State struct { type State struct {
Memory map[uint32]*Page `json:"memory"` Memory *Memory `json:"memory"`
PreimageKey common.Hash `json:"preimageKey"` PreimageKey common.Hash `json:"preimageKey"`
PreimageOffset uint32 `json:"preimageOffset"` PreimageOffset uint32 `json:"preimageOffset"`
...@@ -55,10 +26,9 @@ type State struct { ...@@ -55,10 +26,9 @@ type State struct {
Registers [32]uint32 `json:"registers"` Registers [32]uint32 `json:"registers"`
} }
func (s *State) EncodeWitness(so StateOracle) []byte { func (s *State) EncodeWitness() []byte {
out := make([]byte, 0) out := make([]byte, 0)
memRoot := s.MerkleizeMemory(so) memRoot := s.Memory.MerkleRoot()
memRoot = common.Hash{31: 42} // TODO need contract to actually write memory
out = append(out, memRoot[:]...) out = append(out, memRoot[:]...)
out = append(out, s.PreimageKey[:]...) out = append(out, s.PreimageKey[:]...)
out = binary.BigEndian.AppendUint32(out, s.PreimageOffset) out = binary.BigEndian.AppendUint32(out, s.PreimageOffset)
...@@ -80,153 +50,4 @@ func (s *State) EncodeWitness(so StateOracle) []byte { ...@@ -80,153 +50,4 @@ func (s *State) EncodeWitness(so StateOracle) []byte {
return out return out
} }
func (s *State) MerkleizeMemory(so StateOracle) [32]byte {
// empty parts of the tree are all zero. Precompute the hash of each full-zero range sub-tree level.
var zeroHashes [256][32]byte
for i := 1; i < 256; i++ {
zeroHashes[i] = so.Remember(zeroHashes[i-1], zeroHashes[i-1])
}
// for each page, remember the generalized indices leading up to that page in the memory tree,
// so we can deduplicate work.
pageBranches := make(map[uint64]struct{})
for pageKey := range s.Memory {
pageGindex := (1 << pageKeySize) | uint64(pageKey)
for i := 0; i < pageKeySize; i++ {
gindex := pageGindex >> i
pageBranches[gindex] = struct{}{}
}
}
// helper func to merkleize a complete power-of-2 subtree, with stack-wise operation
merkleize := func(stackDepth uint64, getItem func(index uint64) [32]byte) [32]byte {
stack := make([][32]byte, stackDepth+1)
for i := uint64(0); i < (1 << stackDepth); i++ {
v := getItem(i)
for j := uint64(0); j <= stackDepth; j++ {
if i&(1<<j) == 0 {
stack[j] = v
break
} else {
v = so.Remember(stack[j], v)
}
}
}
return stack[stackDepth]
}
merkleizePage := func(page *Page) [32]byte {
return merkleize(pageAddrSize-5, func(index uint64) [32]byte { // 32 byte leaf values (5 bits)
return *(*[32]byte)(page[index*32 : index*32+32])
})
}
// Function to merkleize a memory sub-tree. Once it reaches the depth of a specific page, it merkleizes as page.
var merkleizeMemory func(gindex uint64, depth uint64) [32]byte
merkleizeMemory = func(gindex uint64, depth uint64) [32]byte {
if depth == pageKeySize {
pageKey := uint32(gindex & ((1 << pageKeySize) - 1))
return merkleizePage(s.Memory[pageKey])
}
left := gindex << 1
right := left | 1
var leftRoot, rightRoot [32]byte
if _, ok := pageBranches[left]; ok {
leftRoot = merkleizeMemory(left, depth+1)
} else {
leftRoot = zeroHashes[pageKeySize-(depth+1)+(pageAddrSize-5)]
}
if _, ok := pageBranches[right]; ok {
rightRoot = merkleizeMemory(right, depth+1)
} else {
rightRoot = zeroHashes[pageKeySize-(depth+1)+(pageAddrSize-5)]
}
return so.Remember(leftRoot, rightRoot)
}
return merkleizeMemory(1, 0)
}
func (s *State) SetMemory(addr uint32, v uint32) {
// addr must be aligned to 4 bytes
if addr&0x3 != 0 {
panic(fmt.Errorf("unaligned memory access: %x", addr))
}
pageIndex := addr >> pageAddrSize
pageAddr := addr & pageAddrMask
p, ok := s.Memory[pageIndex]
if !ok {
// allocate the page if we have not already.
// Go may mmap relatively large ranges, but we only allocate the pages just in time.
p = &Page{}
s.Memory[pageIndex] = p
}
binary.BigEndian.PutUint32(p[pageAddr:pageAddr+4], v)
}
func (s *State) GetMemory(addr uint32) uint32 {
// addr must be aligned to 4 bytes
if addr&0x3 != 0 {
panic(fmt.Errorf("unaligned memory access: %x", addr))
}
p, ok := s.Memory[addr>>pageAddrSize]
if !ok {
return 0
}
pageAddr := addr & pageAddrMask
return binary.BigEndian.Uint32(p[pageAddr : pageAddr+4])
}
func (s *State) SetMemoryRange(addr uint32, r io.Reader) error {
for {
pageIndex := addr >> pageAddrSize
pageAddr := addr & pageAddrMask
p, ok := s.Memory[pageIndex]
if !ok {
p = &Page{}
s.Memory[pageIndex] = p
}
n, err := r.Read(p[pageAddr:])
if err != nil {
if err == io.EOF {
return nil
}
return err
}
addr += uint32(n)
}
}
type memReader struct {
state *State
addr uint32
count uint32
}
func (r *memReader) Read(dest []byte) (n int, err error) {
if r.count == 0 {
return 0, io.EOF
}
// Keep iterating over memory until we have all our data.
// It may wrap around the address range, and may not be aligned
endAddr := r.addr + r.count
pageIndex := r.addr >> pageAddrSize
start := r.addr & pageAddrMask
end := uint32(pageSize)
if pageIndex == (endAddr >> pageAddrSize) {
end = endAddr & pageAddrMask
}
p, ok := r.state.Memory[pageIndex]
if ok {
n = copy(dest, p[start:end])
} else {
n = copy(dest, make([]byte, end-start)) // default to zeroes
}
r.addr += uint32(n)
r.count -= uint32(n)
return n, nil
}
func (s *State) ReadMemoryRange(addr uint32, count uint32) io.Reader {
return &memReader{state: s, addr: addr, count: count}
}
// TODO convert access-list to calldata and state-sets for EVM // TODO convert access-list to calldata and state-sets for EVM
package mipsevm
import (
"fmt"
"github.com/ethereum/go-ethereum/crypto"
)
type StateOracle interface {
Get(key [32]byte) (a, b [32]byte)
Remember(a, b [32]byte) [32]byte
}
type StateCache struct {
data map[[32]byte][2][32]byte
reverse map[[2][32]byte][32]byte
}
func NewStateCache() *StateCache {
return &StateCache{
data: make(map[[32]byte][2][32]byte),
reverse: make(map[[2][32]byte][32]byte),
}
}
func (s *StateCache) Get(key [32]byte) (a, b [32]byte) {
ab, ok := s.data[key]
if !ok {
panic(fmt.Errorf("missing key %x", key))
}
return ab[0], ab[1]
}
func (s *StateCache) Remember(left [32]byte, right [32]byte) [32]byte {
value := [2][32]byte{left, right}
if key, ok := s.reverse[value]; ok {
return key
}
key := crypto.Keccak256Hash(left[:], right[:])
s.data[key] = value
s.reverse[value] = key
return key
}
...@@ -37,8 +37,8 @@ func TestState(t *testing.T) { ...@@ -37,8 +37,8 @@ func TestState(t *testing.T) {
//state, err := LoadELF(elfProgram) //state, err := LoadELF(elfProgram)
//require.NoError(t, err, "must load ELF into state") //require.NoError(t, err, "must load ELF into state")
programMem, err := os.ReadFile(fn) programMem, err := os.ReadFile(fn)
state := &State{PC: 0, NextPC: 4, Memory: make(map[uint32]*Page)} state := &State{PC: 0, NextPC: 4, Memory: NewMemory()}
err = state.SetMemoryRange(0, bytes.NewReader(programMem)) err = state.Memory.SetMemoryRange(0, bytes.NewReader(programMem))
require.NoError(t, err, "load program into state") require.NoError(t, err, "load program into state")
// set the return address ($ra) to jump into when test completes // set the return address ($ra) to jump into when test completes
...@@ -74,7 +74,7 @@ func TestState(t *testing.T) { ...@@ -74,7 +74,7 @@ func TestState(t *testing.T) {
err = RunUnicorn(mu, state.PC, 1000) err = RunUnicorn(mu, state.PC, 1000)
require.NoError(t, err, "must run steps without error") require.NoError(t, err, "must run steps without error")
// inspect test result // inspect test result
done, result := state.GetMemory(baseAddrEnd+4), state.GetMemory(baseAddrEnd+8) done, result := state.Memory.GetMemory(baseAddrEnd+4), state.Memory.GetMemory(baseAddrEnd+8)
require.Equal(t, done, uint32(1), "must be done") require.Equal(t, done, uint32(1), "must be done")
require.Equal(t, result, uint32(1), "must have success result") require.Equal(t, result, uint32(1), "must have success result")
}) })
......
package mipsevm package mipsevm
import "fmt"
type MemEntry struct { type MemEntry struct {
EffAddr uint32 EffAddr uint32
PreValue uint32 PreValue uint32
} }
type AccessList struct { type AccessList struct {
memReads []MemEntry mem *Memory
memWrites []MemEntry
memAccessAddr uint32
proofData []byte
} }
func (al *AccessList) Reset() { func (al *AccessList) Reset() {
al.memReads = al.memReads[:0] al.memAccessAddr = ^uint32(0)
al.memWrites = al.memWrites[:0] al.proofData = nil
} }
func (al *AccessList) OnRead(effAddr uint32, preValue uint32) { func (al *AccessList) OnRead(effAddr uint32) {
// if it matches the last, it's a duplicate; this happens because of multiple callbacks for the same effective addr. if al.memAccessAddr == effAddr {
if len(al.memReads) > 0 && al.memReads[len(al.memReads)-1].EffAddr == effAddr {
return return
} }
al.memReads = append(al.memReads, MemEntry{EffAddr: effAddr, PreValue: preValue}) if al.memAccessAddr != ^uint32(0) {
panic(fmt.Errorf("bad read of %08x, already have %08x", effAddr, al.memAccessAddr))
}
al.memAccessAddr = effAddr
proof := al.mem.MerkleProof(effAddr)
al.proofData = append(al.proofData, proof[:]...)
} }
func (al *AccessList) OnWrite(effAddr uint32, preValue uint32) { func (al *AccessList) OnWrite(effAddr uint32) {
// if it matches the last, it's a duplicate; this happens because of multiple callbacks for the same effective addr. if al.memAccessAddr == effAddr {
if len(al.memWrites) > 0 && al.memWrites[len(al.memWrites)-1].EffAddr == effAddr {
return return
} }
al.memWrites = append(al.memWrites, MemEntry{EffAddr: effAddr, PreValue: preValue}) if al.memAccessAddr != ^uint32(0) {
panic(fmt.Errorf("bad write of %08x, already have %08x", effAddr, al.memAccessAddr))
}
proof := al.mem.MerkleProof(effAddr)
al.proofData = append(al.proofData, proof[:]...)
}
func (al *AccessList) PreInstruction(pc uint32) {
proof := al.mem.MerkleProof(pc)
al.proofData = append(al.proofData, proof[:]...)
} }
var _ Tracer = (*AccessList)(nil) var _ Tracer = (*AccessList)(nil)
...@@ -36,18 +53,22 @@ var _ Tracer = (*AccessList)(nil) ...@@ -36,18 +53,22 @@ var _ Tracer = (*AccessList)(nil)
type Tracer interface { type Tracer interface {
// OnRead remembers reads from the given effAddr. // OnRead remembers reads from the given effAddr.
// Warning: the addr is an effective-addr, i.e. always aligned. // Warning: the addr is an effective-addr, i.e. always aligned.
// But unicorn will fire it multiple times, for each byte that was changed within the effective addr boundaries. // But unicorn may fire it multiple times, for each byte that was changed within the effective addr boundaries.
OnRead(effAddr uint32, value uint32) OnRead(effAddr uint32)
// OnWrite remembers writes to the given effAddr. // OnWrite remembers writes to the given effAddr.
// Warning: the addr is an effective-addr, i.e. always aligned. // Warning: the addr is an effective-addr, i.e. always aligned.
// But unicorn will fire it multiple times, for each byte that was changed within the effective addr boundaries. // But unicorn may fire it multiple times, for each byte that was changed within the effective addr boundaries.
OnWrite(effAddr uint32, value uint32) OnWrite(effAddr uint32)
PreInstruction(pc uint32)
} }
type NoOpTracer struct{} type NoOpTracer struct{}
func (n NoOpTracer) OnRead(effAddr uint32, value uint32) {} func (n NoOpTracer) OnRead(effAddr uint32) {}
func (n NoOpTracer) OnWrite(effAddr uint32) {}
func (n NoOpTracer) OnWrite(effAddr uint32, value uint32) {} func (n NoOpTracer) PreInstruction(pc uint32) {}
var _ Tracer = NoOpTracer{} var _ Tracer = NoOpTracer{}
...@@ -15,12 +15,12 @@ func NewUnicorn() (uc.Unicorn, error) { ...@@ -15,12 +15,12 @@ func NewUnicorn() (uc.Unicorn, error) {
func LoadUnicorn(st *State, mu uc.Unicorn) error { func LoadUnicorn(st *State, mu uc.Unicorn) error {
// mmap and write each page of memory state into unicorn // mmap and write each page of memory state into unicorn
for pageIndex, page := range st.Memory { for pageIndex, page := range st.Memory.Pages {
addr := uint64(pageIndex) << pageAddrSize addr := uint64(pageIndex) << pageAddrSize
if err := mu.MemMap(addr, pageSize); err != nil { if err := mu.MemMap(addr, pageSize); err != nil {
return fmt.Errorf("failed to mmap page at addr 0x%x: %w", addr, err) return fmt.Errorf("failed to mmap page at addr 0x%x: %w", addr, err)
} }
if err := mu.MemWrite(addr, page[:]); err != nil { if err := mu.MemWrite(addr, page.Data[:]); err != nil {
return fmt.Errorf("failed to write page at addr 0x%x: %w", addr, err) return fmt.Errorf("failed to write page at addr 0x%x: %w", addr, err)
} }
} }
...@@ -55,9 +55,9 @@ func HookUnicorn(st *State, mu uc.Unicorn, stdOut, stdErr io.Writer, tr Tracer) ...@@ -55,9 +55,9 @@ func HookUnicorn(st *State, mu uc.Unicorn, stdOut, stdErr io.Writer, tr Tracer)
count, _ := mu.RegRead(uc.MIPS_REG_A2) count, _ := mu.RegRead(uc.MIPS_REG_A2)
switch fd { switch fd {
case 1: case 1:
_, _ = io.Copy(stdOut, st.ReadMemoryRange(uint32(addr), uint32(count))) _, _ = io.Copy(stdOut, st.Memory.ReadMemoryRange(uint32(addr), uint32(count)))
case 2: case 2:
_, _ = io.Copy(stdErr, st.ReadMemoryRange(uint32(addr), uint32(count))) _, _ = io.Copy(stdErr, st.Memory.ReadMemoryRange(uint32(addr), uint32(count)))
default: default:
// ignore other output data // ignore other output data
} }
...@@ -110,7 +110,7 @@ func HookUnicorn(st *State, mu uc.Unicorn, stdOut, stdErr io.Writer, tr Tracer) ...@@ -110,7 +110,7 @@ func HookUnicorn(st *State, mu uc.Unicorn, stdOut, stdErr io.Writer, tr Tracer)
_, err = mu.HookAdd(uc.HOOK_MEM_READ, func(mu uc.Unicorn, access int, addr64 uint64, size int, value int64) { _, err = mu.HookAdd(uc.HOOK_MEM_READ, func(mu uc.Unicorn, access int, addr64 uint64, size int, value int64) {
effAddr := uint32(addr64 & 0xFFFFFFFC) // pass effective addr to tracer effAddr := uint32(addr64 & 0xFFFFFFFC) // pass effective addr to tracer
tr.OnRead(effAddr, st.GetMemory(effAddr)) tr.OnRead(effAddr)
}, 0, ^uint64(0)) }, 0, ^uint64(0))
if err != nil { if err != nil {
return fmt.Errorf("failed to set up mem-write hook: %w", err) return fmt.Errorf("failed to set up mem-write hook: %w", err)
...@@ -124,22 +124,22 @@ func HookUnicorn(st *State, mu uc.Unicorn, stdOut, stdErr io.Writer, tr Tracer) ...@@ -124,22 +124,22 @@ func HookUnicorn(st *State, mu uc.Unicorn, stdOut, stdErr io.Writer, tr Tracer)
panic("invalid mem size") panic("invalid mem size")
} }
effAddr := uint32(addr64 & 0xFFFFFFFC) effAddr := uint32(addr64 & 0xFFFFFFFC)
tr.OnWrite(effAddr, st.GetMemory(effAddr)) tr.OnWrite(effAddr)
rt := value rt := value
rs := addr64 & 3 rs := addr64 & 3
if size == 1 { if size == 1 {
mem := st.GetMemory(effAddr) mem := st.Memory.GetMemory(effAddr)
val := uint32((rt & 0xFF) << (24 - (rs&3)*8)) val := uint32((rt & 0xFF) << (24 - (rs&3)*8))
mask := 0xFFFFFFFF ^ uint32(0xFF<<(24-(rs&3)*8)) mask := 0xFFFFFFFF ^ uint32(0xFF<<(24-(rs&3)*8))
st.SetMemory(effAddr, (mem&mask)|val) st.Memory.SetMemory(effAddr, (mem&mask)|val)
} else if size == 2 { } else if size == 2 {
mem := st.GetMemory(effAddr) mem := st.Memory.GetMemory(effAddr)
val := uint32((rt & 0xFFFF) << (16 - (rs&2)*8)) val := uint32((rt & 0xFFFF) << (16 - (rs&2)*8))
mask := 0xFFFFFFFF ^ uint32(0xFFFF<<(16-(rs&2)*8)) mask := 0xFFFFFFFF ^ uint32(0xFFFF<<(16-(rs&2)*8))
st.SetMemory(effAddr, (mem&mask)|val) st.Memory.SetMemory(effAddr, (mem&mask)|val)
} else if size == 4 { } else if size == 4 {
st.SetMemory(effAddr, uint32(rt)) st.Memory.SetMemory(effAddr, uint32(rt))
} else { } else {
log.Fatal("bad size write to ram") log.Fatal("bad size write to ram")
} }
...@@ -168,7 +168,7 @@ func HookUnicorn(st *State, mu uc.Unicorn, stdOut, stdErr io.Writer, tr Tracer) ...@@ -168,7 +168,7 @@ func HookUnicorn(st *State, mu uc.Unicorn, stdOut, stdErr io.Writer, tr Tracer)
if st.PC == prevPC+4 { if st.PC == prevPC+4 {
st.NextPC = prevPC + 8 st.NextPC = prevPC + 8
prevInsn := st.GetMemory(prevPC) prevInsn := st.Memory.GetMemory(prevPC)
opcode := prevInsn >> 26 opcode := prevInsn >> 26
switch opcode { switch opcode {
case 2, 3: // J/JAL case 2, 3: // J/JAL
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment