Commit d9dfc098 authored by protolambda's avatar protolambda

contracts: simplify MIPS.sol with in-memory State to avoid MPT proofs of the...

contracts: simplify MIPS.sol with in-memory State to avoid MPT proofs of the tiny state like registers, which is smaller than the state-proof data anyway. This has a stack-too-deep error due to variable-count limit, but that can be fixed
parent 61893c31
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
pragma solidity ^0.7.3; pragma solidity ^0.7.3;
import "./MIPSMemory.sol"; pragma experimental ABIEncoderV2;
// https://inst.eecs.berkeley.edu/~cs61c/resources/MIPS_Green_Sheet.pdf // https://inst.eecs.berkeley.edu/~cs61c/resources/MIPS_Green_Sheet.pdf
// https://uweb.engr.arizona.edu/~ece369/Resources/spim/MIPSReference.pdf // https://uweb.engr.arizona.edu/~ece369/Resources/spim/MIPSReference.pdf
...@@ -16,62 +16,33 @@ import "./MIPSMemory.sol"; ...@@ -16,62 +16,33 @@ import "./MIPSMemory.sol";
// Then, you call Step. Step will revert if state is missing. If all state is present, it will return the next hash // Then, you call Step. Step will revert if state is missing. If all state is present, it will return the next hash
contract MIPS { contract MIPS {
MIPSMemory public immutable m;
uint32 constant public REG_OFFSET = 0xc0000000; struct State {
uint32 constant public REG_ZERO = REG_OFFSET; bytes32 memRoot;
uint32 constant public REG_LR = REG_OFFSET + 0x1f*4; bytes32 preimageKey;
uint32 constant public REG_PC = REG_OFFSET + 0x20*4; uint32 preimageOffset;
uint32 constant public REG_HI = REG_OFFSET + 0x21*4;
uint32 constant public REG_LO = REG_OFFSET + 0x22*4; uint32[32] registers;
uint32 constant public REG_HEAP = REG_OFFSET + 0x23*4; uint32 pc;
uint32 nextPC; // State is executing a branch/jump delay slot if nextPC != pc+4
uint32 constant public HEAP_START = 0x20000000; uint32 lr;
uint32 constant public BRK_START = 0x40000000; uint32 lo;
uint32 hi;
constructor() { uint32 heap;
m = new MIPSMemory(); uint8 exitCode;
bool exited;
uint64 step;
} }
bool constant public debug = true; // total State size: 32+32+4+32*4+5*4+1+1+8 = 226 bytes
event DidStep(bytes32 stateHash);
event DidWriteMemory(uint32 addr, uint32 value);
event TryReadMemory(uint32 addr);
event DidReadMemory(uint32 addr, uint32 value);
function WriteMemory(bytes32 stateHash, uint32 addr, uint32 value) internal returns (bytes32) { uint32 constant public HEAP_START = 0x20000000;
if (address(m) != address(0)) { uint32 constant public BRK_START = 0x40000000;
emit DidWriteMemory(addr, value);
bytes32 newStateHash = m.WriteMemory(stateHash, addr, value);
require(m.ReadMemory(newStateHash, addr) == value, "memory readback check failed");
return newStateHash;
}
assembly {
// TODO: this is actually doing an SLOAD first
sstore(addr, value)
}
return stateHash;
}
function ReadMemory(bytes32 stateHash, uint32 addr) internal returns (uint32 ret) {
if (address(m) != address(0)) {
emit TryReadMemory(addr);
ret = m.ReadMemory(stateHash, addr);
//emit DidReadMemory(addr, ret);
return ret;
}
assembly {
ret := sload(addr)
}
}
function Steps(bytes32 stateHash, uint count) public returns (bytes32) { // event DidStep(bytes32 stateHash);
for (uint i = 0; i < count; i++) { // event DidWriteMemory(uint32 addr, uint32 value);
stateHash = Step(stateHash); // event TryReadMemory(uint32 addr);
} // event DidReadMemory(uint32 addr, uint32 value);
return stateHash;
}
function SE(uint32 dat, uint32 idx) internal pure returns (uint32) { function SE(uint32 dat, uint32 idx) internal pure returns (uint32) {
bool isSigned = (dat >> (idx-1)) != 0; bool isSigned = (dat >> (idx-1)) != 0;
...@@ -80,80 +51,47 @@ contract MIPS { ...@@ -80,80 +51,47 @@ contract MIPS {
return uint32(dat&mask | (isSigned ? signed : 0)); return uint32(dat&mask | (isSigned ? signed : 0));
} }
function handleSyscall(bytes32 stateHash) internal returns (bytes32, bool) { // will revert if any required input state is missing
uint32 syscall_no = ReadMemory(stateHash, REG_OFFSET+2*4); function Step(bytes32 stateHash, bytes memory stateData, bytes calldata proof) public returns (bytes32) {
uint32 v0 = 0; require(stateHash == keccak256(stateData), "stateHash must match input");
bool exit = false; State memory state = abi.decode(stateData, (State)); // TODO not efficient, need to write a "decodePacked" for State
if(state.exited) { // don't change state once exited
if (syscall_no == 4090) {
// mmap
uint32 a0 = ReadMemory(stateHash, REG_OFFSET+4*4);
if (a0 == 0) {
uint32 sz = ReadMemory(stateHash, REG_OFFSET+5*4);
uint32 hr = ReadMemory(stateHash, REG_HEAP);
v0 = HEAP_START + hr;
stateHash = WriteMemory(stateHash, REG_HEAP, hr+sz);
} else {
v0 = a0;
}
} else if (syscall_no == 4045) {
// brk
v0 = BRK_START;
} else if (syscall_no == 4120) {
// clone (not supported)
v0 = 1;
} else if (syscall_no == 4246) {
// exit group
exit = true;
}
stateHash = WriteMemory(stateHash, REG_OFFSET+2*4, v0);
stateHash = WriteMemory(stateHash, REG_OFFSET+7*4, 0);
return (stateHash, exit);
}
function Step(bytes32 stateHash) public returns (bytes32 newStateHash) {
uint32 pc = ReadMemory(stateHash, REG_PC);
if (pc == 0x5ead0000) {
return stateHash; return stateHash;
} }
newStateHash = stepPC(stateHash, pc, pc+4);
if (address(m) != address(0)) {
emit DidStep(newStateHash);
}
}
// will revert if any required input state is missing uint32 pc = state.pc;
function stepPC(bytes32 stateHash, uint32 pc, uint32 nextPC) internal returns (bytes32) {
// instruction fetch // instruction fetch
uint32 insn = ReadMemory(stateHash, pc); uint32 insn; // TODO proof the memory read against memRoot
assembly {
insn := shr(sub(256, 32), calldataload(add(proof.offset, 0x20)))
}
uint32 opcode = insn >> 26; // 6-bits uint32 opcode = insn >> 26; // 6-bits
uint32 func = insn & 0x3f; // 6-bits uint32 func = insn & 0x3f; // 6-bits
// j-type j/jal // j-type j/jal
if (opcode == 2 || opcode == 3) { if (opcode == 2 || opcode == 3) {
stateHash = stepPC(stateHash, nextPC, state.pc = state.nextPC;
SE(insn&0x03FFFFFF, 26) << 2); state.nextPC = SE(insn&0x03FFFFFF, 26) << 2;
if (opcode == 3) { if (opcode == 3) {
stateHash = WriteMemory(stateHash, REG_LR, pc+8); state.lr = pc+8; // set the link-register to the instr after the delay slot instruction.
} }
return stateHash; return keccak256(abi.encode(state));
} }
// register fetch // register fetch
uint32 storeAddr = REG_ZERO; uint32 rs; // source register
uint32 rs; uint32 rt; // target register
uint32 rt; uint32 rtReg = ((insn >> 14) & 0x7C);
uint32 rtReg = REG_OFFSET + ((insn >> 14) & 0x7C);
// R-type or I-type (stores rt) // R-type or I-type (stores rt)
rs = ReadMemory(stateHash, REG_OFFSET + ((insn >> 19) & 0x7C)); rs = state.registers[(insn >> 19) & 0x7C];
storeAddr = REG_OFFSET + ((insn >> 14) & 0x7C); uint32 storeReg = (insn >> 14) & 0x7C;
if (opcode == 0 || opcode == 0x1c) { if (opcode == 0 || opcode == 0x1c) {
// R-type (stores rd) // R-type (stores rd)
rt = ReadMemory(stateHash, rtReg); rt = state.registers[rtReg];
storeAddr = REG_OFFSET + ((insn >> 9) & 0x7C); storeReg = (insn >> 9) & 0x7C;
} else if (opcode < 0x20) { } else if (opcode < 0x20) {
// rt is SignExtImm // rt is SignExtImm
// don't sign extend for andi, ori, xori // don't sign extend for andi, ori, xori
...@@ -166,17 +104,17 @@ contract MIPS { ...@@ -166,17 +104,17 @@ contract MIPS {
} }
} else if (opcode >= 0x28 || opcode == 0x22 || opcode == 0x26) { } else if (opcode >= 0x28 || opcode == 0x22 || opcode == 0x26) {
// store rt value with store // store rt value with store
rt = ReadMemory(stateHash, rtReg); rt = state.registers[rtReg];
// store actual rt with lwl and lwr // store actual rt with lwl and lwr
storeAddr = rtReg; storeReg = rtReg;
} }
if ((opcode >= 4 && opcode < 8) || opcode == 1) { if ((opcode >= 4 && opcode < 8) || opcode == 1) {
bool shouldBranch = false; bool shouldBranch = false;
if (opcode == 4 || opcode == 5) { // beq/bne if (opcode == 4 || opcode == 5) { // beq/bne
rt = ReadMemory(stateHash, rtReg); rt = state.registers[rtReg];
shouldBranch = (rs == rt && opcode == 4) || (rs != rt && opcode == 5); shouldBranch = (rs == rt && opcode == 4) || (rs != rt && opcode == 5);
} else if (opcode == 6) { shouldBranch = int32(rs) <= 0; // blez } else if (opcode == 6) { shouldBranch = int32(rs) <= 0; // blez
} else if (opcode == 7) { shouldBranch = int32(rs) > 0; // bgtz } else if (opcode == 7) { shouldBranch = int32(rs) > 0; // bgtz
...@@ -187,23 +125,28 @@ contract MIPS { ...@@ -187,23 +125,28 @@ contract MIPS {
if (rtv == 1) shouldBranch = int32(rs) >= 0; // bgez if (rtv == 1) shouldBranch = int32(rs) >= 0; // bgez
} }
state.pc = state.nextPC; // execute the delay slot first
if (shouldBranch) { if (shouldBranch) {
return stepPC(stateHash, nextPC, state.nextPC = pc + 4 + (SE(insn&0xFFFF, 16)<<2); // then continue with the instruction the branch jumps to.
pc + 4 + (SE(insn&0xFFFF, 16)<<2)); } else {
state.nextPC = state.nextPC + 4; // branch not taken
} }
// branch not taken return keccak256(abi.encode(state));
return stepPC(stateHash, nextPC, nextPC+4);
} }
uint32 storeAddr = 0xFF_FF_FF_FF;
// memory fetch (all I-type) // memory fetch (all I-type)
// we do the load for stores also // we do the load for stores also
uint32 mem; uint32 mem;
if (opcode >= 0x20) { if (opcode >= 0x20) {
// M[R[rs]+SignExtImm] // M[R[rs]+SignExtImm]
uint32 SignExtImm = SE(insn&0xFFFF, 16); rs += SE(insn&0xFFFF, 16);
rs += SignExtImm;
uint32 addr = rs & 0xFFFFFFFC; uint32 addr = rs & 0xFFFFFFFC;
mem = ReadMemory(stateHash, addr); // TODO proof memory read at addr
assembly {
mem := and(shr(sub(256, 64), calldataload(add(proof.offset, 0x20))), 0xFFFFFFFF)
}
if (opcode >= 0x28 && opcode != 0x30) { if (opcode >= 0x28 && opcode != 0x30) {
// store // store
storeAddr = addr; storeAddr = addr;
...@@ -213,80 +156,108 @@ contract MIPS { ...@@ -213,80 +156,108 @@ contract MIPS {
// ALU // ALU
uint32 val = execute(insn, rs, rt, mem); uint32 val = execute(insn, rs, rt, mem);
// TODO: this block can be before the execute call, and then share the mem read/writing
if (opcode == 0 && func >= 8 && func < 0x1c) { if (opcode == 0 && func >= 8 && func < 0x1c) {
if (func == 8 || func == 9) { if (func == 8 || func == 9) {
// jr/jalr // jr/jalr
stateHash = stepPC(stateHash, nextPC, rs); state.pc = state.nextPC;
state.nextPC = rs;
if (func == 9) { if (func == 9) {
stateHash = WriteMemory(stateHash, REG_LR, pc+8); state.lr = pc+8; // set the link-register to the instr after the delay slot instruction.
} }
return stateHash; return keccak256(abi.encode(state));
} }
// handle movz and movn when they don't write back // handle movz and movn when they don't write back
if (func == 0xa && rt != 0) { // movz if (func == 0xa && rt != 0) { // movz
storeAddr = REG_ZERO; storeReg = 0;
} }
if (func == 0xb && rt == 0) { // movn if (func == 0xb && rt == 0) { // movn
storeAddr = REG_ZERO; storeReg = 0;
} }
// syscall (can read and write) // syscall (can read and write)
if (func == 0xC) { if (func == 0xC) {
//revert("unhandled syscall"); uint32 syscall_no = state.registers[2];
bool exit; uint32 v0 = 0;
(stateHash, exit) = handleSyscall(stateHash);
if (exit) { if (syscall_no == 4090) {
nextPC = 0x5ead0000; // mmap
uint32 a0 = state.registers[4];
if (a0 == 0) {
uint32 sz = state.registers[5];
uint32 hr = state.heap;
v0 = HEAP_START + hr;
state.heap = hr+sz;
} else {
v0 = a0;
}
} else if (syscall_no == 4045) {
// brk
v0 = BRK_START;
} else if (syscall_no == 4120) {
// clone (not supported)
v0 = 1;
} else if (syscall_no == 4246) {
// exit group
state.exited = true;
state.exitCode = uint8(state.registers[4]);
return keccak256(abi.encode(state));
} }
// TODO: pre-image oracle read/write
state.registers[2] = v0;
state.registers[7] = 0;
} }
// lo and hi registers // lo and hi registers
// can write back // can write back
if (func >= 0x10 && func < 0x1c) { if (func >= 0x10 && func < 0x1c) {
if (func == 0x10) val = ReadMemory(stateHash, REG_HI); // mfhi if (func == 0x10) val = state.hi; // mfhi
else if (func == 0x11) storeAddr = REG_HI; // mthi else if (func == 0x11) state.hi = rs; // mthi
else if (func == 0x12) val = ReadMemory(stateHash, REG_LO); // mflo else if (func == 0x12) val = state.lo; // mflo
else if (func == 0x13) storeAddr = REG_LO; // mtlo else if (func == 0x13) state.lo = rs; // mtlo
else if (func == 0x18) { // mult
uint32 hi;
if (func == 0x18) { // mult
uint64 acc = uint64(int64(int32(rs))*int64(int32(rt))); uint64 acc = uint64(int64(int32(rs))*int64(int32(rt)));
hi = uint32(acc>>32); state.hi = uint32(acc>>32);
val = uint32(acc); state.lo = uint32(acc);
} else if (func == 0x19) { // multu } else if (func == 0x19) { // multu
uint64 acc = uint64(uint64(rs)*uint64(rt)); uint64 acc = uint64(uint64(rs)*uint64(rt));
hi = uint32(acc>>32); state.hi = uint32(acc>>32);
val = uint32(acc); state.lo = uint32(acc);
} else if (func == 0x1a) { // div } else if (func == 0x1a) { // div
hi = uint32(int32(rs)%int32(rt)); state.hi = uint32(int32(rs)%int32(rt));
val = uint32(int32(rs)/int32(rt)); state.lo = uint32(int32(rs)/int32(rt));
} else if (func == 0x1b) { // divu } else if (func == 0x1b) { // divu
hi = rs%rt; state.hi = rs%rt;
val = rs/rt; state.lo = rs/rt;
}
// lo/hi writeback
if (func >= 0x18 && func < 0x1c) {
stateHash = WriteMemory(stateHash, REG_HI, hi);
storeAddr = REG_LO;
} }
} }
} }
// stupid sc, write a 1 to rt // stupid sc, write a 1 to rt
if (opcode == 0x38 && rtReg != REG_ZERO) { if (opcode == 0x38 && rtReg != 0) {
stateHash = WriteMemory(stateHash, rtReg, 1); state.registers[rtReg] = 1;
} }
// write back // write back
if (storeAddr != REG_ZERO) { if (storeReg != 0) {
stateHash = WriteMemory(stateHash, storeAddr, val); state.registers[storeReg] = val;
}
// write memory
if (storeAddr != 0xFF_FF_FF_FF) {
// TODO: write back memory change.
// Note that we already read the same memory leaf earlier.
// We can use that to shorten the proof significantly,
// by just walking back up to construct the root with the same witness data.
state.memRoot = bytes32(uint256(42));
} }
stateHash = WriteMemory(stateHash, REG_PC, nextPC); state.pc = state.nextPC;
state.nextPC = state.nextPC + 4;
return stateHash; return keccak256(abi.encode(state));
} }
function execute(uint32 insn, uint32 rs, uint32 rt, uint32 mem) internal pure returns (uint32) { function execute(uint32 insn, uint32 rs, uint32 rt, uint32 mem) internal pure returns (uint32) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment