Commit d9dfc098 authored by protolambda's avatar protolambda

contracts: simplify MIPS.sol with in-memory State to avoid MPT proofs of the...

contracts: simplify MIPS.sol with in-memory State to avoid MPT proofs of the tiny state like registers, which is smaller than the state-proof data anyway. This has a stack-too-deep error due to variable-count limit, but that can be fixed
parent 61893c31
// SPDX-License-Identifier: MIT
pragma solidity ^0.7.3;
import "./MIPSMemory.sol";
pragma experimental ABIEncoderV2;
// https://inst.eecs.berkeley.edu/~cs61c/resources/MIPS_Green_Sheet.pdf
// https://uweb.engr.arizona.edu/~ece369/Resources/spim/MIPSReference.pdf
......@@ -16,62 +16,33 @@ import "./MIPSMemory.sol";
// Then, you call Step. Step will revert if state is missing. If all state is present, it will return the next hash
contract MIPS {
MIPSMemory public immutable m;
uint32 constant public REG_OFFSET = 0xc0000000;
uint32 constant public REG_ZERO = REG_OFFSET;
uint32 constant public REG_LR = REG_OFFSET + 0x1f*4;
uint32 constant public REG_PC = REG_OFFSET + 0x20*4;
uint32 constant public REG_HI = REG_OFFSET + 0x21*4;
uint32 constant public REG_LO = REG_OFFSET + 0x22*4;
uint32 constant public REG_HEAP = REG_OFFSET + 0x23*4;
struct State {
bytes32 memRoot;
bytes32 preimageKey;
uint32 preimageOffset;
uint32 constant public HEAP_START = 0x20000000;
uint32 constant public BRK_START = 0x40000000;
constructor() {
m = new MIPSMemory();
uint32[32] registers;
uint32 pc;
uint32 nextPC; // State is executing a branch/jump delay slot if nextPC != pc+4
uint32 lr;
uint32 lo;
uint32 hi;
uint32 heap;
uint8 exitCode;
bool exited;
uint64 step;
}
bool constant public debug = true;
event DidStep(bytes32 stateHash);
event DidWriteMemory(uint32 addr, uint32 value);
event TryReadMemory(uint32 addr);
event DidReadMemory(uint32 addr, uint32 value);
function WriteMemory(bytes32 stateHash, uint32 addr, uint32 value) internal returns (bytes32) {
if (address(m) != address(0)) {
emit DidWriteMemory(addr, value);
bytes32 newStateHash = m.WriteMemory(stateHash, addr, value);
require(m.ReadMemory(newStateHash, addr) == value, "memory readback check failed");
return newStateHash;
}
assembly {
// TODO: this is actually doing an SLOAD first
sstore(addr, value)
}
return stateHash;
}
// total State size: 32+32+4+32*4+5*4+1+1+8 = 226 bytes
function ReadMemory(bytes32 stateHash, uint32 addr) internal returns (uint32 ret) {
if (address(m) != address(0)) {
emit TryReadMemory(addr);
ret = m.ReadMemory(stateHash, addr);
//emit DidReadMemory(addr, ret);
return ret;
}
assembly {
ret := sload(addr)
}
}
uint32 constant public HEAP_START = 0x20000000;
uint32 constant public BRK_START = 0x40000000;
function Steps(bytes32 stateHash, uint count) public returns (bytes32) {
for (uint i = 0; i < count; i++) {
stateHash = Step(stateHash);
}
return stateHash;
}
// event DidStep(bytes32 stateHash);
// event DidWriteMemory(uint32 addr, uint32 value);
// event TryReadMemory(uint32 addr);
// event DidReadMemory(uint32 addr, uint32 value);
function SE(uint32 dat, uint32 idx) internal pure returns (uint32) {
bool isSigned = (dat >> (idx-1)) != 0;
......@@ -80,80 +51,47 @@ contract MIPS {
return uint32(dat&mask | (isSigned ? signed : 0));
}
function handleSyscall(bytes32 stateHash) internal returns (bytes32, bool) {
uint32 syscall_no = ReadMemory(stateHash, REG_OFFSET+2*4);
uint32 v0 = 0;
bool exit = false;
if (syscall_no == 4090) {
// mmap
uint32 a0 = ReadMemory(stateHash, REG_OFFSET+4*4);
if (a0 == 0) {
uint32 sz = ReadMemory(stateHash, REG_OFFSET+5*4);
uint32 hr = ReadMemory(stateHash, REG_HEAP);
v0 = HEAP_START + hr;
stateHash = WriteMemory(stateHash, REG_HEAP, hr+sz);
} else {
v0 = a0;
}
} else if (syscall_no == 4045) {
// brk
v0 = BRK_START;
} else if (syscall_no == 4120) {
// clone (not supported)
v0 = 1;
} else if (syscall_no == 4246) {
// exit group
exit = true;
}
stateHash = WriteMemory(stateHash, REG_OFFSET+2*4, v0);
stateHash = WriteMemory(stateHash, REG_OFFSET+7*4, 0);
return (stateHash, exit);
}
function Step(bytes32 stateHash) public returns (bytes32 newStateHash) {
uint32 pc = ReadMemory(stateHash, REG_PC);
if (pc == 0x5ead0000) {
// will revert if any required input state is missing
function Step(bytes32 stateHash, bytes memory stateData, bytes calldata proof) public returns (bytes32) {
require(stateHash == keccak256(stateData), "stateHash must match input");
State memory state = abi.decode(stateData, (State)); // TODO not efficient, need to write a "decodePacked" for State
if(state.exited) { // don't change state once exited
return stateHash;
}
newStateHash = stepPC(stateHash, pc, pc+4);
if (address(m) != address(0)) {
emit DidStep(newStateHash);
}
}
// will revert if any required input state is missing
function stepPC(bytes32 stateHash, uint32 pc, uint32 nextPC) internal returns (bytes32) {
uint32 pc = state.pc;
// instruction fetch
uint32 insn = ReadMemory(stateHash, pc);
uint32 insn; // TODO proof the memory read against memRoot
assembly {
insn := shr(sub(256, 32), calldataload(add(proof.offset, 0x20)))
}
uint32 opcode = insn >> 26; // 6-bits
uint32 func = insn & 0x3f; // 6-bits
// j-type j/jal
if (opcode == 2 || opcode == 3) {
stateHash = stepPC(stateHash, nextPC,
SE(insn&0x03FFFFFF, 26) << 2);
state.pc = state.nextPC;
state.nextPC = SE(insn&0x03FFFFFF, 26) << 2;
if (opcode == 3) {
stateHash = WriteMemory(stateHash, REG_LR, pc+8);
state.lr = pc+8; // set the link-register to the instr after the delay slot instruction.
}
return stateHash;
return keccak256(abi.encode(state));
}
// register fetch
uint32 storeAddr = REG_ZERO;
uint32 rs;
uint32 rt;
uint32 rtReg = REG_OFFSET + ((insn >> 14) & 0x7C);
uint32 rs; // source register
uint32 rt; // target register
uint32 rtReg = ((insn >> 14) & 0x7C);
// R-type or I-type (stores rt)
rs = ReadMemory(stateHash, REG_OFFSET + ((insn >> 19) & 0x7C));
storeAddr = REG_OFFSET + ((insn >> 14) & 0x7C);
rs = state.registers[(insn >> 19) & 0x7C];
uint32 storeReg = (insn >> 14) & 0x7C;
if (opcode == 0 || opcode == 0x1c) {
// R-type (stores rd)
rt = ReadMemory(stateHash, rtReg);
storeAddr = REG_OFFSET + ((insn >> 9) & 0x7C);
rt = state.registers[rtReg];
storeReg = (insn >> 9) & 0x7C;
} else if (opcode < 0x20) {
// rt is SignExtImm
// don't sign extend for andi, ori, xori
......@@ -166,17 +104,17 @@ contract MIPS {
}
} else if (opcode >= 0x28 || opcode == 0x22 || opcode == 0x26) {
// store rt value with store
rt = ReadMemory(stateHash, rtReg);
rt = state.registers[rtReg];
// store actual rt with lwl and lwr
storeAddr = rtReg;
storeReg = rtReg;
}
if ((opcode >= 4 && opcode < 8) || opcode == 1) {
bool shouldBranch = false;
if (opcode == 4 || opcode == 5) { // beq/bne
rt = ReadMemory(stateHash, rtReg);
rt = state.registers[rtReg];
shouldBranch = (rs == rt && opcode == 4) || (rs != rt && opcode == 5);
} else if (opcode == 6) { shouldBranch = int32(rs) <= 0; // blez
} else if (opcode == 7) { shouldBranch = int32(rs) > 0; // bgtz
......@@ -187,23 +125,28 @@ contract MIPS {
if (rtv == 1) shouldBranch = int32(rs) >= 0; // bgez
}
state.pc = state.nextPC; // execute the delay slot first
if (shouldBranch) {
return stepPC(stateHash, nextPC,
pc + 4 + (SE(insn&0xFFFF, 16)<<2));
state.nextPC = pc + 4 + (SE(insn&0xFFFF, 16)<<2); // then continue with the instruction the branch jumps to.
} else {
state.nextPC = state.nextPC + 4; // branch not taken
}
// branch not taken
return stepPC(stateHash, nextPC, nextPC+4);
return keccak256(abi.encode(state));
}
uint32 storeAddr = 0xFF_FF_FF_FF;
// memory fetch (all I-type)
// we do the load for stores also
uint32 mem;
if (opcode >= 0x20) {
// M[R[rs]+SignExtImm]
uint32 SignExtImm = SE(insn&0xFFFF, 16);
rs += SignExtImm;
rs += SE(insn&0xFFFF, 16);
uint32 addr = rs & 0xFFFFFFFC;
mem = ReadMemory(stateHash, addr);
// TODO proof memory read at addr
assembly {
mem := and(shr(sub(256, 64), calldataload(add(proof.offset, 0x20))), 0xFFFFFFFF)
}
if (opcode >= 0x28 && opcode != 0x30) {
// store
storeAddr = addr;
......@@ -213,80 +156,108 @@ contract MIPS {
// ALU
uint32 val = execute(insn, rs, rt, mem);
// TODO: this block can be before the execute call, and then share the mem read/writing
if (opcode == 0 && func >= 8 && func < 0x1c) {
if (func == 8 || func == 9) {
// jr/jalr
stateHash = stepPC(stateHash, nextPC, rs);
state.pc = state.nextPC;
state.nextPC = rs;
if (func == 9) {
stateHash = WriteMemory(stateHash, REG_LR, pc+8);
state.lr = pc+8; // set the link-register to the instr after the delay slot instruction.
}
return stateHash;
return keccak256(abi.encode(state));
}
// handle movz and movn when they don't write back
if (func == 0xa && rt != 0) { // movz
storeAddr = REG_ZERO;
storeReg = 0;
}
if (func == 0xb && rt == 0) { // movn
storeAddr = REG_ZERO;
storeReg = 0;
}
// syscall (can read and write)
if (func == 0xC) {
//revert("unhandled syscall");
bool exit;
(stateHash, exit) = handleSyscall(stateHash);
if (exit) {
nextPC = 0x5ead0000;
uint32 syscall_no = state.registers[2];
uint32 v0 = 0;
if (syscall_no == 4090) {
// mmap
uint32 a0 = state.registers[4];
if (a0 == 0) {
uint32 sz = state.registers[5];
uint32 hr = state.heap;
v0 = HEAP_START + hr;
state.heap = hr+sz;
} else {
v0 = a0;
}
} else if (syscall_no == 4045) {
// brk
v0 = BRK_START;
} else if (syscall_no == 4120) {
// clone (not supported)
v0 = 1;
} else if (syscall_no == 4246) {
// exit group
state.exited = true;
state.exitCode = uint8(state.registers[4]);
return keccak256(abi.encode(state));
}
// TODO: pre-image oracle read/write
state.registers[2] = v0;
state.registers[7] = 0;
}
// lo and hi registers
// can write back
if (func >= 0x10 && func < 0x1c) {
if (func == 0x10) val = ReadMemory(stateHash, REG_HI); // mfhi
else if (func == 0x11) storeAddr = REG_HI; // mthi
else if (func == 0x12) val = ReadMemory(stateHash, REG_LO); // mflo
else if (func == 0x13) storeAddr = REG_LO; // mtlo
uint32 hi;
if (func == 0x18) { // mult
if (func == 0x10) val = state.hi; // mfhi
else if (func == 0x11) state.hi = rs; // mthi
else if (func == 0x12) val = state.lo; // mflo
else if (func == 0x13) state.lo = rs; // mtlo
else if (func == 0x18) { // mult
uint64 acc = uint64(int64(int32(rs))*int64(int32(rt)));
hi = uint32(acc>>32);
val = uint32(acc);
state.hi = uint32(acc>>32);
state.lo = uint32(acc);
} else if (func == 0x19) { // multu
uint64 acc = uint64(uint64(rs)*uint64(rt));
hi = uint32(acc>>32);
val = uint32(acc);
state.hi = uint32(acc>>32);
state.lo = uint32(acc);
} else if (func == 0x1a) { // div
hi = uint32(int32(rs)%int32(rt));
val = uint32(int32(rs)/int32(rt));
state.hi = uint32(int32(rs)%int32(rt));
state.lo = uint32(int32(rs)/int32(rt));
} else if (func == 0x1b) { // divu
hi = rs%rt;
val = rs/rt;
}
// lo/hi writeback
if (func >= 0x18 && func < 0x1c) {
stateHash = WriteMemory(stateHash, REG_HI, hi);
storeAddr = REG_LO;
state.hi = rs%rt;
state.lo = rs/rt;
}
}
}
// stupid sc, write a 1 to rt
if (opcode == 0x38 && rtReg != REG_ZERO) {
stateHash = WriteMemory(stateHash, rtReg, 1);
if (opcode == 0x38 && rtReg != 0) {
state.registers[rtReg] = 1;
}
// write back
if (storeAddr != REG_ZERO) {
stateHash = WriteMemory(stateHash, storeAddr, val);
if (storeReg != 0) {
state.registers[storeReg] = val;
}
// write memory
if (storeAddr != 0xFF_FF_FF_FF) {
// TODO: write back memory change.
// Note that we already read the same memory leaf earlier.
// We can use that to shorten the proof significantly,
// by just walking back up to construct the root with the same witness data.
state.memRoot = bytes32(uint256(42));
}
stateHash = WriteMemory(stateHash, REG_PC, nextPC);
state.pc = state.nextPC;
state.nextPC = state.nextPC + 4;
return stateHash;
return keccak256(abi.encode(state));
}
function execute(uint32 insn, uint32 rs, uint32 rt, uint32 mem) internal pure returns (uint32) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment