Commit 873a4eab authored by mbaxter's avatar mbaxter Committed by GitHub

cannon: Port audit fixes (#11697)

* cannon: Require boolean exited field to be encoded as 0 or 1

* cannon: Port mmap solidity tests for MTCannon

* cannon: Port srav solidity tests

* cannon: Run semver-lock, snapshot tasks

* cannon: Cut extra validation on thread.exited

* cannon: Run semver lock

* cannon: Update IMIPS2 interface
parent 85c79a55
...@@ -148,8 +148,8 @@ ...@@ -148,8 +148,8 @@
"sourceCodeHash": "0xb6e219e8c2d81d75c48a1459907609e9096fe032a7447c88cd3e0d134752ac8e" "sourceCodeHash": "0xb6e219e8c2d81d75c48a1459907609e9096fe032a7447c88cd3e0d134752ac8e"
}, },
"src/cannon/MIPS2.sol": { "src/cannon/MIPS2.sol": {
"initCodeHash": "0x36b7c32cf9eba05e6db44910a25c800b801c075f8e053eca9515c6e0e4d8a902", "initCodeHash": "0xf5e2bca4ba0c504ffa68f1ce5fbf4349b1fa892034777d77803d9111aed279fa",
"sourceCodeHash": "0xa307c44a2d67bc84e75f4b7341345ed236da2e63c1f3f442416f14cd262126bf" "sourceCodeHash": "0xe8d06d4e2c3cf6e0682e4c152429cd61f0fd963acb1190df1bba1727d90ef6b7"
}, },
"src/cannon/PreimageOracle.sol": { "src/cannon/PreimageOracle.sol": {
"initCodeHash": "0xce7a1c3265e457a05d17b6d1a2ef93c4639caac3733c9cf88bfd192eae2c5788", "initCodeHash": "0xce7a1c3265e457a05d17b6d1a2ef93c4639caac3733c9cf88bfd192eae2c5788",
......
...@@ -51,5 +51,10 @@ ...@@ -51,5 +51,10 @@
], ],
"stateMutability": "view", "stateMutability": "view",
"type": "function" "type": "function"
},
{
"inputs": [],
"name": "InvalidExitedValue",
"type": "error"
} }
] ]
\ No newline at end of file
...@@ -51,8 +51,8 @@ contract MIPS2 is ISemver { ...@@ -51,8 +51,8 @@ contract MIPS2 is ISemver {
} }
/// @notice The semantic version of the MIPS2 contract. /// @notice The semantic version of the MIPS2 contract.
/// @custom:semver 1.0.0-beta.5 /// @custom:semver 1.0.0-beta.6
string public constant version = "1.0.0-beta.5"; string public constant version = "1.0.0-beta.6";
/// @notice The preimage oracle contract. /// @notice The preimage oracle contract.
IPreimageOracle internal immutable ORACLE; IPreimageOracle internal immutable ORACLE;
...@@ -91,7 +91,7 @@ contract MIPS2 is ISemver { ...@@ -91,7 +91,7 @@ contract MIPS2 is ISemver {
unchecked { unchecked {
State memory state; State memory state;
ThreadState memory thread; ThreadState memory thread;
uint32 exited;
assembly { assembly {
if iszero(eq(state, STATE_MEM_OFFSET)) { if iszero(eq(state, STATE_MEM_OFFSET)) {
// expected state mem offset check // expected state mem offset check
...@@ -131,6 +131,7 @@ contract MIPS2 is ISemver { ...@@ -131,6 +131,7 @@ contract MIPS2 is ISemver {
c, m := putField(c, m, 4) // heap c, m := putField(c, m, 4) // heap
c, m := putField(c, m, 1) // exitCode c, m := putField(c, m, 1) // exitCode
c, m := putField(c, m, 1) // exited c, m := putField(c, m, 1) // exited
exited := mload(sub(m, 32))
c, m := putField(c, m, 8) // step c, m := putField(c, m, 8) // step
c, m := putField(c, m, 8) // stepsSinceLastContextSwitch c, m := putField(c, m, 8) // stepsSinceLastContextSwitch
c, m := putField(c, m, 4) // wakeup c, m := putField(c, m, 4) // wakeup
...@@ -139,6 +140,7 @@ contract MIPS2 is ISemver { ...@@ -139,6 +140,7 @@ contract MIPS2 is ISemver {
c, m := putField(c, m, 32) // rightThreadStack c, m := putField(c, m, 32) // rightThreadStack
c, m := putField(c, m, 4) // nextThreadID c, m := putField(c, m, 4) // nextThreadID
} }
st.assertExitedIsValid(exited);
if (state.exited) { if (state.exited) {
// thread state is unchanged // thread state is unchanged
...@@ -459,6 +461,7 @@ contract MIPS2 is ISemver { ...@@ -459,6 +461,7 @@ contract MIPS2 is ISemver {
/// @notice Computes the hash of the MIPS state. /// @notice Computes the hash of the MIPS state.
/// @return out_ The hashed MIPS state. /// @return out_ The hashed MIPS state.
function outputState() internal returns (bytes32 out_) { function outputState() internal returns (bytes32 out_) {
uint32 exited;
assembly { assembly {
// copies 'size' bytes, right-aligned in word at 'from', to 'to', incl. trailing data // copies 'size' bytes, right-aligned in word at 'from', to 'to', incl. trailing data
function copyMem(from, to, size) -> fromOut, toOut { function copyMem(from, to, size) -> fromOut, toOut {
...@@ -481,7 +484,7 @@ contract MIPS2 is ISemver { ...@@ -481,7 +484,7 @@ contract MIPS2 is ISemver {
from, to := copyMem(from, to, 4) // heap from, to := copyMem(from, to, 4) // heap
let exitCode := mload(from) let exitCode := mload(from)
from, to := copyMem(from, to, 1) // exitCode from, to := copyMem(from, to, 1) // exitCode
let exited := mload(from) exited := mload(from)
from, to := copyMem(from, to, 1) // exited from, to := copyMem(from, to, 1) // exited
from, to := copyMem(from, to, 8) // step from, to := copyMem(from, to, 8) // step
from, to := copyMem(from, to, 8) // stepsSinceLastContextSwitch from, to := copyMem(from, to, 8) // stepsSinceLastContextSwitch
...@@ -516,6 +519,8 @@ contract MIPS2 is ISemver { ...@@ -516,6 +519,8 @@ contract MIPS2 is ISemver {
out_ := keccak256(start, sub(to, start)) out_ := keccak256(start, sub(to, start))
out_ := or(and(not(shl(248, 0xFF)), out_), shl(248, status)) out_ := or(and(not(shl(248, 0xFF)), out_), shl(248, status))
} }
st.assertExitedIsValid(exited);
} }
/// @notice Updates the current thread stack root via inner thread root in calldata /// @notice Updates the current thread stack root via inner thread root in calldata
......
...@@ -6,5 +6,7 @@ import { ISemver } from "src/universal/ISemver.sol"; ...@@ -6,5 +6,7 @@ import { ISemver } from "src/universal/ISemver.sol";
/// @title IMIPS2 /// @title IMIPS2
/// @notice Interface for the MIPS2 contract. /// @notice Interface for the MIPS2 contract.
interface IMIPS2 is ISemver { interface IMIPS2 is ISemver {
error InvalidExitedValue();
function step(bytes memory _stateData, bytes memory _proof, bytes32 _localContext) external returns (bytes32); function step(bytes memory _stateData, bytes memory _proof, bytes32 _localContext) external returns (bytes32);
} }
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
pragma solidity 0.8.15; pragma solidity 0.8.15;
import { InvalidExitedValue } from "src/cannon/libraries/CannonErrors.sol";
library MIPSState { library MIPSState {
struct CpuScalars { struct CpuScalars {
uint32 pc; uint32 pc;
...@@ -8,4 +10,10 @@ library MIPSState { ...@@ -8,4 +10,10 @@ library MIPSState {
uint32 lo; uint32 lo;
uint32 hi; uint32 hi;
} }
function assertExitedIsValid(uint32 exited) internal pure {
if (exited > 1) {
revert InvalidExitedValue();
}
}
} }
...@@ -65,7 +65,6 @@ contract MIPS_Test is CommonTest { ...@@ -65,7 +65,6 @@ contract MIPS_Test is CommonTest {
/// invalid (anything greater than 1). /// invalid (anything greater than 1).
/// @param _exited Value to set the exited field to. /// @param _exited Value to set the exited field to.
function testFuzz_step_invalidExitedValue_fails(uint8 _exited) external { function testFuzz_step_invalidExitedValue_fails(uint8 _exited) external {
// Assume
// Make sure the value of _exited is invalid. // Make sure the value of _exited is invalid.
_exited = uint8(bound(uint256(_exited), 2, type(uint8).max)); _exited = uint8(bound(uint256(_exited), 2, type(uint8).max));
...@@ -77,6 +76,7 @@ contract MIPS_Test is CommonTest { ...@@ -77,6 +76,7 @@ contract MIPS_Test is CommonTest {
// Compute the encoded state and manipulate it. // Compute the encoded state and manipulate it.
bytes memory enc = encodeState(state); bytes memory enc = encodeState(state);
assembly { assembly {
// Push offset by an additional 32 bytes (0x20) to account for length prefix
mstore8(add(add(enc, 0x20), 89), _exited) mstore8(add(add(enc, 0x20), 89), _exited)
} }
......
...@@ -5,7 +5,9 @@ import { CommonTest } from "test/setup/CommonTest.sol"; ...@@ -5,7 +5,9 @@ import { CommonTest } from "test/setup/CommonTest.sol";
import { MIPS2 } from "src/cannon/MIPS2.sol"; import { MIPS2 } from "src/cannon/MIPS2.sol";
import { PreimageOracle } from "src/cannon/PreimageOracle.sol"; import { PreimageOracle } from "src/cannon/PreimageOracle.sol";
import { MIPSSyscalls as sys } from "src/cannon/libraries/MIPSSyscalls.sol"; import { MIPSSyscalls as sys } from "src/cannon/libraries/MIPSSyscalls.sol";
import { MIPSInstructions as ins } from "src/cannon/libraries/MIPSInstructions.sol";
import "src/dispute/lib/Types.sol"; import "src/dispute/lib/Types.sol";
import { InvalidExitedValue } from "src/cannon/libraries/CannonErrors.sol";
contract ThreadStack { contract ThreadStack {
bytes32 internal constant EMPTY_THREAD_ROOT = hex"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5"; bytes32 internal constant EMPTY_THREAD_ROOT = hex"ad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5";
...@@ -191,6 +193,34 @@ contract MIPS2_Test is CommonTest { ...@@ -191,6 +193,34 @@ contract MIPS2_Test is CommonTest {
assertNotEq(post, bytes32(0)); assertNotEq(post, bytes32(0));
} }
/// @notice Tests that the mips step function fails when the value of the exited field is
/// invalid (anything greater than 1).
/// @param _exited Value to set the exited field to.
function testFuzz_step_invalidExitedValueInState_fails(uint8 _exited) external {
// Make sure the value of _exited is invalid.
_exited = uint8(bound(uint256(_exited), 2, type(uint8).max));
// Setup state
uint32 insn = encodespec(17, 18, 8, 0x20); // Arbitrary instruction: add t0, s1, s2
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
// Set up step data
bytes memory encodedThread = encodeThread(thread);
bytes memory threadWitness = abi.encodePacked(encodedThread, EMPTY_THREAD_ROOT);
bytes memory proofData = bytes.concat(threadWitness, memProof);
bytes memory stateData = encodeState(state);
assembly {
// Manipulate state data
// Push offset by an additional 32 bytes (0x20) to account for length prefix
mstore8(add(add(stateData, 0x20), 73), _exited)
}
// Call the step function and expect a revert.
vm.expectRevert(InvalidExitedValue.selector);
mips.step(stateData, proofData, 0);
}
function test_invalidThreadWitness_reverts() public { function test_invalidThreadWitness_reverts() public {
MIPS2.State memory state; MIPS2.State memory state;
MIPS2.ThreadState memory thread; MIPS2.ThreadState memory thread;
...@@ -1076,6 +1106,166 @@ contract MIPS2_Test is CommonTest { ...@@ -1076,6 +1106,166 @@ contract MIPS2_Test is CommonTest {
assertEq(postState, outputState(expect), "unexpected post state"); assertEq(postState, outputState(expect), "unexpected post state");
} }
function test_mmap_succeeds_simple() external {
uint32 insn = 0x0000000c; // syscall
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
state.heap = 4096;
thread.nextPC = 4;
thread.registers[2] = sys.SYS_MMAP; // syscall num
thread.registers[4] = 0x0; // a0
thread.registers[5] = 4095; // a1
updateThreadStacks(state, thread);
// Set up step data
bytes memory threadWitness = abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT);
bytes memory encodedState = encodeState(state);
MIPS2.State memory expect = copyState(state);
MIPS2.ThreadState memory expectThread = copyThread(thread);
expect.memRoot = state.memRoot;
expect.step = state.step + 1;
expect.stepsSinceLastContextSwitch = state.stepsSinceLastContextSwitch + 1;
expect.heap = state.heap + 4096;
expectThread.pc = thread.nextPC;
expectThread.nextPC = thread.nextPC + 4;
expectThread.registers[2] = state.heap; // return old heap
expectThread.registers[7] = 0; // No error
updateThreadStacks(expect, expectThread);
bytes32 postState = mips.step(encodedState, bytes.concat(threadWitness, memProof), 0);
assertEq(postState, outputState(expect), "unexpected post state");
}
function test_mmap_succeeds_justWithinMemLimit() external {
uint32 insn = 0x0000000c; // syscall
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
state.heap = sys.HEAP_END - 4096; // Set up to increase heap to its limit
thread.nextPC = 4;
thread.registers[2] = sys.SYS_MMAP; // syscall num
thread.registers[4] = 0x0; // a0
thread.registers[5] = 4095; // a1
updateThreadStacks(state, thread);
// Set up step data
bytes memory threadWitness = abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT);
bytes memory encodedState = encodeState(state);
MIPS2.State memory expect = copyState(state);
MIPS2.ThreadState memory expectThread = copyThread(thread);
expect.memRoot = state.memRoot;
expect.step += 1;
expect.stepsSinceLastContextSwitch += 1;
expect.heap = sys.HEAP_END;
expectThread.pc = thread.nextPC;
expectThread.nextPC = thread.nextPC + 4;
expectThread.registers[2] = state.heap; // Return the old heap value
expectThread.registers[7] = 0; // No error
updateThreadStacks(expect, expectThread);
bytes32 postState = mips.step(encodedState, bytes.concat(threadWitness, memProof), 0);
assertEq(postState, outputState(expect), "unexpected post state");
}
function test_mmap_fails() external {
uint32 insn = 0x0000000c; // syscall
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
state.heap = sys.HEAP_END - 4096; // Set up to increase heap beyond its limit
thread.nextPC = 4;
thread.registers[2] = sys.SYS_MMAP; // syscall num
thread.registers[4] = 0x0; // a0
thread.registers[5] = 4097; // a1
updateThreadStacks(state, thread);
// Set up step data
bytes memory threadWitness = abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT);
bytes memory encodedState = encodeState(state);
MIPS2.State memory expect = copyState(state);
MIPS2.ThreadState memory expectThread = copyThread(thread);
expect.memRoot = state.memRoot;
expect.step += 1;
expect.stepsSinceLastContextSwitch += 1;
expectThread.pc = thread.nextPC;
expectThread.nextPC = thread.nextPC + 4;
expectThread.registers[2] = sys.SYS_ERROR_SIGNAL; // signal an stdError
expectThread.registers[7] = sys.EINVAL; // Return error value
expectThread.registers[4] = thread.registers[4]; // a0
expectThread.registers[5] = thread.registers[5]; // a1
updateThreadStacks(expect, expectThread);
bytes32 postState = mips.step(encodedState, bytes.concat(threadWitness, memProof), 0);
assertEq(postState, outputState(expect), "unexpected post state");
}
function test_srav_succeeds() external {
uint32 insn = encodespec(0xa, 0x9, 0x8, 7); // srav t0, t1, t2
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0xdeafbeef; // t1
thread.registers[10] = 12; // t2
updateThreadStacks(state, thread);
// Set up step data
bytes memory threadWitness = abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT);
bytes memory encodedState = encodeState(state);
MIPS2.State memory expect = copyState(state);
MIPS2.ThreadState memory expectThread = copyThread(thread);
expect.memRoot = state.memRoot;
expect.step += 1;
expect.stepsSinceLastContextSwitch += 1;
expectThread.pc = thread.nextPC;
expectThread.nextPC = thread.nextPC + 4;
expectThread.registers[8] = 0xfffdeafb; // t0
updateThreadStacks(expect, expectThread);
bytes32 postState = mips.step(encodedState, bytes.concat(threadWitness, memProof), 0);
assertEq(postState, outputState(expect), "unexpected post state");
}
/// @notice Tests that the SRAV instruction succeeds when it includes extra bits in the shift
/// amount beyond the lower 5 bits that are actually used for the shift. Extra bits
/// need to be ignored but the instruction should still succeed.
/// @param _rs Value to set in the shift register $rs.
function testFuzz_srav_withExtraBits_succeeds(uint32 _rs) external {
// Assume
// Force _rs to have more than 5 bits set.
_rs = uint32(bound(uint256(_rs), 0x20, type(uint32).max));
uint32 insn = encodespec(0xa, 0x9, 0x8, 7); // srav t0, t1, t2
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
constructMIPSState(0, insn, 0x4, 0);
thread.registers[9] = 0xdeadbeef; // t1
thread.registers[10] = _rs; // t2
updateThreadStacks(state, thread);
// Set up step data
bytes memory threadWitness = abi.encodePacked(encodeThread(thread), EMPTY_THREAD_ROOT);
bytes memory encodedState = encodeState(state);
// Calculate shamt
uint32 shamt = thread.registers[10] & 0x1F;
MIPS2.State memory expect = copyState(state);
MIPS2.ThreadState memory expectThread = copyThread(thread);
expect.memRoot = state.memRoot;
expect.step += 1;
expect.stepsSinceLastContextSwitch += 1;
expectThread.pc = thread.nextPC;
expectThread.nextPC = thread.nextPC + 4;
expectThread.registers[8] = ins.signExtend(thread.registers[9] >> shamt, 32 - shamt); // t0
updateThreadStacks(expect, expectThread);
bytes32 postState = mips.step(encodedState, bytes.concat(threadWitness, memProof), 0);
assertEq(postState, outputState(expect), "unexpected post state");
}
function test_add_succeeds() public { function test_add_succeeds() public {
uint32 insn = encodespec(17, 18, 8, 0x20); // add t0, s1, s2 uint32 insn = encodespec(17, 18, 8, 0x20); // add t0, s1, s2
(MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) = (MIPS2.State memory state, MIPS2.ThreadState memory thread, bytes memory memProof) =
...@@ -2175,8 +2365,11 @@ contract MIPS2_Test is CommonTest { ...@@ -2175,8 +2365,11 @@ contract MIPS2_Test is CommonTest {
} }
} }
function outputState(MIPS2.State memory state) internal pure returns (bytes32 out_) { event ExpectedOutputState(bytes encoded, MIPS2.State state);
function outputState(MIPS2.State memory state) internal returns (bytes32 out_) {
bytes memory enc = encodeState(state); bytes memory enc = encodeState(state);
emit ExpectedOutputState(enc, state);
VMStatus status = vmStatus(state); VMStatus status = vmStatus(state);
out_ = keccak256(enc); out_ = keccak256(enc);
assembly { assembly {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment