Commit ca583f7f authored by Adrian Sutton's avatar Adrian Sutton Committed by GitHub

op-program: Add chain ID to hints when using interop (#13735)

* op-program: Add L2 chain ID to hints when using interop

* op-program: Add L2 chain ID to OutputRoot hints when using interop

* Fix label in error message.
Co-authored-by: default avatarInphi <mlaw2501@gmail.com>

---------
Co-authored-by: default avatarInphi <mlaw2501@gmail.com>
parent 66a67e02
......@@ -319,7 +319,6 @@ func TestInteropFaultProofs(gt *testing.T) {
agreedClaim: step1Expected,
disputedClaim: step2Expected,
expectValid: true,
skip: true,
},
{
name: "PaddingStep",
......
......@@ -39,42 +39,42 @@ func NewCachingOracle(oracle Oracle) *CachingOracle {
}
}
func (o *CachingOracle) NodeByHash(nodeHash common.Hash) []byte {
func (o *CachingOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
node, ok := o.nodes.Get(nodeHash)
if ok {
return node
}
node = o.oracle.NodeByHash(nodeHash)
node = o.oracle.NodeByHash(nodeHash, chainID)
o.nodes.Add(nodeHash, node)
return node
}
func (o *CachingOracle) CodeByHash(codeHash common.Hash) []byte {
func (o *CachingOracle) CodeByHash(codeHash common.Hash, chainID uint64) []byte {
code, ok := o.codes.Get(codeHash)
if ok {
return code
}
code = o.oracle.CodeByHash(codeHash)
code = o.oracle.CodeByHash(codeHash, chainID)
o.codes.Add(codeHash, code)
return code
}
func (o *CachingOracle) BlockByHash(blockHash common.Hash) *types.Block {
func (o *CachingOracle) BlockByHash(blockHash common.Hash, chainID uint64) *types.Block {
block, ok := o.blocks.Get(blockHash)
if ok {
return block
}
block = o.oracle.BlockByHash(blockHash)
block = o.oracle.BlockByHash(blockHash, chainID)
o.blocks.Add(blockHash, block)
return block
}
func (o *CachingOracle) OutputByRoot(root common.Hash) eth.Output {
func (o *CachingOracle) OutputByRoot(root common.Hash, chainID uint64) eth.Output {
output, ok := o.outputs.Get(root)
if ok {
return output
}
output = o.oracle.OutputByRoot(root)
output = o.oracle.OutputByRoot(root, chainID)
o.outputs.Add(root, output)
return output
}
......
......@@ -15,6 +15,7 @@ import (
var _ Oracle = (*CachingOracle)(nil)
func TestBlockByHash(t *testing.T) {
chainID := uint64(48294)
stub, _ := test.NewStubOracle(t)
oracle := NewCachingOracle(stub)
......@@ -23,12 +24,12 @@ func TestBlockByHash(t *testing.T) {
// Initial call retrieves from the stub
stub.Blocks[block.Hash()] = block
actual := oracle.BlockByHash(block.Hash())
actual := oracle.BlockByHash(block.Hash(), chainID)
require.Equal(t, block, actual)
// Later calls should retrieve from cache
// Later calls should retrieve from cache (even if chain ID is different)
delete(stub.Blocks, block.Hash())
actual = oracle.BlockByHash(block.Hash())
actual = oracle.BlockByHash(block.Hash(), 9982)
require.Equal(t, block, actual)
}
......@@ -41,12 +42,12 @@ func TestNodeByHash(t *testing.T) {
// Initial call retrieves from the stub
stateStub.Data[hash] = node
actual := oracle.NodeByHash(hash)
actual := oracle.NodeByHash(hash, 1234)
require.Equal(t, node, actual)
// Later calls should retrieve from cache
// Later calls should retrieve from cache (even if chain ID is different)
delete(stateStub.Data, hash)
actual = oracle.NodeByHash(hash)
actual = oracle.NodeByHash(hash, 997845)
require.Equal(t, node, actual)
}
......@@ -59,12 +60,12 @@ func TestCodeByHash(t *testing.T) {
// Initial call retrieves from the stub
stateStub.Code[hash] = node
actual := oracle.CodeByHash(hash)
actual := oracle.CodeByHash(hash, 342)
require.Equal(t, node, actual)
// Later calls should retrieve from cache
// Later calls should retrieve from cache (even if the chain ID is different)
delete(stateStub.Code, hash)
actual = oracle.CodeByHash(hash)
actual = oracle.CodeByHash(hash, 986776)
require.Equal(t, node, actual)
}
......@@ -78,11 +79,11 @@ func TestOutputByRoot(t *testing.T) {
// Initial call retrieves from the stub
root := common.Hash(eth.OutputRoot(output))
stub.Outputs[root] = output
actual := oracle.OutputByRoot(root)
actual := oracle.OutputByRoot(root, 59284)
require.Equal(t, output, actual)
// Later calls should retrieve from cache
// Later calls should retrieve from cache (even if the chain ID is different)
delete(stub.Outputs, root)
actual = oracle.OutputByRoot(root)
actual = oracle.OutputByRoot(root, 9193)
require.Equal(t, output, actual)
}
......@@ -18,12 +18,14 @@ var ErrInvalidKeyLength = errors.New("pre-images must be identified by 32-byte h
type OracleKeyValueStore struct {
db ethdb.KeyValueStore
oracle StateOracle
chainID uint64
}
func NewOracleBackedDB(oracle StateOracle) *OracleKeyValueStore {
func NewOracleBackedDB(oracle StateOracle, chainID uint64) *OracleKeyValueStore {
return &OracleKeyValueStore{
db: memorydb.New(),
oracle: oracle,
chainID: chainID,
}
}
......@@ -38,12 +40,12 @@ func (o *OracleKeyValueStore) Get(key []byte) ([]byte, error) {
if len(key) == codePrefixedKeyLength && bytes.HasPrefix(key, rawdb.CodePrefix) {
key = key[len(rawdb.CodePrefix):]
return o.oracle.CodeByHash(*(*[common.HashLength]byte)(key)), nil
return o.oracle.CodeByHash(*(*[common.HashLength]byte)(key), o.chainID), nil
}
if len(key) != common.HashLength {
return nil, ErrInvalidKeyLength
}
return o.oracle.NodeByHash(*(*[common.HashLength]byte)(key)), nil
return o.oracle.NodeByHash(*(*[common.HashLength]byte)(key), o.chainID), nil
}
func (o *OracleKeyValueStore) NewBatch() ethdb.Batch {
......
......@@ -34,7 +34,7 @@ var _ ethdb.KeyValueStore = (*OracleKeyValueStore)(nil)
func TestGet(t *testing.T) {
t.Run("IncorrectLengthKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle)
db := NewOracleBackedDB(oracle, 1234)
val, err := db.Get([]byte{1, 2, 3})
require.ErrorIs(t, err, ErrInvalidKeyLength)
require.Nil(t, val)
......@@ -42,7 +42,7 @@ func TestGet(t *testing.T) {
t.Run("KeyWithCodePrefix", func(t *testing.T) {
oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle)
db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0x12345678")
prefixedKey := append(rawdb.CodePrefix, key.Bytes()...)
......@@ -56,7 +56,7 @@ func TestGet(t *testing.T) {
t.Run("NormalKeyThatHappensToStartWithCodePrefix", func(t *testing.T) {
oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle)
db := NewOracleBackedDB(oracle, 1234)
key := make([]byte, common.HashLength)
copy(rawdb.CodePrefix, key)
fmt.Println(key[0])
......@@ -73,7 +73,7 @@ func TestGet(t *testing.T) {
expected := []byte{2, 6, 3, 8}
oracle := test.NewStubStateOracle(t)
oracle.Data[key] = expected
db := NewOracleBackedDB(oracle)
db := NewOracleBackedDB(oracle, 1234)
val, err := db.Get(key.Bytes())
require.NoError(t, err)
require.Equal(t, expected, val)
......@@ -83,7 +83,7 @@ func TestGet(t *testing.T) {
func TestPut(t *testing.T) {
t.Run("NewKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle)
db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0xAA4488")
value := []byte{2, 6, 3, 8}
err := db.Put(key.Bytes(), value)
......@@ -95,7 +95,7 @@ func TestPut(t *testing.T) {
})
t.Run("ReplaceKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle)
db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0xAA4488")
value1 := []byte{2, 6, 3, 8}
value2 := []byte{1, 2, 3}
......@@ -117,13 +117,13 @@ func TestSupportsStateDBOperations(t *testing.T) {
genesisBlock := l2Genesis.MustCommit(realDb, trieDB)
loader := test.NewKvStateOracle(t, realDb)
assertStateDataAvailable(t, NewOracleBackedDB(loader), l2Genesis, genesisBlock)
assertStateDataAvailable(t, NewOracleBackedDB(loader, 1234), l2Genesis, genesisBlock)
}
func TestUpdateState(t *testing.T) {
l2Genesis := createGenesis()
oracle := test.NewStubStateOracle(t)
db := rawdb.NewDatabase(NewOracleBackedDB(oracle))
db := rawdb.NewDatabase(NewOracleBackedDB(oracle, 1234))
trieDB := triedb.NewDatabase(db, &triedb.Config{HashDB: hashdb.Defaults})
genesisBlock := l2Genesis.MustCommit(db, trieDB)
......
......@@ -45,12 +45,13 @@ type OracleBackedL2Chain struct {
var _ engineapi.CachingEngineBackend = (*OracleBackedL2Chain)(nil)
func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle engineapi.PrecompileOracle, chainCfg *params.ChainConfig, l2OutputRoot common.Hash) (*OracleBackedL2Chain, error) {
output := oracle.OutputByRoot(l2OutputRoot)
chainID := chainCfg.ChainID.Uint64()
output := oracle.OutputByRoot(l2OutputRoot, chainID)
outputV0, ok := output.(*eth.OutputV0)
if !ok {
return nil, fmt.Errorf("unsupported L2 output version: %d", output.Version())
}
head := oracle.BlockByHash(outputV0.BlockHash)
head := oracle.BlockByHash(outputV0.BlockHash, chainID)
logger.Info("Loaded L2 head", "hash", head.Hash(), "number", head.Number())
return &OracleBackedL2Chain{
log: logger,
......@@ -69,7 +70,7 @@ func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle e
finalized: head.Header(),
oracleHead: head.Header(),
blocks: make(map[common.Hash]*types.Block),
db: NewOracleBackedDB(oracle),
db: NewOracleBackedDB(oracle, chainID),
vmCfg: vm.Config{
PrecompileOverrides: engineapi.CreatePrecompileOverrides(precompileOracle),
},
......@@ -122,7 +123,7 @@ func (o *OracleBackedL2Chain) GetBlockByHash(hash common.Hash) *types.Block {
return block
}
// Retrieve from the oracle
return o.oracle.BlockByHash(hash)
return o.oracle.BlockByHash(hash, o.chainCfg.ChainID.Uint64())
}
func (o *OracleBackedL2Chain) GetBlock(hash common.Hash, number uint64) *types.Block {
......
......@@ -378,7 +378,7 @@ func createBlock(t *testing.T, chain *OracleBackedL2Chain, opts ...blockCreateOp
require.NoError(t, err)
nonce := parentDB.GetNonce(fundedAddress)
config := chain.Config()
db := rawdb.NewDatabase(NewOracleBackedDB(chain.oracle))
db := rawdb.NewDatabase(NewOracleBackedDB(chain.oracle, config.ChainID.Uint64()))
blocks, _ := core.GenerateChain(config, parent, chain.Engine(), db, 1, func(i int, gen *core.BlockGen) {
rawTx := &types.DynamicFeeTx{
ChainID: config.ChainID,
......
......@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
preimage "github.com/ethereum-optimism/optimism/op-preimage"
)
......@@ -19,43 +20,95 @@ const (
HintAgreedPrestate = "agreed-pre-state"
)
type BlockHeaderHint common.Hash
type LegacyBlockHeaderHint common.Hash
var _ preimage.Hint = LegacyBlockHeaderHint{}
func (l LegacyBlockHeaderHint) Hint() string {
return HintL2BlockHeader + " " + (common.Hash)(l).String()
}
type HashAndChainID struct {
Hash common.Hash
ChainID uint64
}
func (h HashAndChainID) Marshal() []byte {
d := make([]byte, 32+8)
copy(d[:32], h.Hash[:])
binary.BigEndian.PutUint64(d[32:], h.ChainID)
return d
}
type BlockHeaderHint HashAndChainID
var _ preimage.Hint = BlockHeaderHint{}
func (l BlockHeaderHint) Hint() string {
return HintL2BlockHeader + " " + (common.Hash)(l).String()
return HintL2BlockHeader + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type TransactionsHint common.Hash
type LegacyTransactionsHint common.Hash
var _ preimage.Hint = LegacyTransactionsHint{}
func (l LegacyTransactionsHint) Hint() string {
return HintL2Transactions + " " + (common.Hash)(l).String()
}
type TransactionsHint HashAndChainID
var _ preimage.Hint = TransactionsHint{}
func (l TransactionsHint) Hint() string {
return HintL2Transactions + " " + (common.Hash)(l).String()
return HintL2Transactions + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type CodeHint common.Hash
type CodeHint HashAndChainID
var _ preimage.Hint = CodeHint{}
func (l CodeHint) Hint() string {
return HintL2Code + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyCodeHint common.Hash
var _ preimage.Hint = LegacyCodeHint{}
func (l LegacyCodeHint) Hint() string {
return HintL2Code + " " + (common.Hash)(l).String()
}
type StateNodeHint common.Hash
type StateNodeHint HashAndChainID
var _ preimage.Hint = StateNodeHint{}
func (l StateNodeHint) Hint() string {
return HintL2StateNode + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyStateNodeHint common.Hash
var _ preimage.Hint = LegacyStateNodeHint{}
func (l LegacyStateNodeHint) Hint() string {
return HintL2StateNode + " " + (common.Hash)(l).String()
}
type L2OutputHint common.Hash
type L2OutputHint HashAndChainID
var _ preimage.Hint = L2OutputHint{}
func (l L2OutputHint) Hint() string {
return HintL2Output + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyL2OutputHint common.Hash
var _ preimage.Hint = LegacyL2OutputHint{}
func (l LegacyL2OutputHint) Hint() string {
return HintL2Output + " " + (common.Hash)(l).String()
}
......
......@@ -19,11 +19,11 @@ type StateOracle interface {
// NodeByHash retrieves the merkle-patricia trie node pre-image for a given hash.
// Trie nodes may be from the world state trie or any account storage trie.
// Contract code is not stored as part of the trie and must be retrieved via CodeByHash
NodeByHash(nodeHash common.Hash) []byte
NodeByHash(nodeHash common.Hash, chainID uint64) []byte
// CodeByHash retrieves the contract code pre-image for a given hash.
// codeHash should be retrieved from the world state account for a contract.
CodeByHash(codeHash common.Hash) []byte
CodeByHash(codeHash common.Hash, chainID uint64) []byte
}
// Oracle defines the high-level API used to retrieve L2 data.
......@@ -32,9 +32,9 @@ type Oracle interface {
StateOracle
// BlockByHash retrieves the block with the given hash.
BlockByHash(blockHash common.Hash) *types.Block
BlockByHash(blockHash common.Hash, chainID uint64) *types.Block
OutputByRoot(root common.Hash) eth.Output
OutputByRoot(root common.Hash, chainID uint64) eth.Output
// BlockDataByHash retrieves the block, including all data used to construct it.
BlockDataByHash(agreedBlockHash, blockHash common.Hash, chainID uint64) *types.Block
......@@ -47,19 +47,25 @@ type Oracle interface {
type PreimageOracle struct {
oracle preimage.Oracle
hint preimage.Hinter
hintL2ChainIDs bool
}
var _ Oracle = (*PreimageOracle)(nil)
func NewPreimageOracle(raw preimage.Oracle, hint preimage.Hinter) *PreimageOracle {
func NewPreimageOracle(raw preimage.Oracle, hint preimage.Hinter, hintL2ChainIDs bool) *PreimageOracle {
return &PreimageOracle{
oracle: raw,
hint: hint,
hintL2ChainIDs: hintL2ChainIDs,
}
}
func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash) *types.Header {
p.hint.Hint(BlockHeaderHint(blockHash))
func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash, chainID uint64) *types.Header {
if p.hintL2ChainIDs {
p.hint.Hint(BlockHeaderHint{Hash: blockHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyBlockHeaderHint(blockHash))
}
headerRlp := p.oracle.Get(preimage.Keccak256Key(blockHash))
var header types.Header
if err := rlp.DecodeBytes(headerRlp, &header); err != nil {
......@@ -68,15 +74,19 @@ func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash) *types.Header
return &header
}
func (p *PreimageOracle) BlockByHash(blockHash common.Hash) *types.Block {
header := p.headerByBlockHash(blockHash)
txs := p.LoadTransactions(blockHash, header.TxHash)
func (p *PreimageOracle) BlockByHash(blockHash common.Hash, chainID uint64) *types.Block {
header := p.headerByBlockHash(blockHash, chainID)
txs := p.LoadTransactions(blockHash, header.TxHash, chainID)
return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs})
}
func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.Hash) []*types.Transaction {
p.hint.Hint(TransactionsHint(blockHash))
func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.Hash, chainID uint64) []*types.Transaction {
if p.hintL2ChainIDs {
p.hint.Hint(TransactionsHint{Hash: blockHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyTransactionsHint(blockHash))
}
opaqueTxs := mpt.ReadTrie(txHash, func(key common.Hash) []byte {
return p.oracle.Get(preimage.Keccak256Key(key))
......@@ -89,18 +99,30 @@ func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.H
return txs
}
func (p *PreimageOracle) NodeByHash(nodeHash common.Hash) []byte {
p.hint.Hint(StateNodeHint(nodeHash))
func (p *PreimageOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
if p.hintL2ChainIDs {
p.hint.Hint(StateNodeHint{Hash: nodeHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyStateNodeHint(nodeHash))
}
return p.oracle.Get(preimage.Keccak256Key(nodeHash))
}
func (p *PreimageOracle) CodeByHash(codeHash common.Hash) []byte {
p.hint.Hint(CodeHint(codeHash))
func (p *PreimageOracle) CodeByHash(codeHash common.Hash, chainID uint64) []byte {
if p.hintL2ChainIDs {
p.hint.Hint(CodeHint{Hash: codeHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyCodeHint(codeHash))
}
return p.oracle.Get(preimage.Keccak256Key(codeHash))
}
func (p *PreimageOracle) OutputByRoot(l2OutputRoot common.Hash) eth.Output {
p.hint.Hint(L2OutputHint(l2OutputRoot))
func (p *PreimageOracle) OutputByRoot(l2OutputRoot common.Hash, chainID uint64) eth.Output {
if p.hintL2ChainIDs {
p.hint.Hint(L2OutputHint{Hash: l2OutputRoot, ChainID: chainID})
} else {
p.hint.Hint(LegacyL2OutputHint(l2OutputRoot))
}
data := p.oracle.Get(preimage.Keccak256Key(l2OutputRoot))
output, err := eth.UnmarshalOutput(data)
if err != nil {
......@@ -116,8 +138,8 @@ func (p *PreimageOracle) BlockDataByHash(agreedBlockHash, blockHash common.Hash,
ChainID: chainID,
}
p.hint.Hint(hint)
header := p.headerByBlockHash(blockHash)
txs := p.LoadTransactions(blockHash, header.TxHash)
header := p.headerByBlockHash(blockHash, chainID)
txs := p.LoadTransactions(blockHash, header.TxHash, chainID)
return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs})
}
......
......@@ -19,29 +19,27 @@ import (
"github.com/ethereum-optimism/optimism/op-service/testutils"
)
func mockPreimageOracle(t *testing.T) (po *PreimageOracle, hintsMock *mock.Mock, preimages map[common.Hash][]byte) {
func mockPreimageOracle(t *testing.T, hintL2ChainIDs bool) (po *PreimageOracle, hintsMock *mock.Mock, preimages map[common.Hash][]byte) {
// Prepare the pre-images
preimages = make(map[common.Hash][]byte)
hintsMock = new(mock.Mock)
po = &PreimageOracle{
oracle: preimage.OracleFn(func(key preimage.Key) []byte {
rawOracle := preimage.OracleFn(func(key preimage.Key) []byte {
v, ok := preimages[key.PreimageKey()]
require.True(t, ok, "preimage must exist")
return v
}),
hint: preimage.HinterFn(func(v preimage.Hint) {
})
hinter := preimage.HinterFn(func(v preimage.Hint) {
hintsMock.MethodCalled("hint", v.Hint())
}),
}
return po, hintsMock, preimages
})
po = NewPreimageOracle(rawOracle, hinter, hintL2ChainIDs)
return
}
// testBlock tests that the given block can be passed through the preimage oracle.
func testBlock(t *testing.T, block *types.Block) {
po, hints, preimages := mockPreimageOracle(t)
func testBlock(t *testing.T, block *types.Block, hintL2ChainIDs bool) {
po, hints, preimages := mockPreimageOracle(t, hintL2ChainIDs)
hdrBytes, err := rlp.EncodeToBytes(block.Header())
require.NoError(t, err)
......@@ -54,12 +52,19 @@ func testBlock(t *testing.T, block *types.Block) {
preimages[preimage.Keccak256Key(crypto.Keccak256Hash(p)).PreimageKey()] = p
}
chainID := uint64(4924)
// Prepare a raw mock pre-image oracle that will serve the pre-image data and handle hints
// Check if blocks with txs work
hints.On("hint", BlockHeaderHint(block.Hash()).Hint()).Once().Return()
hints.On("hint", TransactionsHint(block.Hash()).Hint()).Once().Return()
gotBlock := po.BlockByHash(block.Hash())
if hintL2ChainIDs {
hints.On("hint", BlockHeaderHint{Hash: block.Hash(), ChainID: chainID}.Hint()).Once().Return()
hints.On("hint", TransactionsHint{Hash: block.Hash(), ChainID: chainID}.Hint()).Once().Return()
} else {
hints.On("hint", LegacyBlockHeaderHint(block.Hash()).Hint()).Once().Return()
hints.On("hint", LegacyTransactionsHint(block.Hash()).Hint()).Once().Return()
}
gotBlock := po.BlockByHash(block.Hash(), chainID)
hints.AssertExpectations(t)
require.Equal(t, gotBlock.Hash(), block.Hash())
......@@ -75,8 +80,12 @@ func TestPreimageOracleBlockByHash(t *testing.T) {
for i := 0; i < 10; i++ {
block, _ := testutils.RandomBlock(rng, 10)
t.Run(fmt.Sprintf("legacy_block_%d", i), func(t *testing.T) {
testBlock(t, block, false)
})
t.Run(fmt.Sprintf("block_%d", i), func(t *testing.T) {
testBlock(t, block)
testBlock(t, block, true)
})
}
}
......@@ -85,8 +94,24 @@ func TestPreimageOracleNodeByHash(t *testing.T) {
rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_node_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
node := make([]byte, 123)
rng.Read(node)
h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", LegacyStateNodeHint(h).Hint()).Once().Return()
gotNode := po.NodeByHash(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "node matches")
})
t.Run(fmt.Sprintf("node_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t)
po, hints, preimages := mockPreimageOracle(t, true)
node := make([]byte, 123)
rng.Read(node)
......@@ -94,8 +119,8 @@ func TestPreimageOracleNodeByHash(t *testing.T) {
h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", StateNodeHint(h).Hint()).Once().Return()
gotNode := po.NodeByHash(h)
hints.On("hint", StateNodeHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotNode := po.NodeByHash(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "node matches")
})
......@@ -106,8 +131,24 @@ func TestPreimageOracleCodeByHash(t *testing.T) {
rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_code_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
node := make([]byte, 123)
rng.Read(node)
h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", LegacyCodeHint(h).Hint()).Once().Return()
gotNode := po.CodeByHash(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "code matches")
})
t.Run(fmt.Sprintf("code_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t)
po, hints, preimages := mockPreimageOracle(t, true)
node := make([]byte, 123)
rng.Read(node)
......@@ -115,8 +156,8 @@ func TestPreimageOracleCodeByHash(t *testing.T) {
h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", CodeHint(h).Hint()).Once().Return()
gotNode := po.CodeByHash(h)
hints.On("hint", CodeHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotNode := po.CodeByHash(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "code matches")
})
......@@ -127,14 +168,26 @@ func TestPreimageOracleOutputByRoot(t *testing.T) {
rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_output_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
output := testutils.RandomOutputV0(rng)
h := common.Hash(eth.OutputRoot(output))
preimages[preimage.Keccak256Key(h).PreimageKey()] = output.Marshal()
hints.On("hint", LegacyL2OutputHint(h).Hint()).Once().Return()
gotOutput := po.OutputByRoot(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(output.Marshal()), hexutil.Bytes(gotOutput.Marshal()), "output matches")
})
t.Run(fmt.Sprintf("output_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t)
po, hints, preimages := mockPreimageOracle(t, true)
output := testutils.RandomOutputV0(rng)
h := common.Hash(eth.OutputRoot(output))
preimages[preimage.Keccak256Key(h).PreimageKey()] = output.Marshal()
hints.On("hint", L2OutputHint(h).Hint()).Once().Return()
gotOutput := po.OutputByRoot(h)
hints.On("hint", L2OutputHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotOutput := po.OutputByRoot(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(output.Marshal()), hexutil.Bytes(gotOutput.Marshal()), "output matches")
})
......
......@@ -15,8 +15,8 @@ import (
// Same as l2.StateOracle but need to use our own copy to avoid dependency loops
type stateOracle interface {
NodeByHash(nodeHash common.Hash) []byte
CodeByHash(codeHash common.Hash) []byte
NodeByHash(nodeHash common.Hash, chainID uint64) []byte
CodeByHash(codeHash common.Hash, chainID uint64) []byte
}
type StubBlockOracle struct {
......@@ -56,7 +56,7 @@ func NewStubOracleWithBlocks(t *testing.T, chain []*gethTypes.Block, outputs []e
}
}
func (o StubBlockOracle) BlockByHash(blockHash common.Hash) *gethTypes.Block {
func (o StubBlockOracle) BlockByHash(blockHash common.Hash, chainID uint64) *gethTypes.Block {
block, ok := o.Blocks[blockHash]
if !ok {
o.t.Fatalf("requested unknown block %s", blockHash)
......@@ -64,7 +64,7 @@ func (o StubBlockOracle) BlockByHash(blockHash common.Hash) *gethTypes.Block {
return block
}
func (o StubBlockOracle) OutputByRoot(root common.Hash) eth.Output {
func (o StubBlockOracle) OutputByRoot(root common.Hash, chainID uint64) eth.Output {
output, ok := o.Outputs[root]
if !ok {
o.t.Fatalf("requested unknown output root %s", root)
......@@ -100,7 +100,7 @@ func NewKvStateOracle(t *testing.T, db ethdb.KeyValueStore) *KvStateOracle {
}
}
func (o *KvStateOracle) NodeByHash(nodeHash common.Hash) []byte {
func (o *KvStateOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
val, err := o.Source.Get(nodeHash.Bytes())
if err != nil {
o.t.Fatalf("error retrieving node %v: %v", nodeHash, err)
......@@ -108,7 +108,7 @@ func (o *KvStateOracle) NodeByHash(nodeHash common.Hash) []byte {
return val
}
func (o *KvStateOracle) CodeByHash(hash common.Hash) []byte {
func (o *KvStateOracle) CodeByHash(hash common.Hash, chainID uint64) []byte {
return rawdb.ReadCode(o.Source, hash)
}
......@@ -127,7 +127,7 @@ type StubStateOracle struct {
Code map[common.Hash][]byte
}
func (o *StubStateOracle) NodeByHash(nodeHash common.Hash) []byte {
func (o *StubStateOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
data, ok := o.Data[nodeHash]
if !ok {
o.t.Fatalf("no value for node %v", nodeHash)
......@@ -135,7 +135,7 @@ func (o *StubStateOracle) NodeByHash(nodeHash common.Hash) []byte {
return data
}
func (o *StubStateOracle) CodeByHash(hash common.Hash) []byte {
func (o *StubStateOracle) CodeByHash(hash common.Hash, chainID uint64) []byte {
data, ok := o.Code[hash]
if !ok {
o.t.Fatalf("no value for code %v", hash)
......
......@@ -45,7 +45,7 @@ func RunProgram(logger log.Logger, preimageOracle io.ReadWriter, preimageHinter
pClient := preimage.NewOracleClient(preimageOracle)
hClient := preimage.NewHintWriter(preimageHinter)
l1PreimageOracle := l1.NewCachingOracle(l1.NewPreimageOracle(pClient, hClient))
l2PreimageOracle := l2.NewCachingOracle(l2.NewPreimageOracle(pClient, hClient))
l2PreimageOracle := l2.NewCachingOracle(l2.NewPreimageOracle(pClient, hClient, cfg.InteropEnabled))
if cfg.InteropEnabled {
bootInfo := boot.BootstrapInterop(pClient)
......
......@@ -32,3 +32,7 @@ func NewL2Client(client client.RPC, log log.Logger, metrics caching.Metrics, con
func (s *L2Client) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error) {
return s.OutputV0AtBlock(ctx, blockRoot)
}
func (s *L2Client) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
return s.OutputV0AtBlockNumber(ctx, blockNum)
}
......@@ -117,6 +117,14 @@ func (l *L2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth
return l.canonicalEthClient.OutputByRoot(ctx, blockRoot)
}
// OutputByBlockNumber implements prefetcher.L2Source.
func (l *L2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
if l.ExperimentalEnabled() {
return l.experimentalClient.OutputByNumber(ctx, blockNum)
}
return l.canonicalEthClient.OutputByNumber(ctx, blockNum)
}
// ExecutionWitness implements prefetcher.L2Source.
func (l *L2Source) ExecutionWitness(ctx context.Context, blockNum uint64) (*eth.ExecutionWitness, error) {
if !l.ExperimentalEnabled() {
......
......@@ -9,6 +9,7 @@ import (
"strings"
preimage "github.com/ethereum-optimism/optimism/op-preimage"
clientTypes "github.com/ethereum-optimism/optimism/op-program/client/interop/types"
"github.com/ethereum-optimism/optimism/op-program/client/l1"
"github.com/ethereum-optimism/optimism/op-program/client/l2"
"github.com/ethereum-optimism/optimism/op-program/client/mpt"
......@@ -258,11 +259,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
}
return p.kvStore.Put(preimage.PrecompileKey(inputHash).PreimageKey(), result)
case l2.HintL2BlockHeader, l2.HintL2Transactions:
if len(hintBytes) != 32 {
return fmt.Errorf("invalid L2 header/tx hint: %x", hint)
hash, chainID, err := p.parseHashAndChainID("L2 header/tx", hintBytes)
if err != nil {
return err
}
hash := common.Hash(hintBytes)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
source, err := p.l2Sources.ForChainID(chainID)
if err != nil {
return err
}
......@@ -280,11 +281,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
}
return p.storeTransactions(txs)
case l2.HintL2StateNode:
if len(hintBytes) != 32 {
return fmt.Errorf("invalid L2 state node hint: %x", hint)
hash, chainID, err := p.parseHashAndChainID("L2 state node", hintBytes)
if err != nil {
return err
}
hash := common.Hash(hintBytes)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
source, err := p.l2Sources.ForChainID(chainID)
if err != nil {
return err
}
......@@ -294,11 +295,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
}
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), node)
case l2.HintL2Code:
if len(hintBytes) != 32 {
return fmt.Errorf("invalid L2 code hint: %x", hint)
hash, chainID, err := p.parseHashAndChainID("L2 code", hintBytes)
if err != nil {
return err
}
hash := common.Hash(hintBytes)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
source, err := p.l2Sources.ForChainID(chainID)
if err != nil {
return err
}
......@@ -308,14 +309,15 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
}
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), code)
case l2.HintL2Output:
if len(hintBytes) != 32 {
return fmt.Errorf("invalid L2 output hint: %x", hint)
requestedHash, chainID, err := p.parseHashAndChainID("L2 output", hintBytes)
if err != nil {
return err
}
requestedHash := common.Hash(hintBytes)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
source, err := p.l2Sources.ForChainID(chainID)
if err != nil {
return err
}
if len(p.agreedPrestate) == 0 {
output, err := source.OutputByRoot(ctx, p.l2Head)
if err != nil {
return fmt.Errorf("failed to fetch L2 output root for block %s: %w", p.l2Head, err)
......@@ -325,6 +327,29 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
return fmt.Errorf("output root %v from block %v does not match requested root: %v", hash, p.l2Head, requestedHash)
}
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), output.Marshal())
} else {
prestate, err := clientTypes.UnmarshalTransitionState(p.agreedPrestate)
if err != nil {
return fmt.Errorf("cannot fetch output root, invalid agreed prestate: %w", err)
}
superRoot, err := eth.UnmarshalSuperRoot(prestate.SuperRoot)
if err != nil {
return fmt.Errorf("cannot fetch output root, invalid super root in prestate: %w", err)
}
superV1, ok := superRoot.(*eth.SuperV1)
if !ok {
return fmt.Errorf("cannot fetch output root, unsupported super root version in prestate: %v", superRoot.Version())
}
blockNum, err := source.RollupConfig().TargetBlockNumber(superV1.Timestamp)
if err != nil {
return fmt.Errorf("cannot fetch output root, failed to calculate block number at timestamp %v: %w", superV1.Timestamp, err)
}
output, err := source.OutputByNumber(ctx, blockNum)
if err != nil {
return fmt.Errorf("failed to fetch L2 output root for block %v: %w", blockNum, err)
}
return p.kvStore.Put(preimage.Keccak256Key(eth.OutputRoot(output)).PreimageKey(), output.Marshal())
}
case l2.HintL2BlockData:
if p.executor == nil {
return fmt.Errorf("this prefetcher does not support native block execution")
......@@ -353,6 +378,17 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
return fmt.Errorf("unknown hint type: %v", hintType)
}
func (p *Prefetcher) parseHashAndChainID(hintType string, hintBytes []byte) (common.Hash, uint64, error) {
switch len(hintBytes) {
case 32:
return common.Hash(hintBytes), p.defaultChainID, nil
case 40:
return common.Hash(hintBytes[0:32]), binary.BigEndian.Uint64(hintBytes[32:]), nil
default:
return common.Hash{}, 0, fmt.Errorf("invalid %s hint: %x", hintType, hintBytes)
}
}
type BlockDataKey [32]byte
func (p BlockDataKey) Key() [32]byte {
......
......@@ -34,6 +34,7 @@ import (
var (
ecRecoverInput = common.FromHex("18c547e4f7b0f325ad1e56f57e26c745b09a3e503d86e00e5255ff7f715d3d1c000000000000000000000000000000000000000000000000000000000000001c73b1693892219d736caba55bdb67216e485557ea6b6af75f37096c9aa6a5a75feeb940b1d03b21e36b0e47e79769f095fe2ab855bd91e3a38756b7d75a9c4549")
kzgPointEvalInput = common.FromHex("01e798154708fe7789429634053cbf9f99b619f9f084048927333fce637f549b564c0a11a0f704f4fc3e8acfe0f8245f0ad1347b378fbf96e206da11a5d3630624d25032e67a7e6a4910df5834b8fe70e6bcfeeac0352434196bdf4b2485d5a18f59a8d2a1a625a17f3fea0fe5eb8c896db3764f3185481bc22f91b4aaffcca25f26936857bc3a7c2539ea8ec3a952b7873033e038326e87ed3e1276fd140253fa08e9fc25fb2d9a98527fc22a2c9612fbeafdad446cbc7bcdbdcd780af2c16a")
defaultChainID = uint64(14)
)
func TestNoHint(t *testing.T) {
......@@ -415,6 +416,7 @@ func TestRestrictedPrecompileContracts(t *testing.T) {
func TestFetchL2Block(t *testing.T) {
rng := rand.New(rand.NewSource(123))
chainID := uint64(482948)
block, rcpts := testutils.RandomBlock(rng, 10)
hash := block.Hash()
......@@ -422,19 +424,32 @@ func TestFetchL2Block(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t)
storeBlock(t, kv, block, rcpts)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.BlockByHash(hash)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.BlockByHash(hash, chainID)
require.EqualValues(t, block.Header(), result.Header())
assertTransactionsEqual(t, block.Transactions(), result.Transactions())
})
t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cl, _ := createPrefetcher(t)
prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil)
defer l2Cl.MockL2Client.AssertExpectations(t)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.BlockByHash(hash)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.BlockByHash(hash, chainID)
require.EqualValues(t, block.Header(), result.Header())
assertTransactionsEqual(t, block.Transactions(), result.Transactions())
})
t.Run("WithChainID", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t, 5, 7, 10)
l2Cl := l2Cls.sources[7]
l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.BlockByHash(hash, 7)
require.EqualValues(t, block.Header(), result.Header())
assertTransactionsEqual(t, block.Transactions(), result.Transactions())
})
......@@ -444,23 +459,36 @@ func TestFetchL2Transactions(t *testing.T) {
rng := rand.New(rand.NewSource(123))
block, rcpts := testutils.RandomBlock(rng, 10)
hash := block.Hash()
chainID := rng.Uint64()
t.Run("AlreadyKnown", func(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t)
storeBlock(t, kv, block, rcpts)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.LoadTransactions(hash, block.TxHash())
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.LoadTransactions(hash, block.TxHash(), chainID)
assertTransactionsEqual(t, block.Transactions(), result)
})
t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cl, _ := createPrefetcher(t)
prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil)
defer l2Cl.MockL2Client.AssertExpectations(t)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.LoadTransactions(hash, block.TxHash())
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.LoadTransactions(hash, block.TxHash(), chainID)
assertTransactionsEqual(t, block.Transactions(), result)
})
t.Run("WithChainID", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t, 5, 7, 10)
l2Cl := l2Cls.sources[7]
l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.LoadTransactions(hash, block.TxHash(), 7)
assertTransactionsEqual(t, block.Transactions(), result)
})
}
......@@ -470,23 +498,36 @@ func TestFetchL2Node(t *testing.T) {
node := testutils.RandomData(rng, 30)
hash := crypto.Keccak256Hash(node)
key := preimage.Keccak256Key(hash).PreimageKey()
chainID := rng.Uint64()
t.Run("AlreadyKnown", func(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t)
require.NoError(t, kv.Put(key, node))
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.NodeByHash(hash)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.NodeByHash(hash, chainID)
require.EqualValues(t, node, result)
})
t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cl, _ := createPrefetcher(t)
prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectNodeByHash(hash, node, nil)
defer l2Cl.MockDebugClient.AssertExpectations(t)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.NodeByHash(hash, chainID)
require.EqualValues(t, node, result)
})
t.Run("WithChainID", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t, 5, 9, 99)
l2Cl := l2Cls.sources[9]
l2Cl.ExpectNodeByHash(hash, node, nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.NodeByHash(hash)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.NodeByHash(hash, 9)
require.EqualValues(t, node, result)
})
}
......@@ -496,32 +537,96 @@ func TestFetchL2Code(t *testing.T) {
code := testutils.RandomData(rng, 30)
hash := crypto.Keccak256Hash(code)
key := preimage.Keccak256Key(hash).PreimageKey()
chainID := rng.Uint64()
t.Run("AlreadyKnown", func(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t)
require.NoError(t, kv.Put(key, code))
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.CodeByHash(hash)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.CodeByHash(hash, chainID)
require.EqualValues(t, code, result)
})
t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cl, _ := createPrefetcher(t)
prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectCodeByHash(hash, code, nil)
defer l2Cl.MockDebugClient.AssertExpectations(t)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.CodeByHash(hash)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.CodeByHash(hash, chainID)
require.EqualValues(t, code, result)
})
t.Run("WithChainID", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t, 8, 45, 98, 55)
l2Cl := l2Cls.sources[98]
l2Cl.ExpectCodeByHash(hash, code, nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.CodeByHash(hash, 98)
require.EqualValues(t, code, result)
})
}
func TestFetchL2Output(t *testing.T) {
rng := rand.New(rand.NewSource(123))
output := testutils.RandomOutputV0(rng)
hash := common.Hash(eth.OutputRoot(output))
key := preimage.Keccak256Key(hash).PreimageKey()
t.Run("AlreadyKnown", func(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t)
require.NoError(t, kv.Put(key, output.Marshal()))
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.OutputByRoot(hash, rng.Uint64())
require.EqualValues(t, output, result)
})
t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectOutputByRoot(prefetcher.l2Head, output, nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.OutputByRoot(hash, rng.Uint64())
require.EqualValues(t, output, result)
})
t.Run("WithChainID", func(t *testing.T) {
chain6Output := testutils.RandomOutputV0(rng)
chain99Output := testutils.RandomOutputV0(rng)
timestamp := uint64(4567882)
superV1 := eth.SuperV1{
Timestamp: timestamp,
Chains: []eth.ChainIDAndOutput{
{ChainID: 6, Output: eth.OutputRoot(chain6Output)},
{ChainID: 78, Output: eth.OutputRoot(output)},
{ChainID: 99, Output: eth.OutputRoot(chain99Output)},
},
}
prefetcher, _, _, l2Cls, _ := createPrefetcherWithAgreedPrestate(t, superV1.Marshal(), 6, 78, 99)
l2Cl := l2Cls.sources[78]
blockNum, err := l2Cls.sources[78].RollupConfig().TargetBlockNumber(timestamp)
require.NoError(t, err)
l2Cl.ExpectOutputByNumber(blockNum, output, nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.OutputByRoot(hash, 78)
require.EqualValues(t, output, result)
})
}
func TestFetchL2BlockData(t *testing.T) {
chainID := uint64(14)
testBlockExec := func(t *testing.T, err error) {
prefetcher, _, _, l2Client, _ := createPrefetcher(t)
prefetcher, _, _, l2Clients, _ := createPrefetcher(t)
l2Client := l2Clients.sources[defaultChainID]
rng := rand.New(rand.NewSource(123))
block, _ := testutils.RandomBlock(rng, 10)
disputedBlockHash := common.Hash{0xab}
......@@ -644,21 +749,23 @@ func TestRetryWhenNotAvailableAfterPrefetching(t *testing.T) {
rng := rand.New(rand.NewSource(123))
node := testutils.RandomData(rng, 30)
hash := crypto.Keccak256Hash(node)
chainID := rng.Uint64()
_, l1Source, l1BlobSource, l2Cl, kv := createPrefetcher(t)
_, l1Source, l1BlobSource, l2Cls, kv := createPrefetcher(t)
putsToIgnore := 2
kv = &unreliableKvStore{KV: kv, putsToIgnore: putsToIgnore}
sources := &l2Clients{sources: map[uint64]hostTypes.L2Source{6: l2Cl}}
sources := &l2Clients{sources: map[uint64]*l2Client{6: l2Cls.sources[defaultChainID]}}
prefetcher := NewPrefetcher(testlog.Logger(t, log.LevelInfo), l1Source, l1BlobSource, 6, sources, kv, nil, common.Hash{}, nil)
l2Cl := sources.sources[6]
// Expect one call for each ignored put, plus one more request for when the put succeeds
for i := 0; i < putsToIgnore+1; i++ {
l2Cl.ExpectNodeByHash(hash, node, nil)
}
defer l2Cl.MockDebugClient.AssertExpectations(t)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher))
result := oracle.NodeByHash(hash)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.NodeByHash(hash, chainID)
require.EqualValues(t, node, result)
}
......@@ -677,7 +784,7 @@ func (s *unreliableKvStore) Put(k common.Hash, v []byte) error {
}
type l2Clients struct {
sources map[uint64]hostTypes.L2Source
sources map[uint64]*l2Client
}
func (l *l2Clients) ForChainID(id uint64) (hostTypes.L2Source, error) {
......@@ -695,10 +802,11 @@ func (l *l2Clients) ForChainIDWithoutRetries(id uint64) (hostTypes.L2Source, err
type l2Client struct {
*testutils.MockL2Client
*testutils.MockDebugClient
rollupCfg *rollup.Config
}
func (m *l2Client) RollupConfig() *rollup.Config {
panic("implement me")
return m.rollupCfg
}
func (m *l2Client) ExperimentalEnabled() bool {
......@@ -714,25 +822,47 @@ func (m *l2Client) ExpectOutputByRoot(blockRoot common.Hash, output eth.Output,
m.Mock.On("OutputByRoot", blockRoot).Once().Return(output, &err)
}
func createPrefetcher(t *testing.T) (*Prefetcher, *testutils.MockL1Source, *testutils.MockBlobsFetcher, *l2Client, kvstore.KV) {
return createPrefetcherWithAgreedPrestate(t, nil)
func (m *l2Client) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
out := m.Mock.MethodCalled("OutputByNumber", blockNum)
return out[0].(eth.Output), *out[1].(*error)
}
func createPrefetcherWithAgreedPrestate(t *testing.T, agreedPrestate []byte) (*Prefetcher, *testutils.MockL1Source, *testutils.MockBlobsFetcher, *l2Client, kvstore.KV) {
func (m *l2Client) ExpectOutputByNumber(blockNum uint64, output eth.Output, err error) {
m.Mock.On("OutputByNumber", blockNum).Once().Return(output, &err)
}
func createPrefetcher(t *testing.T, chainIDs ...uint64) (*Prefetcher, *testutils.MockL1Source, *testutils.MockBlobsFetcher, *l2Clients, kvstore.KV) {
return createPrefetcherWithAgreedPrestate(t, nil, chainIDs...)
}
func createPrefetcherWithAgreedPrestate(t *testing.T, agreedPrestate []byte, chainIDs ...uint64) (*Prefetcher, *testutils.MockL1Source, *testutils.MockBlobsFetcher, *l2Clients, kvstore.KV) {
logger := testlog.Logger(t, log.LevelDebug)
kv := kvstore.NewMemKV()
l1Source := new(testutils.MockL1Source)
l1BlobSource := new(testutils.MockBlobsFetcher)
// Provide a default chain if none specified.
if len(chainIDs) == 0 {
chainIDs = []uint64{defaultChainID}
}
l2Sources := &l2Clients{sources: make(map[uint64]*l2Client)}
for i, chainID := range chainIDs {
l2Source := &l2Client{
rollupCfg: &rollup.Config{
// Make the block numbers for each chain differ at each timestamp
Genesis: rollup.Genesis{L2Time: 500 + uint64(2*i)},
BlockTime: 1,
},
MockL2Client: new(testutils.MockL2Client),
MockDebugClient: new(testutils.MockDebugClient),
}
l2Sources := &l2Clients{
sources: map[uint64]hostTypes.L2Source{14: l2Source},
l2Sources.sources[chainID] = l2Source
}
prefetcher := NewPrefetcher(logger, l1Source, l1BlobSource, 14, l2Sources, kv, nil, common.Hash{0xdd}, agreedPrestate)
return prefetcher, l1Source, l1BlobSource, l2Source, kv
prefetcher := NewPrefetcher(logger, l1Source, l1BlobSource, chainIDs[0], l2Sources, kv, nil, common.Hash{0xdd}, agreedPrestate)
return prefetcher, l1Source, l1BlobSource, l2Sources, kv
}
func storeBlock(t *testing.T, kv kvstore.KV, block *types.Block, receipts types.Receipts) {
......@@ -847,3 +977,10 @@ func (m *mockExecutor) RunProgram(
m.chainID = chainID
return nil
}
func assertAllClientExpectations(t *testing.T, l2Cls *l2Clients) {
for _, source := range l2Cls.sources {
source.Mock.AssertExpectations(t)
source.MockDebugClient.Mock.AssertExpectations(t)
}
}
......@@ -152,6 +152,17 @@ func (s *RetryingL2Source) OutputByRoot(ctx context.Context, blockRoot common.Ha
})
}
func (s *RetryingL2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
return retry.Do(ctx, maxAttempts, s.strategy, func() (eth.Output, error) {
o, err := s.source.OutputByNumber(ctx, blockNum)
if err != nil {
s.logger.Warn("Failed to fetch l2 output", "block", blockNum, "err", err)
return o, err
}
return o, nil
})
}
func NewRetryingL2Source(logger log.Logger, source hosttypes.L2Source) *RetryingL2Source {
return &RetryingL2Source{
logger: logger,
......
......@@ -226,7 +226,8 @@ func createL1BlobSource(t *testing.T) (*RetryingL1BlobSource, *testutils.MockBlo
func TestRetryingL2Source(t *testing.T) {
ctx := context.Background()
hash := common.Hash{0xab}
info := &testutils.MockBlockInfo{InfoHash: hash}
blockNum := uint64(14982)
info := &testutils.MockBlockInfo{InfoHash: hash, InfoNum: blockNum}
// The mock really doesn't like returning nil for a eth.BlockInfo so return a value we expect to be ignored instead
wrongInfo := &testutils.MockBlockInfo{InfoHash: common.Hash{0x99}}
txs := types.Transactions{
......@@ -325,6 +326,28 @@ func TestRetryingL2Source(t *testing.T) {
require.NoError(t, err)
require.Equal(t, output, actualOutput)
})
t.Run("OutputByNumber Success", func(t *testing.T) {
source, mock := createL2Source(t)
defer mock.AssertExpectations(t)
mock.ExpectOutputByNumber(blockNum, output, nil)
actualOutput, err := source.OutputByNumber(ctx, blockNum)
require.NoError(t, err)
require.Equal(t, output, actualOutput)
})
t.Run("OutputByNumber Error", func(t *testing.T) {
source, mock := createL2Source(t)
defer mock.AssertExpectations(t)
expectedErr := errors.New("boom")
mock.ExpectOutputByNumber(blockNum, wrongOutput, expectedErr)
mock.ExpectOutputByNumber(blockNum, output, nil)
actualOutput, err := source.OutputByNumber(ctx, blockNum)
require.NoError(t, err)
require.Equal(t, output, actualOutput)
})
}
func createL2Source(t *testing.T) (*RetryingL2Source, *MockL2Source) {
......@@ -370,6 +393,11 @@ func (m *MockL2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash)
return out[0].(eth.Output), *out[1].(*error)
}
func (m *MockL2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
out := m.Mock.MethodCalled("OutputByNumber", blockNum)
return out[0].(eth.Output), *out[1].(*error)
}
func (m *MockL2Source) ExpectInfoAndTxsByHash(blockHash common.Hash, info eth.BlockInfo, txs types.Transactions, err error) {
m.Mock.On("InfoAndTxsByHash", blockHash).Once().Return(info, txs, &err)
}
......@@ -386,4 +414,8 @@ func (m *MockL2Source) ExpectOutputByRoot(blockHash common.Hash, output eth.Outp
m.Mock.On("OutputByRoot", blockHash).Once().Return(output, &err)
}
func (m *MockL2Source) ExpectOutputByNumber(blockNum uint64, output eth.Output, err error) {
m.Mock.On("OutputByNumber", blockNum).Once().Return(output, &err)
}
var _ hosttypes.L2Source = (*MockL2Source)(nil)
......@@ -24,6 +24,7 @@ type L2Source interface {
NodeByHash(ctx context.Context, hash common.Hash) ([]byte, error)
CodeByHash(ctx context.Context, hash common.Hash) ([]byte, error)
OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error)
OutputByNumber(ctx context.Context, blockNumber uint64) (eth.Output, error)
RollupConfig() *rollup.Config
ExperimentalEnabled() bool
}
......
......@@ -172,15 +172,26 @@ func (s *L2Client) SystemConfigByL2Hash(ctx context.Context, hash common.Hash) (
return cfg, nil
}
func (s *L2Client) OutputV0AtBlockNumber(ctx context.Context, blockNum uint64) (*eth.OutputV0, error) {
head, err := s.InfoByNumber(ctx, blockNum)
if err != nil {
return nil, fmt.Errorf("failed to get L2 block by hash: %w", err)
}
return s.outputV0(ctx, head)
}
func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) (*eth.OutputV0, error) {
head, err := s.InfoByHash(ctx, blockHash)
if err != nil {
return nil, fmt.Errorf("failed to get L2 block by hash: %w", err)
}
if head == nil {
return s.outputV0(ctx, head)
}
func (s *L2Client) outputV0(ctx context.Context, block eth.BlockInfo) (*eth.OutputV0, error) {
if block == nil {
return nil, ethereum.NotFound
}
blockHash := block.Hash()
proof, err := s.GetProof(ctx, predeploys.L2ToL1MessagePasserAddr, []common.Hash{}, blockHash.String())
if err != nil {
return nil, fmt.Errorf("failed to get contract proof at block %s: %w", blockHash, err)
......@@ -189,10 +200,10 @@ func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) (
return nil, fmt.Errorf("proof %w", ethereum.NotFound)
}
// make sure that the proof (including storage hash) that we retrieved is correct by verifying it against the state-root
if err := proof.Verify(head.Root()); err != nil {
return nil, fmt.Errorf("invalid withdrawal root hash, state root was %s: %w", head.Root(), err)
if err := proof.Verify(block.Root()); err != nil {
return nil, fmt.Errorf("invalid withdrawal root hash, state root was %s: %w", block.Root(), err)
}
stateRoot := head.Root()
stateRoot := block.Root()
return &eth.OutputV0{
StateRoot: eth.Bytes32(stateRoot),
MessagePasserStorageRoot: eth.Bytes32(proof.StorageHash),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment