Commit ca583f7f authored by Adrian Sutton's avatar Adrian Sutton Committed by GitHub

op-program: Add chain ID to hints when using interop (#13735)

* op-program: Add L2 chain ID to hints when using interop

* op-program: Add L2 chain ID to OutputRoot hints when using interop

* Fix label in error message.
Co-authored-by: default avatarInphi <mlaw2501@gmail.com>

---------
Co-authored-by: default avatarInphi <mlaw2501@gmail.com>
parent 66a67e02
...@@ -319,7 +319,6 @@ func TestInteropFaultProofs(gt *testing.T) { ...@@ -319,7 +319,6 @@ func TestInteropFaultProofs(gt *testing.T) {
agreedClaim: step1Expected, agreedClaim: step1Expected,
disputedClaim: step2Expected, disputedClaim: step2Expected,
expectValid: true, expectValid: true,
skip: true,
}, },
{ {
name: "PaddingStep", name: "PaddingStep",
......
...@@ -39,42 +39,42 @@ func NewCachingOracle(oracle Oracle) *CachingOracle { ...@@ -39,42 +39,42 @@ func NewCachingOracle(oracle Oracle) *CachingOracle {
} }
} }
func (o *CachingOracle) NodeByHash(nodeHash common.Hash) []byte { func (o *CachingOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
node, ok := o.nodes.Get(nodeHash) node, ok := o.nodes.Get(nodeHash)
if ok { if ok {
return node return node
} }
node = o.oracle.NodeByHash(nodeHash) node = o.oracle.NodeByHash(nodeHash, chainID)
o.nodes.Add(nodeHash, node) o.nodes.Add(nodeHash, node)
return node return node
} }
func (o *CachingOracle) CodeByHash(codeHash common.Hash) []byte { func (o *CachingOracle) CodeByHash(codeHash common.Hash, chainID uint64) []byte {
code, ok := o.codes.Get(codeHash) code, ok := o.codes.Get(codeHash)
if ok { if ok {
return code return code
} }
code = o.oracle.CodeByHash(codeHash) code = o.oracle.CodeByHash(codeHash, chainID)
o.codes.Add(codeHash, code) o.codes.Add(codeHash, code)
return code return code
} }
func (o *CachingOracle) BlockByHash(blockHash common.Hash) *types.Block { func (o *CachingOracle) BlockByHash(blockHash common.Hash, chainID uint64) *types.Block {
block, ok := o.blocks.Get(blockHash) block, ok := o.blocks.Get(blockHash)
if ok { if ok {
return block return block
} }
block = o.oracle.BlockByHash(blockHash) block = o.oracle.BlockByHash(blockHash, chainID)
o.blocks.Add(blockHash, block) o.blocks.Add(blockHash, block)
return block return block
} }
func (o *CachingOracle) OutputByRoot(root common.Hash) eth.Output { func (o *CachingOracle) OutputByRoot(root common.Hash, chainID uint64) eth.Output {
output, ok := o.outputs.Get(root) output, ok := o.outputs.Get(root)
if ok { if ok {
return output return output
} }
output = o.oracle.OutputByRoot(root) output = o.oracle.OutputByRoot(root, chainID)
o.outputs.Add(root, output) o.outputs.Add(root, output)
return output return output
} }
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
var _ Oracle = (*CachingOracle)(nil) var _ Oracle = (*CachingOracle)(nil)
func TestBlockByHash(t *testing.T) { func TestBlockByHash(t *testing.T) {
chainID := uint64(48294)
stub, _ := test.NewStubOracle(t) stub, _ := test.NewStubOracle(t)
oracle := NewCachingOracle(stub) oracle := NewCachingOracle(stub)
...@@ -23,12 +24,12 @@ func TestBlockByHash(t *testing.T) { ...@@ -23,12 +24,12 @@ func TestBlockByHash(t *testing.T) {
// Initial call retrieves from the stub // Initial call retrieves from the stub
stub.Blocks[block.Hash()] = block stub.Blocks[block.Hash()] = block
actual := oracle.BlockByHash(block.Hash()) actual := oracle.BlockByHash(block.Hash(), chainID)
require.Equal(t, block, actual) require.Equal(t, block, actual)
// Later calls should retrieve from cache // Later calls should retrieve from cache (even if chain ID is different)
delete(stub.Blocks, block.Hash()) delete(stub.Blocks, block.Hash())
actual = oracle.BlockByHash(block.Hash()) actual = oracle.BlockByHash(block.Hash(), 9982)
require.Equal(t, block, actual) require.Equal(t, block, actual)
} }
...@@ -41,12 +42,12 @@ func TestNodeByHash(t *testing.T) { ...@@ -41,12 +42,12 @@ func TestNodeByHash(t *testing.T) {
// Initial call retrieves from the stub // Initial call retrieves from the stub
stateStub.Data[hash] = node stateStub.Data[hash] = node
actual := oracle.NodeByHash(hash) actual := oracle.NodeByHash(hash, 1234)
require.Equal(t, node, actual) require.Equal(t, node, actual)
// Later calls should retrieve from cache // Later calls should retrieve from cache (even if chain ID is different)
delete(stateStub.Data, hash) delete(stateStub.Data, hash)
actual = oracle.NodeByHash(hash) actual = oracle.NodeByHash(hash, 997845)
require.Equal(t, node, actual) require.Equal(t, node, actual)
} }
...@@ -59,12 +60,12 @@ func TestCodeByHash(t *testing.T) { ...@@ -59,12 +60,12 @@ func TestCodeByHash(t *testing.T) {
// Initial call retrieves from the stub // Initial call retrieves from the stub
stateStub.Code[hash] = node stateStub.Code[hash] = node
actual := oracle.CodeByHash(hash) actual := oracle.CodeByHash(hash, 342)
require.Equal(t, node, actual) require.Equal(t, node, actual)
// Later calls should retrieve from cache // Later calls should retrieve from cache (even if the chain ID is different)
delete(stateStub.Code, hash) delete(stateStub.Code, hash)
actual = oracle.CodeByHash(hash) actual = oracle.CodeByHash(hash, 986776)
require.Equal(t, node, actual) require.Equal(t, node, actual)
} }
...@@ -78,11 +79,11 @@ func TestOutputByRoot(t *testing.T) { ...@@ -78,11 +79,11 @@ func TestOutputByRoot(t *testing.T) {
// Initial call retrieves from the stub // Initial call retrieves from the stub
root := common.Hash(eth.OutputRoot(output)) root := common.Hash(eth.OutputRoot(output))
stub.Outputs[root] = output stub.Outputs[root] = output
actual := oracle.OutputByRoot(root) actual := oracle.OutputByRoot(root, 59284)
require.Equal(t, output, actual) require.Equal(t, output, actual)
// Later calls should retrieve from cache // Later calls should retrieve from cache (even if the chain ID is different)
delete(stub.Outputs, root) delete(stub.Outputs, root)
actual = oracle.OutputByRoot(root) actual = oracle.OutputByRoot(root, 9193)
require.Equal(t, output, actual) require.Equal(t, output, actual)
} }
...@@ -16,14 +16,16 @@ var codePrefixedKeyLength = common.HashLength + len(rawdb.CodePrefix) ...@@ -16,14 +16,16 @@ var codePrefixedKeyLength = common.HashLength + len(rawdb.CodePrefix)
var ErrInvalidKeyLength = errors.New("pre-images must be identified by 32-byte hash keys") var ErrInvalidKeyLength = errors.New("pre-images must be identified by 32-byte hash keys")
type OracleKeyValueStore struct { type OracleKeyValueStore struct {
db ethdb.KeyValueStore db ethdb.KeyValueStore
oracle StateOracle oracle StateOracle
chainID uint64
} }
func NewOracleBackedDB(oracle StateOracle) *OracleKeyValueStore { func NewOracleBackedDB(oracle StateOracle, chainID uint64) *OracleKeyValueStore {
return &OracleKeyValueStore{ return &OracleKeyValueStore{
db: memorydb.New(), db: memorydb.New(),
oracle: oracle, oracle: oracle,
chainID: chainID,
} }
} }
...@@ -38,12 +40,12 @@ func (o *OracleKeyValueStore) Get(key []byte) ([]byte, error) { ...@@ -38,12 +40,12 @@ func (o *OracleKeyValueStore) Get(key []byte) ([]byte, error) {
if len(key) == codePrefixedKeyLength && bytes.HasPrefix(key, rawdb.CodePrefix) { if len(key) == codePrefixedKeyLength && bytes.HasPrefix(key, rawdb.CodePrefix) {
key = key[len(rawdb.CodePrefix):] key = key[len(rawdb.CodePrefix):]
return o.oracle.CodeByHash(*(*[common.HashLength]byte)(key)), nil return o.oracle.CodeByHash(*(*[common.HashLength]byte)(key), o.chainID), nil
} }
if len(key) != common.HashLength { if len(key) != common.HashLength {
return nil, ErrInvalidKeyLength return nil, ErrInvalidKeyLength
} }
return o.oracle.NodeByHash(*(*[common.HashLength]byte)(key)), nil return o.oracle.NodeByHash(*(*[common.HashLength]byte)(key), o.chainID), nil
} }
func (o *OracleKeyValueStore) NewBatch() ethdb.Batch { func (o *OracleKeyValueStore) NewBatch() ethdb.Batch {
......
...@@ -34,7 +34,7 @@ var _ ethdb.KeyValueStore = (*OracleKeyValueStore)(nil) ...@@ -34,7 +34,7 @@ var _ ethdb.KeyValueStore = (*OracleKeyValueStore)(nil)
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
t.Run("IncorrectLengthKey", func(t *testing.T) { t.Run("IncorrectLengthKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
val, err := db.Get([]byte{1, 2, 3}) val, err := db.Get([]byte{1, 2, 3})
require.ErrorIs(t, err, ErrInvalidKeyLength) require.ErrorIs(t, err, ErrInvalidKeyLength)
require.Nil(t, val) require.Nil(t, val)
...@@ -42,7 +42,7 @@ func TestGet(t *testing.T) { ...@@ -42,7 +42,7 @@ func TestGet(t *testing.T) {
t.Run("KeyWithCodePrefix", func(t *testing.T) { t.Run("KeyWithCodePrefix", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0x12345678") key := common.HexToHash("0x12345678")
prefixedKey := append(rawdb.CodePrefix, key.Bytes()...) prefixedKey := append(rawdb.CodePrefix, key.Bytes()...)
...@@ -56,7 +56,7 @@ func TestGet(t *testing.T) { ...@@ -56,7 +56,7 @@ func TestGet(t *testing.T) {
t.Run("NormalKeyThatHappensToStartWithCodePrefix", func(t *testing.T) { t.Run("NormalKeyThatHappensToStartWithCodePrefix", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
key := make([]byte, common.HashLength) key := make([]byte, common.HashLength)
copy(rawdb.CodePrefix, key) copy(rawdb.CodePrefix, key)
fmt.Println(key[0]) fmt.Println(key[0])
...@@ -73,7 +73,7 @@ func TestGet(t *testing.T) { ...@@ -73,7 +73,7 @@ func TestGet(t *testing.T) {
expected := []byte{2, 6, 3, 8} expected := []byte{2, 6, 3, 8}
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
oracle.Data[key] = expected oracle.Data[key] = expected
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
val, err := db.Get(key.Bytes()) val, err := db.Get(key.Bytes())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expected, val) require.Equal(t, expected, val)
...@@ -83,7 +83,7 @@ func TestGet(t *testing.T) { ...@@ -83,7 +83,7 @@ func TestGet(t *testing.T) {
func TestPut(t *testing.T) { func TestPut(t *testing.T) {
t.Run("NewKey", func(t *testing.T) { t.Run("NewKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0xAA4488") key := common.HexToHash("0xAA4488")
value := []byte{2, 6, 3, 8} value := []byte{2, 6, 3, 8}
err := db.Put(key.Bytes(), value) err := db.Put(key.Bytes(), value)
...@@ -95,7 +95,7 @@ func TestPut(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestPut(t *testing.T) {
}) })
t.Run("ReplaceKey", func(t *testing.T) { t.Run("ReplaceKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0xAA4488") key := common.HexToHash("0xAA4488")
value1 := []byte{2, 6, 3, 8} value1 := []byte{2, 6, 3, 8}
value2 := []byte{1, 2, 3} value2 := []byte{1, 2, 3}
...@@ -117,13 +117,13 @@ func TestSupportsStateDBOperations(t *testing.T) { ...@@ -117,13 +117,13 @@ func TestSupportsStateDBOperations(t *testing.T) {
genesisBlock := l2Genesis.MustCommit(realDb, trieDB) genesisBlock := l2Genesis.MustCommit(realDb, trieDB)
loader := test.NewKvStateOracle(t, realDb) loader := test.NewKvStateOracle(t, realDb)
assertStateDataAvailable(t, NewOracleBackedDB(loader), l2Genesis, genesisBlock) assertStateDataAvailable(t, NewOracleBackedDB(loader, 1234), l2Genesis, genesisBlock)
} }
func TestUpdateState(t *testing.T) { func TestUpdateState(t *testing.T) {
l2Genesis := createGenesis() l2Genesis := createGenesis()
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := rawdb.NewDatabase(NewOracleBackedDB(oracle)) db := rawdb.NewDatabase(NewOracleBackedDB(oracle, 1234))
trieDB := triedb.NewDatabase(db, &triedb.Config{HashDB: hashdb.Defaults}) trieDB := triedb.NewDatabase(db, &triedb.Config{HashDB: hashdb.Defaults})
genesisBlock := l2Genesis.MustCommit(db, trieDB) genesisBlock := l2Genesis.MustCommit(db, trieDB)
......
...@@ -45,12 +45,13 @@ type OracleBackedL2Chain struct { ...@@ -45,12 +45,13 @@ type OracleBackedL2Chain struct {
var _ engineapi.CachingEngineBackend = (*OracleBackedL2Chain)(nil) var _ engineapi.CachingEngineBackend = (*OracleBackedL2Chain)(nil)
func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle engineapi.PrecompileOracle, chainCfg *params.ChainConfig, l2OutputRoot common.Hash) (*OracleBackedL2Chain, error) { func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle engineapi.PrecompileOracle, chainCfg *params.ChainConfig, l2OutputRoot common.Hash) (*OracleBackedL2Chain, error) {
output := oracle.OutputByRoot(l2OutputRoot) chainID := chainCfg.ChainID.Uint64()
output := oracle.OutputByRoot(l2OutputRoot, chainID)
outputV0, ok := output.(*eth.OutputV0) outputV0, ok := output.(*eth.OutputV0)
if !ok { if !ok {
return nil, fmt.Errorf("unsupported L2 output version: %d", output.Version()) return nil, fmt.Errorf("unsupported L2 output version: %d", output.Version())
} }
head := oracle.BlockByHash(outputV0.BlockHash) head := oracle.BlockByHash(outputV0.BlockHash, chainID)
logger.Info("Loaded L2 head", "hash", head.Hash(), "number", head.Number()) logger.Info("Loaded L2 head", "hash", head.Hash(), "number", head.Number())
return &OracleBackedL2Chain{ return &OracleBackedL2Chain{
log: logger, log: logger,
...@@ -69,7 +70,7 @@ func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle e ...@@ -69,7 +70,7 @@ func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle e
finalized: head.Header(), finalized: head.Header(),
oracleHead: head.Header(), oracleHead: head.Header(),
blocks: make(map[common.Hash]*types.Block), blocks: make(map[common.Hash]*types.Block),
db: NewOracleBackedDB(oracle), db: NewOracleBackedDB(oracle, chainID),
vmCfg: vm.Config{ vmCfg: vm.Config{
PrecompileOverrides: engineapi.CreatePrecompileOverrides(precompileOracle), PrecompileOverrides: engineapi.CreatePrecompileOverrides(precompileOracle),
}, },
...@@ -122,7 +123,7 @@ func (o *OracleBackedL2Chain) GetBlockByHash(hash common.Hash) *types.Block { ...@@ -122,7 +123,7 @@ func (o *OracleBackedL2Chain) GetBlockByHash(hash common.Hash) *types.Block {
return block return block
} }
// Retrieve from the oracle // Retrieve from the oracle
return o.oracle.BlockByHash(hash) return o.oracle.BlockByHash(hash, o.chainCfg.ChainID.Uint64())
} }
func (o *OracleBackedL2Chain) GetBlock(hash common.Hash, number uint64) *types.Block { func (o *OracleBackedL2Chain) GetBlock(hash common.Hash, number uint64) *types.Block {
......
...@@ -378,7 +378,7 @@ func createBlock(t *testing.T, chain *OracleBackedL2Chain, opts ...blockCreateOp ...@@ -378,7 +378,7 @@ func createBlock(t *testing.T, chain *OracleBackedL2Chain, opts ...blockCreateOp
require.NoError(t, err) require.NoError(t, err)
nonce := parentDB.GetNonce(fundedAddress) nonce := parentDB.GetNonce(fundedAddress)
config := chain.Config() config := chain.Config()
db := rawdb.NewDatabase(NewOracleBackedDB(chain.oracle)) db := rawdb.NewDatabase(NewOracleBackedDB(chain.oracle, config.ChainID.Uint64()))
blocks, _ := core.GenerateChain(config, parent, chain.Engine(), db, 1, func(i int, gen *core.BlockGen) { blocks, _ := core.GenerateChain(config, parent, chain.Engine(), db, 1, func(i int, gen *core.BlockGen) {
rawTx := &types.DynamicFeeTx{ rawTx := &types.DynamicFeeTx{
ChainID: config.ChainID, ChainID: config.ChainID,
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
preimage "github.com/ethereum-optimism/optimism/op-preimage" preimage "github.com/ethereum-optimism/optimism/op-preimage"
) )
...@@ -19,43 +20,95 @@ const ( ...@@ -19,43 +20,95 @@ const (
HintAgreedPrestate = "agreed-pre-state" HintAgreedPrestate = "agreed-pre-state"
) )
type BlockHeaderHint common.Hash type LegacyBlockHeaderHint common.Hash
var _ preimage.Hint = LegacyBlockHeaderHint{}
func (l LegacyBlockHeaderHint) Hint() string {
return HintL2BlockHeader + " " + (common.Hash)(l).String()
}
type HashAndChainID struct {
Hash common.Hash
ChainID uint64
}
func (h HashAndChainID) Marshal() []byte {
d := make([]byte, 32+8)
copy(d[:32], h.Hash[:])
binary.BigEndian.PutUint64(d[32:], h.ChainID)
return d
}
type BlockHeaderHint HashAndChainID
var _ preimage.Hint = BlockHeaderHint{} var _ preimage.Hint = BlockHeaderHint{}
func (l BlockHeaderHint) Hint() string { func (l BlockHeaderHint) Hint() string {
return HintL2BlockHeader + " " + (common.Hash)(l).String() return HintL2BlockHeader + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyTransactionsHint common.Hash
var _ preimage.Hint = LegacyTransactionsHint{}
func (l LegacyTransactionsHint) Hint() string {
return HintL2Transactions + " " + (common.Hash)(l).String()
} }
type TransactionsHint common.Hash type TransactionsHint HashAndChainID
var _ preimage.Hint = TransactionsHint{} var _ preimage.Hint = TransactionsHint{}
func (l TransactionsHint) Hint() string { func (l TransactionsHint) Hint() string {
return HintL2Transactions + " " + (common.Hash)(l).String() return HintL2Transactions + " " + hexutil.Encode(HashAndChainID(l).Marshal())
} }
type CodeHint common.Hash type CodeHint HashAndChainID
var _ preimage.Hint = CodeHint{} var _ preimage.Hint = CodeHint{}
func (l CodeHint) Hint() string { func (l CodeHint) Hint() string {
return HintL2Code + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyCodeHint common.Hash
var _ preimage.Hint = LegacyCodeHint{}
func (l LegacyCodeHint) Hint() string {
return HintL2Code + " " + (common.Hash)(l).String() return HintL2Code + " " + (common.Hash)(l).String()
} }
type StateNodeHint common.Hash type StateNodeHint HashAndChainID
var _ preimage.Hint = StateNodeHint{} var _ preimage.Hint = StateNodeHint{}
func (l StateNodeHint) Hint() string { func (l StateNodeHint) Hint() string {
return HintL2StateNode + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyStateNodeHint common.Hash
var _ preimage.Hint = LegacyStateNodeHint{}
func (l LegacyStateNodeHint) Hint() string {
return HintL2StateNode + " " + (common.Hash)(l).String() return HintL2StateNode + " " + (common.Hash)(l).String()
} }
type L2OutputHint common.Hash type L2OutputHint HashAndChainID
var _ preimage.Hint = L2OutputHint{} var _ preimage.Hint = L2OutputHint{}
func (l L2OutputHint) Hint() string { func (l L2OutputHint) Hint() string {
return HintL2Output + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyL2OutputHint common.Hash
var _ preimage.Hint = LegacyL2OutputHint{}
func (l LegacyL2OutputHint) Hint() string {
return HintL2Output + " " + (common.Hash)(l).String() return HintL2Output + " " + (common.Hash)(l).String()
} }
......
...@@ -19,11 +19,11 @@ type StateOracle interface { ...@@ -19,11 +19,11 @@ type StateOracle interface {
// NodeByHash retrieves the merkle-patricia trie node pre-image for a given hash. // NodeByHash retrieves the merkle-patricia trie node pre-image for a given hash.
// Trie nodes may be from the world state trie or any account storage trie. // Trie nodes may be from the world state trie or any account storage trie.
// Contract code is not stored as part of the trie and must be retrieved via CodeByHash // Contract code is not stored as part of the trie and must be retrieved via CodeByHash
NodeByHash(nodeHash common.Hash) []byte NodeByHash(nodeHash common.Hash, chainID uint64) []byte
// CodeByHash retrieves the contract code pre-image for a given hash. // CodeByHash retrieves the contract code pre-image for a given hash.
// codeHash should be retrieved from the world state account for a contract. // codeHash should be retrieved from the world state account for a contract.
CodeByHash(codeHash common.Hash) []byte CodeByHash(codeHash common.Hash, chainID uint64) []byte
} }
// Oracle defines the high-level API used to retrieve L2 data. // Oracle defines the high-level API used to retrieve L2 data.
...@@ -32,9 +32,9 @@ type Oracle interface { ...@@ -32,9 +32,9 @@ type Oracle interface {
StateOracle StateOracle
// BlockByHash retrieves the block with the given hash. // BlockByHash retrieves the block with the given hash.
BlockByHash(blockHash common.Hash) *types.Block BlockByHash(blockHash common.Hash, chainID uint64) *types.Block
OutputByRoot(root common.Hash) eth.Output OutputByRoot(root common.Hash, chainID uint64) eth.Output
// BlockDataByHash retrieves the block, including all data used to construct it. // BlockDataByHash retrieves the block, including all data used to construct it.
BlockDataByHash(agreedBlockHash, blockHash common.Hash, chainID uint64) *types.Block BlockDataByHash(agreedBlockHash, blockHash common.Hash, chainID uint64) *types.Block
...@@ -45,21 +45,27 @@ type Oracle interface { ...@@ -45,21 +45,27 @@ type Oracle interface {
// PreimageOracle implements Oracle using by interfacing with the pure preimage.Oracle // PreimageOracle implements Oracle using by interfacing with the pure preimage.Oracle
// to fetch pre-images to decode into the requested data. // to fetch pre-images to decode into the requested data.
type PreimageOracle struct { type PreimageOracle struct {
oracle preimage.Oracle oracle preimage.Oracle
hint preimage.Hinter hint preimage.Hinter
hintL2ChainIDs bool
} }
var _ Oracle = (*PreimageOracle)(nil) var _ Oracle = (*PreimageOracle)(nil)
func NewPreimageOracle(raw preimage.Oracle, hint preimage.Hinter) *PreimageOracle { func NewPreimageOracle(raw preimage.Oracle, hint preimage.Hinter, hintL2ChainIDs bool) *PreimageOracle {
return &PreimageOracle{ return &PreimageOracle{
oracle: raw, oracle: raw,
hint: hint, hint: hint,
hintL2ChainIDs: hintL2ChainIDs,
} }
} }
func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash) *types.Header { func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash, chainID uint64) *types.Header {
p.hint.Hint(BlockHeaderHint(blockHash)) if p.hintL2ChainIDs {
p.hint.Hint(BlockHeaderHint{Hash: blockHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyBlockHeaderHint(blockHash))
}
headerRlp := p.oracle.Get(preimage.Keccak256Key(blockHash)) headerRlp := p.oracle.Get(preimage.Keccak256Key(blockHash))
var header types.Header var header types.Header
if err := rlp.DecodeBytes(headerRlp, &header); err != nil { if err := rlp.DecodeBytes(headerRlp, &header); err != nil {
...@@ -68,15 +74,19 @@ func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash) *types.Header ...@@ -68,15 +74,19 @@ func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash) *types.Header
return &header return &header
} }
func (p *PreimageOracle) BlockByHash(blockHash common.Hash) *types.Block { func (p *PreimageOracle) BlockByHash(blockHash common.Hash, chainID uint64) *types.Block {
header := p.headerByBlockHash(blockHash) header := p.headerByBlockHash(blockHash, chainID)
txs := p.LoadTransactions(blockHash, header.TxHash) txs := p.LoadTransactions(blockHash, header.TxHash, chainID)
return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs}) return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs})
} }
func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.Hash) []*types.Transaction { func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.Hash, chainID uint64) []*types.Transaction {
p.hint.Hint(TransactionsHint(blockHash)) if p.hintL2ChainIDs {
p.hint.Hint(TransactionsHint{Hash: blockHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyTransactionsHint(blockHash))
}
opaqueTxs := mpt.ReadTrie(txHash, func(key common.Hash) []byte { opaqueTxs := mpt.ReadTrie(txHash, func(key common.Hash) []byte {
return p.oracle.Get(preimage.Keccak256Key(key)) return p.oracle.Get(preimage.Keccak256Key(key))
...@@ -89,18 +99,30 @@ func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.H ...@@ -89,18 +99,30 @@ func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.H
return txs return txs
} }
func (p *PreimageOracle) NodeByHash(nodeHash common.Hash) []byte { func (p *PreimageOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
p.hint.Hint(StateNodeHint(nodeHash)) if p.hintL2ChainIDs {
p.hint.Hint(StateNodeHint{Hash: nodeHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyStateNodeHint(nodeHash))
}
return p.oracle.Get(preimage.Keccak256Key(nodeHash)) return p.oracle.Get(preimage.Keccak256Key(nodeHash))
} }
func (p *PreimageOracle) CodeByHash(codeHash common.Hash) []byte { func (p *PreimageOracle) CodeByHash(codeHash common.Hash, chainID uint64) []byte {
p.hint.Hint(CodeHint(codeHash)) if p.hintL2ChainIDs {
p.hint.Hint(CodeHint{Hash: codeHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyCodeHint(codeHash))
}
return p.oracle.Get(preimage.Keccak256Key(codeHash)) return p.oracle.Get(preimage.Keccak256Key(codeHash))
} }
func (p *PreimageOracle) OutputByRoot(l2OutputRoot common.Hash) eth.Output { func (p *PreimageOracle) OutputByRoot(l2OutputRoot common.Hash, chainID uint64) eth.Output {
p.hint.Hint(L2OutputHint(l2OutputRoot)) if p.hintL2ChainIDs {
p.hint.Hint(L2OutputHint{Hash: l2OutputRoot, ChainID: chainID})
} else {
p.hint.Hint(LegacyL2OutputHint(l2OutputRoot))
}
data := p.oracle.Get(preimage.Keccak256Key(l2OutputRoot)) data := p.oracle.Get(preimage.Keccak256Key(l2OutputRoot))
output, err := eth.UnmarshalOutput(data) output, err := eth.UnmarshalOutput(data)
if err != nil { if err != nil {
...@@ -116,8 +138,8 @@ func (p *PreimageOracle) BlockDataByHash(agreedBlockHash, blockHash common.Hash, ...@@ -116,8 +138,8 @@ func (p *PreimageOracle) BlockDataByHash(agreedBlockHash, blockHash common.Hash,
ChainID: chainID, ChainID: chainID,
} }
p.hint.Hint(hint) p.hint.Hint(hint)
header := p.headerByBlockHash(blockHash) header := p.headerByBlockHash(blockHash, chainID)
txs := p.LoadTransactions(blockHash, header.TxHash) txs := p.LoadTransactions(blockHash, header.TxHash, chainID)
return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs}) return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs})
} }
......
...@@ -19,29 +19,27 @@ import ( ...@@ -19,29 +19,27 @@ import (
"github.com/ethereum-optimism/optimism/op-service/testutils" "github.com/ethereum-optimism/optimism/op-service/testutils"
) )
func mockPreimageOracle(t *testing.T) (po *PreimageOracle, hintsMock *mock.Mock, preimages map[common.Hash][]byte) { func mockPreimageOracle(t *testing.T, hintL2ChainIDs bool) (po *PreimageOracle, hintsMock *mock.Mock, preimages map[common.Hash][]byte) {
// Prepare the pre-images // Prepare the pre-images
preimages = make(map[common.Hash][]byte) preimages = make(map[common.Hash][]byte)
hintsMock = new(mock.Mock) hintsMock = new(mock.Mock)
po = &PreimageOracle{ rawOracle := preimage.OracleFn(func(key preimage.Key) []byte {
oracle: preimage.OracleFn(func(key preimage.Key) []byte { v, ok := preimages[key.PreimageKey()]
v, ok := preimages[key.PreimageKey()] require.True(t, ok, "preimage must exist")
require.True(t, ok, "preimage must exist") return v
return v })
}), hinter := preimage.HinterFn(func(v preimage.Hint) {
hint: preimage.HinterFn(func(v preimage.Hint) { hintsMock.MethodCalled("hint", v.Hint())
hintsMock.MethodCalled("hint", v.Hint()) })
}), po = NewPreimageOracle(rawOracle, hinter, hintL2ChainIDs)
} return
return po, hintsMock, preimages
} }
// testBlock tests that the given block can be passed through the preimage oracle. // testBlock tests that the given block can be passed through the preimage oracle.
func testBlock(t *testing.T, block *types.Block) { func testBlock(t *testing.T, block *types.Block, hintL2ChainIDs bool) {
po, hints, preimages := mockPreimageOracle(t) po, hints, preimages := mockPreimageOracle(t, hintL2ChainIDs)
hdrBytes, err := rlp.EncodeToBytes(block.Header()) hdrBytes, err := rlp.EncodeToBytes(block.Header())
require.NoError(t, err) require.NoError(t, err)
...@@ -54,12 +52,19 @@ func testBlock(t *testing.T, block *types.Block) { ...@@ -54,12 +52,19 @@ func testBlock(t *testing.T, block *types.Block) {
preimages[preimage.Keccak256Key(crypto.Keccak256Hash(p)).PreimageKey()] = p preimages[preimage.Keccak256Key(crypto.Keccak256Hash(p)).PreimageKey()] = p
} }
chainID := uint64(4924)
// Prepare a raw mock pre-image oracle that will serve the pre-image data and handle hints // Prepare a raw mock pre-image oracle that will serve the pre-image data and handle hints
// Check if blocks with txs work // Check if blocks with txs work
hints.On("hint", BlockHeaderHint(block.Hash()).Hint()).Once().Return() if hintL2ChainIDs {
hints.On("hint", TransactionsHint(block.Hash()).Hint()).Once().Return() hints.On("hint", BlockHeaderHint{Hash: block.Hash(), ChainID: chainID}.Hint()).Once().Return()
gotBlock := po.BlockByHash(block.Hash()) hints.On("hint", TransactionsHint{Hash: block.Hash(), ChainID: chainID}.Hint()).Once().Return()
} else {
hints.On("hint", LegacyBlockHeaderHint(block.Hash()).Hint()).Once().Return()
hints.On("hint", LegacyTransactionsHint(block.Hash()).Hint()).Once().Return()
}
gotBlock := po.BlockByHash(block.Hash(), chainID)
hints.AssertExpectations(t) hints.AssertExpectations(t)
require.Equal(t, gotBlock.Hash(), block.Hash()) require.Equal(t, gotBlock.Hash(), block.Hash())
...@@ -75,8 +80,12 @@ func TestPreimageOracleBlockByHash(t *testing.T) { ...@@ -75,8 +80,12 @@ func TestPreimageOracleBlockByHash(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
block, _ := testutils.RandomBlock(rng, 10) block, _ := testutils.RandomBlock(rng, 10)
t.Run(fmt.Sprintf("legacy_block_%d", i), func(t *testing.T) {
testBlock(t, block, false)
})
t.Run(fmt.Sprintf("block_%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("block_%d", i), func(t *testing.T) {
testBlock(t, block) testBlock(t, block, true)
}) })
} }
} }
...@@ -85,8 +94,24 @@ func TestPreimageOracleNodeByHash(t *testing.T) { ...@@ -85,8 +94,24 @@ func TestPreimageOracleNodeByHash(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_node_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
node := make([]byte, 123)
rng.Read(node)
h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", LegacyStateNodeHint(h).Hint()).Once().Return()
gotNode := po.NodeByHash(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "node matches")
})
t.Run(fmt.Sprintf("node_%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("node_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t) po, hints, preimages := mockPreimageOracle(t, true)
node := make([]byte, 123) node := make([]byte, 123)
rng.Read(node) rng.Read(node)
...@@ -94,8 +119,8 @@ func TestPreimageOracleNodeByHash(t *testing.T) { ...@@ -94,8 +119,8 @@ func TestPreimageOracleNodeByHash(t *testing.T) {
h := crypto.Keccak256Hash(node) h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", StateNodeHint(h).Hint()).Once().Return() hints.On("hint", StateNodeHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotNode := po.NodeByHash(h) gotNode := po.NodeByHash(h, chainID)
hints.AssertExpectations(t) hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "node matches") require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "node matches")
}) })
...@@ -106,8 +131,24 @@ func TestPreimageOracleCodeByHash(t *testing.T) { ...@@ -106,8 +131,24 @@ func TestPreimageOracleCodeByHash(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_code_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
node := make([]byte, 123)
rng.Read(node)
h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", LegacyCodeHint(h).Hint()).Once().Return()
gotNode := po.CodeByHash(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "code matches")
})
t.Run(fmt.Sprintf("code_%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("code_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t) po, hints, preimages := mockPreimageOracle(t, true)
node := make([]byte, 123) node := make([]byte, 123)
rng.Read(node) rng.Read(node)
...@@ -115,8 +156,8 @@ func TestPreimageOracleCodeByHash(t *testing.T) { ...@@ -115,8 +156,8 @@ func TestPreimageOracleCodeByHash(t *testing.T) {
h := crypto.Keccak256Hash(node) h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", CodeHint(h).Hint()).Once().Return() hints.On("hint", CodeHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotNode := po.CodeByHash(h) gotNode := po.CodeByHash(h, chainID)
hints.AssertExpectations(t) hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "code matches") require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "code matches")
}) })
...@@ -127,14 +168,26 @@ func TestPreimageOracleOutputByRoot(t *testing.T) { ...@@ -127,14 +168,26 @@ func TestPreimageOracleOutputByRoot(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_output_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
output := testutils.RandomOutputV0(rng)
h := common.Hash(eth.OutputRoot(output))
preimages[preimage.Keccak256Key(h).PreimageKey()] = output.Marshal()
hints.On("hint", LegacyL2OutputHint(h).Hint()).Once().Return()
gotOutput := po.OutputByRoot(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(output.Marshal()), hexutil.Bytes(gotOutput.Marshal()), "output matches")
})
t.Run(fmt.Sprintf("output_%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("output_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t) po, hints, preimages := mockPreimageOracle(t, true)
output := testutils.RandomOutputV0(rng) output := testutils.RandomOutputV0(rng)
h := common.Hash(eth.OutputRoot(output)) h := common.Hash(eth.OutputRoot(output))
preimages[preimage.Keccak256Key(h).PreimageKey()] = output.Marshal() preimages[preimage.Keccak256Key(h).PreimageKey()] = output.Marshal()
hints.On("hint", L2OutputHint(h).Hint()).Once().Return() hints.On("hint", L2OutputHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotOutput := po.OutputByRoot(h) gotOutput := po.OutputByRoot(h, chainID)
hints.AssertExpectations(t) hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(output.Marshal()), hexutil.Bytes(gotOutput.Marshal()), "output matches") require.Equal(t, hexutil.Bytes(output.Marshal()), hexutil.Bytes(gotOutput.Marshal()), "output matches")
}) })
......
...@@ -15,8 +15,8 @@ import ( ...@@ -15,8 +15,8 @@ import (
// Same as l2.StateOracle but need to use our own copy to avoid dependency loops // Same as l2.StateOracle but need to use our own copy to avoid dependency loops
type stateOracle interface { type stateOracle interface {
NodeByHash(nodeHash common.Hash) []byte NodeByHash(nodeHash common.Hash, chainID uint64) []byte
CodeByHash(codeHash common.Hash) []byte CodeByHash(codeHash common.Hash, chainID uint64) []byte
} }
type StubBlockOracle struct { type StubBlockOracle struct {
...@@ -56,7 +56,7 @@ func NewStubOracleWithBlocks(t *testing.T, chain []*gethTypes.Block, outputs []e ...@@ -56,7 +56,7 @@ func NewStubOracleWithBlocks(t *testing.T, chain []*gethTypes.Block, outputs []e
} }
} }
func (o StubBlockOracle) BlockByHash(blockHash common.Hash) *gethTypes.Block { func (o StubBlockOracle) BlockByHash(blockHash common.Hash, chainID uint64) *gethTypes.Block {
block, ok := o.Blocks[blockHash] block, ok := o.Blocks[blockHash]
if !ok { if !ok {
o.t.Fatalf("requested unknown block %s", blockHash) o.t.Fatalf("requested unknown block %s", blockHash)
...@@ -64,7 +64,7 @@ func (o StubBlockOracle) BlockByHash(blockHash common.Hash) *gethTypes.Block { ...@@ -64,7 +64,7 @@ func (o StubBlockOracle) BlockByHash(blockHash common.Hash) *gethTypes.Block {
return block return block
} }
func (o StubBlockOracle) OutputByRoot(root common.Hash) eth.Output { func (o StubBlockOracle) OutputByRoot(root common.Hash, chainID uint64) eth.Output {
output, ok := o.Outputs[root] output, ok := o.Outputs[root]
if !ok { if !ok {
o.t.Fatalf("requested unknown output root %s", root) o.t.Fatalf("requested unknown output root %s", root)
...@@ -100,7 +100,7 @@ func NewKvStateOracle(t *testing.T, db ethdb.KeyValueStore) *KvStateOracle { ...@@ -100,7 +100,7 @@ func NewKvStateOracle(t *testing.T, db ethdb.KeyValueStore) *KvStateOracle {
} }
} }
func (o *KvStateOracle) NodeByHash(nodeHash common.Hash) []byte { func (o *KvStateOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
val, err := o.Source.Get(nodeHash.Bytes()) val, err := o.Source.Get(nodeHash.Bytes())
if err != nil { if err != nil {
o.t.Fatalf("error retrieving node %v: %v", nodeHash, err) o.t.Fatalf("error retrieving node %v: %v", nodeHash, err)
...@@ -108,7 +108,7 @@ func (o *KvStateOracle) NodeByHash(nodeHash common.Hash) []byte { ...@@ -108,7 +108,7 @@ func (o *KvStateOracle) NodeByHash(nodeHash common.Hash) []byte {
return val return val
} }
func (o *KvStateOracle) CodeByHash(hash common.Hash) []byte { func (o *KvStateOracle) CodeByHash(hash common.Hash, chainID uint64) []byte {
return rawdb.ReadCode(o.Source, hash) return rawdb.ReadCode(o.Source, hash)
} }
...@@ -127,7 +127,7 @@ type StubStateOracle struct { ...@@ -127,7 +127,7 @@ type StubStateOracle struct {
Code map[common.Hash][]byte Code map[common.Hash][]byte
} }
func (o *StubStateOracle) NodeByHash(nodeHash common.Hash) []byte { func (o *StubStateOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
data, ok := o.Data[nodeHash] data, ok := o.Data[nodeHash]
if !ok { if !ok {
o.t.Fatalf("no value for node %v", nodeHash) o.t.Fatalf("no value for node %v", nodeHash)
...@@ -135,7 +135,7 @@ func (o *StubStateOracle) NodeByHash(nodeHash common.Hash) []byte { ...@@ -135,7 +135,7 @@ func (o *StubStateOracle) NodeByHash(nodeHash common.Hash) []byte {
return data return data
} }
func (o *StubStateOracle) CodeByHash(hash common.Hash) []byte { func (o *StubStateOracle) CodeByHash(hash common.Hash, chainID uint64) []byte {
data, ok := o.Code[hash] data, ok := o.Code[hash]
if !ok { if !ok {
o.t.Fatalf("no value for code %v", hash) o.t.Fatalf("no value for code %v", hash)
......
...@@ -45,7 +45,7 @@ func RunProgram(logger log.Logger, preimageOracle io.ReadWriter, preimageHinter ...@@ -45,7 +45,7 @@ func RunProgram(logger log.Logger, preimageOracle io.ReadWriter, preimageHinter
pClient := preimage.NewOracleClient(preimageOracle) pClient := preimage.NewOracleClient(preimageOracle)
hClient := preimage.NewHintWriter(preimageHinter) hClient := preimage.NewHintWriter(preimageHinter)
l1PreimageOracle := l1.NewCachingOracle(l1.NewPreimageOracle(pClient, hClient)) l1PreimageOracle := l1.NewCachingOracle(l1.NewPreimageOracle(pClient, hClient))
l2PreimageOracle := l2.NewCachingOracle(l2.NewPreimageOracle(pClient, hClient)) l2PreimageOracle := l2.NewCachingOracle(l2.NewPreimageOracle(pClient, hClient, cfg.InteropEnabled))
if cfg.InteropEnabled { if cfg.InteropEnabled {
bootInfo := boot.BootstrapInterop(pClient) bootInfo := boot.BootstrapInterop(pClient)
......
...@@ -32,3 +32,7 @@ func NewL2Client(client client.RPC, log log.Logger, metrics caching.Metrics, con ...@@ -32,3 +32,7 @@ func NewL2Client(client client.RPC, log log.Logger, metrics caching.Metrics, con
func (s *L2Client) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error) { func (s *L2Client) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error) {
return s.OutputV0AtBlock(ctx, blockRoot) return s.OutputV0AtBlock(ctx, blockRoot)
} }
func (s *L2Client) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
return s.OutputV0AtBlockNumber(ctx, blockNum)
}
...@@ -117,6 +117,14 @@ func (l *L2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth ...@@ -117,6 +117,14 @@ func (l *L2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth
return l.canonicalEthClient.OutputByRoot(ctx, blockRoot) return l.canonicalEthClient.OutputByRoot(ctx, blockRoot)
} }
// OutputByBlockNumber implements prefetcher.L2Source.
func (l *L2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
if l.ExperimentalEnabled() {
return l.experimentalClient.OutputByNumber(ctx, blockNum)
}
return l.canonicalEthClient.OutputByNumber(ctx, blockNum)
}
// ExecutionWitness implements prefetcher.L2Source. // ExecutionWitness implements prefetcher.L2Source.
func (l *L2Source) ExecutionWitness(ctx context.Context, blockNum uint64) (*eth.ExecutionWitness, error) { func (l *L2Source) ExecutionWitness(ctx context.Context, blockNum uint64) (*eth.ExecutionWitness, error) {
if !l.ExperimentalEnabled() { if !l.ExperimentalEnabled() {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"strings" "strings"
preimage "github.com/ethereum-optimism/optimism/op-preimage" preimage "github.com/ethereum-optimism/optimism/op-preimage"
clientTypes "github.com/ethereum-optimism/optimism/op-program/client/interop/types"
"github.com/ethereum-optimism/optimism/op-program/client/l1" "github.com/ethereum-optimism/optimism/op-program/client/l1"
"github.com/ethereum-optimism/optimism/op-program/client/l2" "github.com/ethereum-optimism/optimism/op-program/client/l2"
"github.com/ethereum-optimism/optimism/op-program/client/mpt" "github.com/ethereum-optimism/optimism/op-program/client/mpt"
...@@ -258,11 +259,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -258,11 +259,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
} }
return p.kvStore.Put(preimage.PrecompileKey(inputHash).PreimageKey(), result) return p.kvStore.Put(preimage.PrecompileKey(inputHash).PreimageKey(), result)
case l2.HintL2BlockHeader, l2.HintL2Transactions: case l2.HintL2BlockHeader, l2.HintL2Transactions:
if len(hintBytes) != 32 { hash, chainID, err := p.parseHashAndChainID("L2 header/tx", hintBytes)
return fmt.Errorf("invalid L2 header/tx hint: %x", hint) if err != nil {
return err
} }
hash := common.Hash(hintBytes) source, err := p.l2Sources.ForChainID(chainID)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
if err != nil { if err != nil {
return err return err
} }
...@@ -280,11 +281,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -280,11 +281,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
} }
return p.storeTransactions(txs) return p.storeTransactions(txs)
case l2.HintL2StateNode: case l2.HintL2StateNode:
if len(hintBytes) != 32 { hash, chainID, err := p.parseHashAndChainID("L2 state node", hintBytes)
return fmt.Errorf("invalid L2 state node hint: %x", hint) if err != nil {
return err
} }
hash := common.Hash(hintBytes) source, err := p.l2Sources.ForChainID(chainID)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
if err != nil { if err != nil {
return err return err
} }
...@@ -294,11 +295,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -294,11 +295,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
} }
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), node) return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), node)
case l2.HintL2Code: case l2.HintL2Code:
if len(hintBytes) != 32 { hash, chainID, err := p.parseHashAndChainID("L2 code", hintBytes)
return fmt.Errorf("invalid L2 code hint: %x", hint) if err != nil {
return err
} }
hash := common.Hash(hintBytes) source, err := p.l2Sources.ForChainID(chainID)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
if err != nil { if err != nil {
return err return err
} }
...@@ -308,23 +309,47 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -308,23 +309,47 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
} }
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), code) return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), code)
case l2.HintL2Output: case l2.HintL2Output:
if len(hintBytes) != 32 { requestedHash, chainID, err := p.parseHashAndChainID("L2 output", hintBytes)
return fmt.Errorf("invalid L2 output hint: %x", hint)
}
requestedHash := common.Hash(hintBytes)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
if err != nil { if err != nil {
return err return err
} }
output, err := source.OutputByRoot(ctx, p.l2Head) source, err := p.l2Sources.ForChainID(chainID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch L2 output root for block %s: %w", p.l2Head, err) return err
} }
hash := common.Hash(eth.OutputRoot(output)) if len(p.agreedPrestate) == 0 {
if requestedHash != hash { output, err := source.OutputByRoot(ctx, p.l2Head)
return fmt.Errorf("output root %v from block %v does not match requested root: %v", hash, p.l2Head, requestedHash) if err != nil {
return fmt.Errorf("failed to fetch L2 output root for block %s: %w", p.l2Head, err)
}
hash := common.Hash(eth.OutputRoot(output))
if requestedHash != hash {
return fmt.Errorf("output root %v from block %v does not match requested root: %v", hash, p.l2Head, requestedHash)
}
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), output.Marshal())
} else {
prestate, err := clientTypes.UnmarshalTransitionState(p.agreedPrestate)
if err != nil {
return fmt.Errorf("cannot fetch output root, invalid agreed prestate: %w", err)
}
superRoot, err := eth.UnmarshalSuperRoot(prestate.SuperRoot)
if err != nil {
return fmt.Errorf("cannot fetch output root, invalid super root in prestate: %w", err)
}
superV1, ok := superRoot.(*eth.SuperV1)
if !ok {
return fmt.Errorf("cannot fetch output root, unsupported super root version in prestate: %v", superRoot.Version())
}
blockNum, err := source.RollupConfig().TargetBlockNumber(superV1.Timestamp)
if err != nil {
return fmt.Errorf("cannot fetch output root, failed to calculate block number at timestamp %v: %w", superV1.Timestamp, err)
}
output, err := source.OutputByNumber(ctx, blockNum)
if err != nil {
return fmt.Errorf("failed to fetch L2 output root for block %v: %w", blockNum, err)
}
return p.kvStore.Put(preimage.Keccak256Key(eth.OutputRoot(output)).PreimageKey(), output.Marshal())
} }
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), output.Marshal())
case l2.HintL2BlockData: case l2.HintL2BlockData:
if p.executor == nil { if p.executor == nil {
return fmt.Errorf("this prefetcher does not support native block execution") return fmt.Errorf("this prefetcher does not support native block execution")
...@@ -353,6 +378,17 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -353,6 +378,17 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
return fmt.Errorf("unknown hint type: %v", hintType) return fmt.Errorf("unknown hint type: %v", hintType)
} }
func (p *Prefetcher) parseHashAndChainID(hintType string, hintBytes []byte) (common.Hash, uint64, error) {
switch len(hintBytes) {
case 32:
return common.Hash(hintBytes), p.defaultChainID, nil
case 40:
return common.Hash(hintBytes[0:32]), binary.BigEndian.Uint64(hintBytes[32:]), nil
default:
return common.Hash{}, 0, fmt.Errorf("invalid %s hint: %x", hintType, hintBytes)
}
}
type BlockDataKey [32]byte type BlockDataKey [32]byte
func (p BlockDataKey) Key() [32]byte { func (p BlockDataKey) Key() [32]byte {
......
...@@ -34,6 +34,7 @@ import ( ...@@ -34,6 +34,7 @@ import (
var ( var (
ecRecoverInput = common.FromHex("18c547e4f7b0f325ad1e56f57e26c745b09a3e503d86e00e5255ff7f715d3d1c000000000000000000000000000000000000000000000000000000000000001c73b1693892219d736caba55bdb67216e485557ea6b6af75f37096c9aa6a5a75feeb940b1d03b21e36b0e47e79769f095fe2ab855bd91e3a38756b7d75a9c4549") ecRecoverInput = common.FromHex("18c547e4f7b0f325ad1e56f57e26c745b09a3e503d86e00e5255ff7f715d3d1c000000000000000000000000000000000000000000000000000000000000001c73b1693892219d736caba55bdb67216e485557ea6b6af75f37096c9aa6a5a75feeb940b1d03b21e36b0e47e79769f095fe2ab855bd91e3a38756b7d75a9c4549")
kzgPointEvalInput = common.FromHex("01e798154708fe7789429634053cbf9f99b619f9f084048927333fce637f549b564c0a11a0f704f4fc3e8acfe0f8245f0ad1347b378fbf96e206da11a5d3630624d25032e67a7e6a4910df5834b8fe70e6bcfeeac0352434196bdf4b2485d5a18f59a8d2a1a625a17f3fea0fe5eb8c896db3764f3185481bc22f91b4aaffcca25f26936857bc3a7c2539ea8ec3a952b7873033e038326e87ed3e1276fd140253fa08e9fc25fb2d9a98527fc22a2c9612fbeafdad446cbc7bcdbdcd780af2c16a") kzgPointEvalInput = common.FromHex("01e798154708fe7789429634053cbf9f99b619f9f084048927333fce637f549b564c0a11a0f704f4fc3e8acfe0f8245f0ad1347b378fbf96e206da11a5d3630624d25032e67a7e6a4910df5834b8fe70e6bcfeeac0352434196bdf4b2485d5a18f59a8d2a1a625a17f3fea0fe5eb8c896db3764f3185481bc22f91b4aaffcca25f26936857bc3a7c2539ea8ec3a952b7873033e038326e87ed3e1276fd140253fa08e9fc25fb2d9a98527fc22a2c9612fbeafdad446cbc7bcdbdcd780af2c16a")
defaultChainID = uint64(14)
) )
func TestNoHint(t *testing.T) { func TestNoHint(t *testing.T) {
...@@ -415,6 +416,7 @@ func TestRestrictedPrecompileContracts(t *testing.T) { ...@@ -415,6 +416,7 @@ func TestRestrictedPrecompileContracts(t *testing.T) {
func TestFetchL2Block(t *testing.T) { func TestFetchL2Block(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
chainID := uint64(482948)
block, rcpts := testutils.RandomBlock(rng, 10) block, rcpts := testutils.RandomBlock(rng, 10)
hash := block.Hash() hash := block.Hash()
...@@ -422,19 +424,32 @@ func TestFetchL2Block(t *testing.T) { ...@@ -422,19 +424,32 @@ func TestFetchL2Block(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t) prefetcher, _, _, _, kv := createPrefetcher(t)
storeBlock(t, kv, block, rcpts) storeBlock(t, kv, block, rcpts)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.BlockByHash(hash) result := oracle.BlockByHash(hash, chainID)
require.EqualValues(t, block.Header(), result.Header()) require.EqualValues(t, block.Header(), result.Header())
assertTransactionsEqual(t, block.Transactions(), result.Transactions()) assertTransactionsEqual(t, block.Transactions(), result.Transactions())
}) })
t.Run("Unknown", func(t *testing.T) { t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cl, _ := createPrefetcher(t) prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil) l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil)
defer l2Cl.MockL2Client.AssertExpectations(t) defer l2Cl.MockL2Client.AssertExpectations(t)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.BlockByHash(hash) result := oracle.BlockByHash(hash, chainID)
require.EqualValues(t, block.Header(), result.Header())
assertTransactionsEqual(t, block.Transactions(), result.Transactions())
})
t.Run("WithChainID", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t, 5, 7, 10)
l2Cl := l2Cls.sources[7]
l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.BlockByHash(hash, 7)
require.EqualValues(t, block.Header(), result.Header()) require.EqualValues(t, block.Header(), result.Header())
assertTransactionsEqual(t, block.Transactions(), result.Transactions()) assertTransactionsEqual(t, block.Transactions(), result.Transactions())
}) })
...@@ -444,23 +459,36 @@ func TestFetchL2Transactions(t *testing.T) { ...@@ -444,23 +459,36 @@ func TestFetchL2Transactions(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
block, rcpts := testutils.RandomBlock(rng, 10) block, rcpts := testutils.RandomBlock(rng, 10)
hash := block.Hash() hash := block.Hash()
chainID := rng.Uint64()
t.Run("AlreadyKnown", func(t *testing.T) { t.Run("AlreadyKnown", func(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t) prefetcher, _, _, _, kv := createPrefetcher(t)
storeBlock(t, kv, block, rcpts) storeBlock(t, kv, block, rcpts)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.LoadTransactions(hash, block.TxHash()) result := oracle.LoadTransactions(hash, block.TxHash(), chainID)
assertTransactionsEqual(t, block.Transactions(), result) assertTransactionsEqual(t, block.Transactions(), result)
}) })
t.Run("Unknown", func(t *testing.T) { t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cl, _ := createPrefetcher(t) prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil) l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil)
defer l2Cl.MockL2Client.AssertExpectations(t) defer l2Cl.MockL2Client.AssertExpectations(t)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.LoadTransactions(hash, block.TxHash()) result := oracle.LoadTransactions(hash, block.TxHash(), chainID)
assertTransactionsEqual(t, block.Transactions(), result)
})
t.Run("WithChainID", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t, 5, 7, 10)
l2Cl := l2Cls.sources[7]
l2Cl.ExpectInfoAndTxsByHash(hash, eth.BlockToInfo(block), block.Transactions(), nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.LoadTransactions(hash, block.TxHash(), 7)
assertTransactionsEqual(t, block.Transactions(), result) assertTransactionsEqual(t, block.Transactions(), result)
}) })
} }
...@@ -470,23 +498,36 @@ func TestFetchL2Node(t *testing.T) { ...@@ -470,23 +498,36 @@ func TestFetchL2Node(t *testing.T) {
node := testutils.RandomData(rng, 30) node := testutils.RandomData(rng, 30)
hash := crypto.Keccak256Hash(node) hash := crypto.Keccak256Hash(node)
key := preimage.Keccak256Key(hash).PreimageKey() key := preimage.Keccak256Key(hash).PreimageKey()
chainID := rng.Uint64()
t.Run("AlreadyKnown", func(t *testing.T) { t.Run("AlreadyKnown", func(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t) prefetcher, _, _, _, kv := createPrefetcher(t)
require.NoError(t, kv.Put(key, node)) require.NoError(t, kv.Put(key, node))
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.NodeByHash(hash) result := oracle.NodeByHash(hash, chainID)
require.EqualValues(t, node, result) require.EqualValues(t, node, result)
}) })
t.Run("Unknown", func(t *testing.T) { t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cl, _ := createPrefetcher(t) prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectNodeByHash(hash, node, nil) l2Cl.ExpectNodeByHash(hash, node, nil)
defer l2Cl.MockDebugClient.AssertExpectations(t) defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.NodeByHash(hash, chainID)
require.EqualValues(t, node, result)
})
t.Run("WithChainID", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t, 5, 9, 99)
l2Cl := l2Cls.sources[9]
l2Cl.ExpectNodeByHash(hash, node, nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.NodeByHash(hash) result := oracle.NodeByHash(hash, 9)
require.EqualValues(t, node, result) require.EqualValues(t, node, result)
}) })
} }
...@@ -496,32 +537,96 @@ func TestFetchL2Code(t *testing.T) { ...@@ -496,32 +537,96 @@ func TestFetchL2Code(t *testing.T) {
code := testutils.RandomData(rng, 30) code := testutils.RandomData(rng, 30)
hash := crypto.Keccak256Hash(code) hash := crypto.Keccak256Hash(code)
key := preimage.Keccak256Key(hash).PreimageKey() key := preimage.Keccak256Key(hash).PreimageKey()
chainID := rng.Uint64()
t.Run("AlreadyKnown", func(t *testing.T) { t.Run("AlreadyKnown", func(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t) prefetcher, _, _, _, kv := createPrefetcher(t)
require.NoError(t, kv.Put(key, code)) require.NoError(t, kv.Put(key, code))
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.CodeByHash(hash) result := oracle.CodeByHash(hash, chainID)
require.EqualValues(t, code, result) require.EqualValues(t, code, result)
}) })
t.Run("Unknown", func(t *testing.T) { t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cl, _ := createPrefetcher(t) prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectCodeByHash(hash, code, nil) l2Cl.ExpectCodeByHash(hash, code, nil)
defer l2Cl.MockDebugClient.AssertExpectations(t) defer l2Cl.MockDebugClient.AssertExpectations(t)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.CodeByHash(hash) result := oracle.CodeByHash(hash, chainID)
require.EqualValues(t, code, result)
})
t.Run("WithChainID", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t, 8, 45, 98, 55)
l2Cl := l2Cls.sources[98]
l2Cl.ExpectCodeByHash(hash, code, nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.CodeByHash(hash, 98)
require.EqualValues(t, code, result) require.EqualValues(t, code, result)
}) })
} }
func TestFetchL2Output(t *testing.T) {
rng := rand.New(rand.NewSource(123))
output := testutils.RandomOutputV0(rng)
hash := common.Hash(eth.OutputRoot(output))
key := preimage.Keccak256Key(hash).PreimageKey()
t.Run("AlreadyKnown", func(t *testing.T) {
prefetcher, _, _, _, kv := createPrefetcher(t)
require.NoError(t, kv.Put(key, output.Marshal()))
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.OutputByRoot(hash, rng.Uint64())
require.EqualValues(t, output, result)
})
t.Run("Unknown", func(t *testing.T) {
prefetcher, _, _, l2Cls, _ := createPrefetcher(t)
l2Cl := l2Cls.sources[defaultChainID]
l2Cl.ExpectOutputByRoot(prefetcher.l2Head, output, nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.OutputByRoot(hash, rng.Uint64())
require.EqualValues(t, output, result)
})
t.Run("WithChainID", func(t *testing.T) {
chain6Output := testutils.RandomOutputV0(rng)
chain99Output := testutils.RandomOutputV0(rng)
timestamp := uint64(4567882)
superV1 := eth.SuperV1{
Timestamp: timestamp,
Chains: []eth.ChainIDAndOutput{
{ChainID: 6, Output: eth.OutputRoot(chain6Output)},
{ChainID: 78, Output: eth.OutputRoot(output)},
{ChainID: 99, Output: eth.OutputRoot(chain99Output)},
},
}
prefetcher, _, _, l2Cls, _ := createPrefetcherWithAgreedPrestate(t, superV1.Marshal(), 6, 78, 99)
l2Cl := l2Cls.sources[78]
blockNum, err := l2Cls.sources[78].RollupConfig().TargetBlockNumber(timestamp)
require.NoError(t, err)
l2Cl.ExpectOutputByNumber(blockNum, output, nil)
defer assertAllClientExpectations(t, l2Cls)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), true)
result := oracle.OutputByRoot(hash, 78)
require.EqualValues(t, output, result)
})
}
func TestFetchL2BlockData(t *testing.T) { func TestFetchL2BlockData(t *testing.T) {
chainID := uint64(14) chainID := uint64(14)
testBlockExec := func(t *testing.T, err error) { testBlockExec := func(t *testing.T, err error) {
prefetcher, _, _, l2Client, _ := createPrefetcher(t) prefetcher, _, _, l2Clients, _ := createPrefetcher(t)
l2Client := l2Clients.sources[defaultChainID]
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
block, _ := testutils.RandomBlock(rng, 10) block, _ := testutils.RandomBlock(rng, 10)
disputedBlockHash := common.Hash{0xab} disputedBlockHash := common.Hash{0xab}
...@@ -644,21 +749,23 @@ func TestRetryWhenNotAvailableAfterPrefetching(t *testing.T) { ...@@ -644,21 +749,23 @@ func TestRetryWhenNotAvailableAfterPrefetching(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
node := testutils.RandomData(rng, 30) node := testutils.RandomData(rng, 30)
hash := crypto.Keccak256Hash(node) hash := crypto.Keccak256Hash(node)
chainID := rng.Uint64()
_, l1Source, l1BlobSource, l2Cl, kv := createPrefetcher(t) _, l1Source, l1BlobSource, l2Cls, kv := createPrefetcher(t)
putsToIgnore := 2 putsToIgnore := 2
kv = &unreliableKvStore{KV: kv, putsToIgnore: putsToIgnore} kv = &unreliableKvStore{KV: kv, putsToIgnore: putsToIgnore}
sources := &l2Clients{sources: map[uint64]hostTypes.L2Source{6: l2Cl}} sources := &l2Clients{sources: map[uint64]*l2Client{6: l2Cls.sources[defaultChainID]}}
prefetcher := NewPrefetcher(testlog.Logger(t, log.LevelInfo), l1Source, l1BlobSource, 6, sources, kv, nil, common.Hash{}, nil) prefetcher := NewPrefetcher(testlog.Logger(t, log.LevelInfo), l1Source, l1BlobSource, 6, sources, kv, nil, common.Hash{}, nil)
l2Cl := sources.sources[6]
// Expect one call for each ignored put, plus one more request for when the put succeeds // Expect one call for each ignored put, plus one more request for when the put succeeds
for i := 0; i < putsToIgnore+1; i++ { for i := 0; i < putsToIgnore+1; i++ {
l2Cl.ExpectNodeByHash(hash, node, nil) l2Cl.ExpectNodeByHash(hash, node, nil)
} }
defer l2Cl.MockDebugClient.AssertExpectations(t) defer l2Cl.MockDebugClient.AssertExpectations(t)
oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher)) oracle := l2.NewPreimageOracle(asOracleFn(t, prefetcher), asHinter(t, prefetcher), false)
result := oracle.NodeByHash(hash) result := oracle.NodeByHash(hash, chainID)
require.EqualValues(t, node, result) require.EqualValues(t, node, result)
} }
...@@ -677,7 +784,7 @@ func (s *unreliableKvStore) Put(k common.Hash, v []byte) error { ...@@ -677,7 +784,7 @@ func (s *unreliableKvStore) Put(k common.Hash, v []byte) error {
} }
type l2Clients struct { type l2Clients struct {
sources map[uint64]hostTypes.L2Source sources map[uint64]*l2Client
} }
func (l *l2Clients) ForChainID(id uint64) (hostTypes.L2Source, error) { func (l *l2Clients) ForChainID(id uint64) (hostTypes.L2Source, error) {
...@@ -695,10 +802,11 @@ func (l *l2Clients) ForChainIDWithoutRetries(id uint64) (hostTypes.L2Source, err ...@@ -695,10 +802,11 @@ func (l *l2Clients) ForChainIDWithoutRetries(id uint64) (hostTypes.L2Source, err
type l2Client struct { type l2Client struct {
*testutils.MockL2Client *testutils.MockL2Client
*testutils.MockDebugClient *testutils.MockDebugClient
rollupCfg *rollup.Config
} }
func (m *l2Client) RollupConfig() *rollup.Config { func (m *l2Client) RollupConfig() *rollup.Config {
panic("implement me") return m.rollupCfg
} }
func (m *l2Client) ExperimentalEnabled() bool { func (m *l2Client) ExperimentalEnabled() bool {
...@@ -714,25 +822,47 @@ func (m *l2Client) ExpectOutputByRoot(blockRoot common.Hash, output eth.Output, ...@@ -714,25 +822,47 @@ func (m *l2Client) ExpectOutputByRoot(blockRoot common.Hash, output eth.Output,
m.Mock.On("OutputByRoot", blockRoot).Once().Return(output, &err) m.Mock.On("OutputByRoot", blockRoot).Once().Return(output, &err)
} }
func createPrefetcher(t *testing.T) (*Prefetcher, *testutils.MockL1Source, *testutils.MockBlobsFetcher, *l2Client, kvstore.KV) { func (m *l2Client) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
return createPrefetcherWithAgreedPrestate(t, nil) out := m.Mock.MethodCalled("OutputByNumber", blockNum)
return out[0].(eth.Output), *out[1].(*error)
}
func (m *l2Client) ExpectOutputByNumber(blockNum uint64, output eth.Output, err error) {
m.Mock.On("OutputByNumber", blockNum).Once().Return(output, &err)
} }
func createPrefetcherWithAgreedPrestate(t *testing.T, agreedPrestate []byte) (*Prefetcher, *testutils.MockL1Source, *testutils.MockBlobsFetcher, *l2Client, kvstore.KV) {
func createPrefetcher(t *testing.T, chainIDs ...uint64) (*Prefetcher, *testutils.MockL1Source, *testutils.MockBlobsFetcher, *l2Clients, kvstore.KV) {
return createPrefetcherWithAgreedPrestate(t, nil, chainIDs...)
}
func createPrefetcherWithAgreedPrestate(t *testing.T, agreedPrestate []byte, chainIDs ...uint64) (*Prefetcher, *testutils.MockL1Source, *testutils.MockBlobsFetcher, *l2Clients, kvstore.KV) {
logger := testlog.Logger(t, log.LevelDebug) logger := testlog.Logger(t, log.LevelDebug)
kv := kvstore.NewMemKV() kv := kvstore.NewMemKV()
l1Source := new(testutils.MockL1Source) l1Source := new(testutils.MockL1Source)
l1BlobSource := new(testutils.MockBlobsFetcher) l1BlobSource := new(testutils.MockBlobsFetcher)
l2Source := &l2Client{
MockL2Client: new(testutils.MockL2Client), // Provide a default chain if none specified.
MockDebugClient: new(testutils.MockDebugClient), if len(chainIDs) == 0 {
chainIDs = []uint64{defaultChainID}
} }
l2Sources := &l2Clients{
sources: map[uint64]hostTypes.L2Source{14: l2Source}, l2Sources := &l2Clients{sources: make(map[uint64]*l2Client)}
for i, chainID := range chainIDs {
l2Source := &l2Client{
rollupCfg: &rollup.Config{
// Make the block numbers for each chain differ at each timestamp
Genesis: rollup.Genesis{L2Time: 500 + uint64(2*i)},
BlockTime: 1,
},
MockL2Client: new(testutils.MockL2Client),
MockDebugClient: new(testutils.MockDebugClient),
}
l2Sources.sources[chainID] = l2Source
} }
prefetcher := NewPrefetcher(logger, l1Source, l1BlobSource, 14, l2Sources, kv, nil, common.Hash{0xdd}, agreedPrestate) prefetcher := NewPrefetcher(logger, l1Source, l1BlobSource, chainIDs[0], l2Sources, kv, nil, common.Hash{0xdd}, agreedPrestate)
return prefetcher, l1Source, l1BlobSource, l2Source, kv return prefetcher, l1Source, l1BlobSource, l2Sources, kv
} }
func storeBlock(t *testing.T, kv kvstore.KV, block *types.Block, receipts types.Receipts) { func storeBlock(t *testing.T, kv kvstore.KV, block *types.Block, receipts types.Receipts) {
...@@ -847,3 +977,10 @@ func (m *mockExecutor) RunProgram( ...@@ -847,3 +977,10 @@ func (m *mockExecutor) RunProgram(
m.chainID = chainID m.chainID = chainID
return nil return nil
} }
func assertAllClientExpectations(t *testing.T, l2Cls *l2Clients) {
for _, source := range l2Cls.sources {
source.Mock.AssertExpectations(t)
source.MockDebugClient.Mock.AssertExpectations(t)
}
}
...@@ -152,6 +152,17 @@ func (s *RetryingL2Source) OutputByRoot(ctx context.Context, blockRoot common.Ha ...@@ -152,6 +152,17 @@ func (s *RetryingL2Source) OutputByRoot(ctx context.Context, blockRoot common.Ha
}) })
} }
func (s *RetryingL2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
return retry.Do(ctx, maxAttempts, s.strategy, func() (eth.Output, error) {
o, err := s.source.OutputByNumber(ctx, blockNum)
if err != nil {
s.logger.Warn("Failed to fetch l2 output", "block", blockNum, "err", err)
return o, err
}
return o, nil
})
}
func NewRetryingL2Source(logger log.Logger, source hosttypes.L2Source) *RetryingL2Source { func NewRetryingL2Source(logger log.Logger, source hosttypes.L2Source) *RetryingL2Source {
return &RetryingL2Source{ return &RetryingL2Source{
logger: logger, logger: logger,
......
...@@ -226,7 +226,8 @@ func createL1BlobSource(t *testing.T) (*RetryingL1BlobSource, *testutils.MockBlo ...@@ -226,7 +226,8 @@ func createL1BlobSource(t *testing.T) (*RetryingL1BlobSource, *testutils.MockBlo
func TestRetryingL2Source(t *testing.T) { func TestRetryingL2Source(t *testing.T) {
ctx := context.Background() ctx := context.Background()
hash := common.Hash{0xab} hash := common.Hash{0xab}
info := &testutils.MockBlockInfo{InfoHash: hash} blockNum := uint64(14982)
info := &testutils.MockBlockInfo{InfoHash: hash, InfoNum: blockNum}
// The mock really doesn't like returning nil for a eth.BlockInfo so return a value we expect to be ignored instead // The mock really doesn't like returning nil for a eth.BlockInfo so return a value we expect to be ignored instead
wrongInfo := &testutils.MockBlockInfo{InfoHash: common.Hash{0x99}} wrongInfo := &testutils.MockBlockInfo{InfoHash: common.Hash{0x99}}
txs := types.Transactions{ txs := types.Transactions{
...@@ -325,6 +326,28 @@ func TestRetryingL2Source(t *testing.T) { ...@@ -325,6 +326,28 @@ func TestRetryingL2Source(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, output, actualOutput) require.Equal(t, output, actualOutput)
}) })
t.Run("OutputByNumber Success", func(t *testing.T) {
source, mock := createL2Source(t)
defer mock.AssertExpectations(t)
mock.ExpectOutputByNumber(blockNum, output, nil)
actualOutput, err := source.OutputByNumber(ctx, blockNum)
require.NoError(t, err)
require.Equal(t, output, actualOutput)
})
t.Run("OutputByNumber Error", func(t *testing.T) {
source, mock := createL2Source(t)
defer mock.AssertExpectations(t)
expectedErr := errors.New("boom")
mock.ExpectOutputByNumber(blockNum, wrongOutput, expectedErr)
mock.ExpectOutputByNumber(blockNum, output, nil)
actualOutput, err := source.OutputByNumber(ctx, blockNum)
require.NoError(t, err)
require.Equal(t, output, actualOutput)
})
} }
func createL2Source(t *testing.T) (*RetryingL2Source, *MockL2Source) { func createL2Source(t *testing.T) (*RetryingL2Source, *MockL2Source) {
...@@ -370,6 +393,11 @@ func (m *MockL2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash) ...@@ -370,6 +393,11 @@ func (m *MockL2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash)
return out[0].(eth.Output), *out[1].(*error) return out[0].(eth.Output), *out[1].(*error)
} }
func (m *MockL2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
out := m.Mock.MethodCalled("OutputByNumber", blockNum)
return out[0].(eth.Output), *out[1].(*error)
}
func (m *MockL2Source) ExpectInfoAndTxsByHash(blockHash common.Hash, info eth.BlockInfo, txs types.Transactions, err error) { func (m *MockL2Source) ExpectInfoAndTxsByHash(blockHash common.Hash, info eth.BlockInfo, txs types.Transactions, err error) {
m.Mock.On("InfoAndTxsByHash", blockHash).Once().Return(info, txs, &err) m.Mock.On("InfoAndTxsByHash", blockHash).Once().Return(info, txs, &err)
} }
...@@ -386,4 +414,8 @@ func (m *MockL2Source) ExpectOutputByRoot(blockHash common.Hash, output eth.Outp ...@@ -386,4 +414,8 @@ func (m *MockL2Source) ExpectOutputByRoot(blockHash common.Hash, output eth.Outp
m.Mock.On("OutputByRoot", blockHash).Once().Return(output, &err) m.Mock.On("OutputByRoot", blockHash).Once().Return(output, &err)
} }
func (m *MockL2Source) ExpectOutputByNumber(blockNum uint64, output eth.Output, err error) {
m.Mock.On("OutputByNumber", blockNum).Once().Return(output, &err)
}
var _ hosttypes.L2Source = (*MockL2Source)(nil) var _ hosttypes.L2Source = (*MockL2Source)(nil)
...@@ -24,6 +24,7 @@ type L2Source interface { ...@@ -24,6 +24,7 @@ type L2Source interface {
NodeByHash(ctx context.Context, hash common.Hash) ([]byte, error) NodeByHash(ctx context.Context, hash common.Hash) ([]byte, error)
CodeByHash(ctx context.Context, hash common.Hash) ([]byte, error) CodeByHash(ctx context.Context, hash common.Hash) ([]byte, error)
OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error)
OutputByNumber(ctx context.Context, blockNumber uint64) (eth.Output, error)
RollupConfig() *rollup.Config RollupConfig() *rollup.Config
ExperimentalEnabled() bool ExperimentalEnabled() bool
} }
......
...@@ -172,15 +172,26 @@ func (s *L2Client) SystemConfigByL2Hash(ctx context.Context, hash common.Hash) ( ...@@ -172,15 +172,26 @@ func (s *L2Client) SystemConfigByL2Hash(ctx context.Context, hash common.Hash) (
return cfg, nil return cfg, nil
} }
func (s *L2Client) OutputV0AtBlockNumber(ctx context.Context, blockNum uint64) (*eth.OutputV0, error) {
head, err := s.InfoByNumber(ctx, blockNum)
if err != nil {
return nil, fmt.Errorf("failed to get L2 block by hash: %w", err)
}
return s.outputV0(ctx, head)
}
func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) (*eth.OutputV0, error) { func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) (*eth.OutputV0, error) {
head, err := s.InfoByHash(ctx, blockHash) head, err := s.InfoByHash(ctx, blockHash)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get L2 block by hash: %w", err) return nil, fmt.Errorf("failed to get L2 block by hash: %w", err)
} }
if head == nil { return s.outputV0(ctx, head)
}
func (s *L2Client) outputV0(ctx context.Context, block eth.BlockInfo) (*eth.OutputV0, error) {
if block == nil {
return nil, ethereum.NotFound return nil, ethereum.NotFound
} }
blockHash := block.Hash()
proof, err := s.GetProof(ctx, predeploys.L2ToL1MessagePasserAddr, []common.Hash{}, blockHash.String()) proof, err := s.GetProof(ctx, predeploys.L2ToL1MessagePasserAddr, []common.Hash{}, blockHash.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get contract proof at block %s: %w", blockHash, err) return nil, fmt.Errorf("failed to get contract proof at block %s: %w", blockHash, err)
...@@ -189,10 +200,10 @@ func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) ( ...@@ -189,10 +200,10 @@ func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) (
return nil, fmt.Errorf("proof %w", ethereum.NotFound) return nil, fmt.Errorf("proof %w", ethereum.NotFound)
} }
// make sure that the proof (including storage hash) that we retrieved is correct by verifying it against the state-root // make sure that the proof (including storage hash) that we retrieved is correct by verifying it against the state-root
if err := proof.Verify(head.Root()); err != nil { if err := proof.Verify(block.Root()); err != nil {
return nil, fmt.Errorf("invalid withdrawal root hash, state root was %s: %w", head.Root(), err) return nil, fmt.Errorf("invalid withdrawal root hash, state root was %s: %w", block.Root(), err)
} }
stateRoot := head.Root() stateRoot := block.Root()
return &eth.OutputV0{ return &eth.OutputV0{
StateRoot: eth.Bytes32(stateRoot), StateRoot: eth.Bytes32(stateRoot),
MessagePasserStorageRoot: eth.Bytes32(proof.StorageHash), MessagePasserStorageRoot: eth.Bytes32(proof.StorageHash),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment