Commit ca583f7f authored by Adrian Sutton's avatar Adrian Sutton Committed by GitHub

op-program: Add chain ID to hints when using interop (#13735)

* op-program: Add L2 chain ID to hints when using interop

* op-program: Add L2 chain ID to OutputRoot hints when using interop

* Fix label in error message.
Co-authored-by: default avatarInphi <mlaw2501@gmail.com>

---------
Co-authored-by: default avatarInphi <mlaw2501@gmail.com>
parent 66a67e02
...@@ -319,7 +319,6 @@ func TestInteropFaultProofs(gt *testing.T) { ...@@ -319,7 +319,6 @@ func TestInteropFaultProofs(gt *testing.T) {
agreedClaim: step1Expected, agreedClaim: step1Expected,
disputedClaim: step2Expected, disputedClaim: step2Expected,
expectValid: true, expectValid: true,
skip: true,
}, },
{ {
name: "PaddingStep", name: "PaddingStep",
......
...@@ -39,42 +39,42 @@ func NewCachingOracle(oracle Oracle) *CachingOracle { ...@@ -39,42 +39,42 @@ func NewCachingOracle(oracle Oracle) *CachingOracle {
} }
} }
func (o *CachingOracle) NodeByHash(nodeHash common.Hash) []byte { func (o *CachingOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
node, ok := o.nodes.Get(nodeHash) node, ok := o.nodes.Get(nodeHash)
if ok { if ok {
return node return node
} }
node = o.oracle.NodeByHash(nodeHash) node = o.oracle.NodeByHash(nodeHash, chainID)
o.nodes.Add(nodeHash, node) o.nodes.Add(nodeHash, node)
return node return node
} }
func (o *CachingOracle) CodeByHash(codeHash common.Hash) []byte { func (o *CachingOracle) CodeByHash(codeHash common.Hash, chainID uint64) []byte {
code, ok := o.codes.Get(codeHash) code, ok := o.codes.Get(codeHash)
if ok { if ok {
return code return code
} }
code = o.oracle.CodeByHash(codeHash) code = o.oracle.CodeByHash(codeHash, chainID)
o.codes.Add(codeHash, code) o.codes.Add(codeHash, code)
return code return code
} }
func (o *CachingOracle) BlockByHash(blockHash common.Hash) *types.Block { func (o *CachingOracle) BlockByHash(blockHash common.Hash, chainID uint64) *types.Block {
block, ok := o.blocks.Get(blockHash) block, ok := o.blocks.Get(blockHash)
if ok { if ok {
return block return block
} }
block = o.oracle.BlockByHash(blockHash) block = o.oracle.BlockByHash(blockHash, chainID)
o.blocks.Add(blockHash, block) o.blocks.Add(blockHash, block)
return block return block
} }
func (o *CachingOracle) OutputByRoot(root common.Hash) eth.Output { func (o *CachingOracle) OutputByRoot(root common.Hash, chainID uint64) eth.Output {
output, ok := o.outputs.Get(root) output, ok := o.outputs.Get(root)
if ok { if ok {
return output return output
} }
output = o.oracle.OutputByRoot(root) output = o.oracle.OutputByRoot(root, chainID)
o.outputs.Add(root, output) o.outputs.Add(root, output)
return output return output
} }
......
...@@ -15,6 +15,7 @@ import ( ...@@ -15,6 +15,7 @@ import (
var _ Oracle = (*CachingOracle)(nil) var _ Oracle = (*CachingOracle)(nil)
func TestBlockByHash(t *testing.T) { func TestBlockByHash(t *testing.T) {
chainID := uint64(48294)
stub, _ := test.NewStubOracle(t) stub, _ := test.NewStubOracle(t)
oracle := NewCachingOracle(stub) oracle := NewCachingOracle(stub)
...@@ -23,12 +24,12 @@ func TestBlockByHash(t *testing.T) { ...@@ -23,12 +24,12 @@ func TestBlockByHash(t *testing.T) {
// Initial call retrieves from the stub // Initial call retrieves from the stub
stub.Blocks[block.Hash()] = block stub.Blocks[block.Hash()] = block
actual := oracle.BlockByHash(block.Hash()) actual := oracle.BlockByHash(block.Hash(), chainID)
require.Equal(t, block, actual) require.Equal(t, block, actual)
// Later calls should retrieve from cache // Later calls should retrieve from cache (even if chain ID is different)
delete(stub.Blocks, block.Hash()) delete(stub.Blocks, block.Hash())
actual = oracle.BlockByHash(block.Hash()) actual = oracle.BlockByHash(block.Hash(), 9982)
require.Equal(t, block, actual) require.Equal(t, block, actual)
} }
...@@ -41,12 +42,12 @@ func TestNodeByHash(t *testing.T) { ...@@ -41,12 +42,12 @@ func TestNodeByHash(t *testing.T) {
// Initial call retrieves from the stub // Initial call retrieves from the stub
stateStub.Data[hash] = node stateStub.Data[hash] = node
actual := oracle.NodeByHash(hash) actual := oracle.NodeByHash(hash, 1234)
require.Equal(t, node, actual) require.Equal(t, node, actual)
// Later calls should retrieve from cache // Later calls should retrieve from cache (even if chain ID is different)
delete(stateStub.Data, hash) delete(stateStub.Data, hash)
actual = oracle.NodeByHash(hash) actual = oracle.NodeByHash(hash, 997845)
require.Equal(t, node, actual) require.Equal(t, node, actual)
} }
...@@ -59,12 +60,12 @@ func TestCodeByHash(t *testing.T) { ...@@ -59,12 +60,12 @@ func TestCodeByHash(t *testing.T) {
// Initial call retrieves from the stub // Initial call retrieves from the stub
stateStub.Code[hash] = node stateStub.Code[hash] = node
actual := oracle.CodeByHash(hash) actual := oracle.CodeByHash(hash, 342)
require.Equal(t, node, actual) require.Equal(t, node, actual)
// Later calls should retrieve from cache // Later calls should retrieve from cache (even if the chain ID is different)
delete(stateStub.Code, hash) delete(stateStub.Code, hash)
actual = oracle.CodeByHash(hash) actual = oracle.CodeByHash(hash, 986776)
require.Equal(t, node, actual) require.Equal(t, node, actual)
} }
...@@ -78,11 +79,11 @@ func TestOutputByRoot(t *testing.T) { ...@@ -78,11 +79,11 @@ func TestOutputByRoot(t *testing.T) {
// Initial call retrieves from the stub // Initial call retrieves from the stub
root := common.Hash(eth.OutputRoot(output)) root := common.Hash(eth.OutputRoot(output))
stub.Outputs[root] = output stub.Outputs[root] = output
actual := oracle.OutputByRoot(root) actual := oracle.OutputByRoot(root, 59284)
require.Equal(t, output, actual) require.Equal(t, output, actual)
// Later calls should retrieve from cache // Later calls should retrieve from cache (even if the chain ID is different)
delete(stub.Outputs, root) delete(stub.Outputs, root)
actual = oracle.OutputByRoot(root) actual = oracle.OutputByRoot(root, 9193)
require.Equal(t, output, actual) require.Equal(t, output, actual)
} }
...@@ -16,14 +16,16 @@ var codePrefixedKeyLength = common.HashLength + len(rawdb.CodePrefix) ...@@ -16,14 +16,16 @@ var codePrefixedKeyLength = common.HashLength + len(rawdb.CodePrefix)
var ErrInvalidKeyLength = errors.New("pre-images must be identified by 32-byte hash keys") var ErrInvalidKeyLength = errors.New("pre-images must be identified by 32-byte hash keys")
type OracleKeyValueStore struct { type OracleKeyValueStore struct {
db ethdb.KeyValueStore db ethdb.KeyValueStore
oracle StateOracle oracle StateOracle
chainID uint64
} }
func NewOracleBackedDB(oracle StateOracle) *OracleKeyValueStore { func NewOracleBackedDB(oracle StateOracle, chainID uint64) *OracleKeyValueStore {
return &OracleKeyValueStore{ return &OracleKeyValueStore{
db: memorydb.New(), db: memorydb.New(),
oracle: oracle, oracle: oracle,
chainID: chainID,
} }
} }
...@@ -38,12 +40,12 @@ func (o *OracleKeyValueStore) Get(key []byte) ([]byte, error) { ...@@ -38,12 +40,12 @@ func (o *OracleKeyValueStore) Get(key []byte) ([]byte, error) {
if len(key) == codePrefixedKeyLength && bytes.HasPrefix(key, rawdb.CodePrefix) { if len(key) == codePrefixedKeyLength && bytes.HasPrefix(key, rawdb.CodePrefix) {
key = key[len(rawdb.CodePrefix):] key = key[len(rawdb.CodePrefix):]
return o.oracle.CodeByHash(*(*[common.HashLength]byte)(key)), nil return o.oracle.CodeByHash(*(*[common.HashLength]byte)(key), o.chainID), nil
} }
if len(key) != common.HashLength { if len(key) != common.HashLength {
return nil, ErrInvalidKeyLength return nil, ErrInvalidKeyLength
} }
return o.oracle.NodeByHash(*(*[common.HashLength]byte)(key)), nil return o.oracle.NodeByHash(*(*[common.HashLength]byte)(key), o.chainID), nil
} }
func (o *OracleKeyValueStore) NewBatch() ethdb.Batch { func (o *OracleKeyValueStore) NewBatch() ethdb.Batch {
......
...@@ -34,7 +34,7 @@ var _ ethdb.KeyValueStore = (*OracleKeyValueStore)(nil) ...@@ -34,7 +34,7 @@ var _ ethdb.KeyValueStore = (*OracleKeyValueStore)(nil)
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
t.Run("IncorrectLengthKey", func(t *testing.T) { t.Run("IncorrectLengthKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
val, err := db.Get([]byte{1, 2, 3}) val, err := db.Get([]byte{1, 2, 3})
require.ErrorIs(t, err, ErrInvalidKeyLength) require.ErrorIs(t, err, ErrInvalidKeyLength)
require.Nil(t, val) require.Nil(t, val)
...@@ -42,7 +42,7 @@ func TestGet(t *testing.T) { ...@@ -42,7 +42,7 @@ func TestGet(t *testing.T) {
t.Run("KeyWithCodePrefix", func(t *testing.T) { t.Run("KeyWithCodePrefix", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0x12345678") key := common.HexToHash("0x12345678")
prefixedKey := append(rawdb.CodePrefix, key.Bytes()...) prefixedKey := append(rawdb.CodePrefix, key.Bytes()...)
...@@ -56,7 +56,7 @@ func TestGet(t *testing.T) { ...@@ -56,7 +56,7 @@ func TestGet(t *testing.T) {
t.Run("NormalKeyThatHappensToStartWithCodePrefix", func(t *testing.T) { t.Run("NormalKeyThatHappensToStartWithCodePrefix", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
key := make([]byte, common.HashLength) key := make([]byte, common.HashLength)
copy(rawdb.CodePrefix, key) copy(rawdb.CodePrefix, key)
fmt.Println(key[0]) fmt.Println(key[0])
...@@ -73,7 +73,7 @@ func TestGet(t *testing.T) { ...@@ -73,7 +73,7 @@ func TestGet(t *testing.T) {
expected := []byte{2, 6, 3, 8} expected := []byte{2, 6, 3, 8}
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
oracle.Data[key] = expected oracle.Data[key] = expected
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
val, err := db.Get(key.Bytes()) val, err := db.Get(key.Bytes())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expected, val) require.Equal(t, expected, val)
...@@ -83,7 +83,7 @@ func TestGet(t *testing.T) { ...@@ -83,7 +83,7 @@ func TestGet(t *testing.T) {
func TestPut(t *testing.T) { func TestPut(t *testing.T) {
t.Run("NewKey", func(t *testing.T) { t.Run("NewKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0xAA4488") key := common.HexToHash("0xAA4488")
value := []byte{2, 6, 3, 8} value := []byte{2, 6, 3, 8}
err := db.Put(key.Bytes(), value) err := db.Put(key.Bytes(), value)
...@@ -95,7 +95,7 @@ func TestPut(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestPut(t *testing.T) {
}) })
t.Run("ReplaceKey", func(t *testing.T) { t.Run("ReplaceKey", func(t *testing.T) {
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle, 1234)
key := common.HexToHash("0xAA4488") key := common.HexToHash("0xAA4488")
value1 := []byte{2, 6, 3, 8} value1 := []byte{2, 6, 3, 8}
value2 := []byte{1, 2, 3} value2 := []byte{1, 2, 3}
...@@ -117,13 +117,13 @@ func TestSupportsStateDBOperations(t *testing.T) { ...@@ -117,13 +117,13 @@ func TestSupportsStateDBOperations(t *testing.T) {
genesisBlock := l2Genesis.MustCommit(realDb, trieDB) genesisBlock := l2Genesis.MustCommit(realDb, trieDB)
loader := test.NewKvStateOracle(t, realDb) loader := test.NewKvStateOracle(t, realDb)
assertStateDataAvailable(t, NewOracleBackedDB(loader), l2Genesis, genesisBlock) assertStateDataAvailable(t, NewOracleBackedDB(loader, 1234), l2Genesis, genesisBlock)
} }
func TestUpdateState(t *testing.T) { func TestUpdateState(t *testing.T) {
l2Genesis := createGenesis() l2Genesis := createGenesis()
oracle := test.NewStubStateOracle(t) oracle := test.NewStubStateOracle(t)
db := rawdb.NewDatabase(NewOracleBackedDB(oracle)) db := rawdb.NewDatabase(NewOracleBackedDB(oracle, 1234))
trieDB := triedb.NewDatabase(db, &triedb.Config{HashDB: hashdb.Defaults}) trieDB := triedb.NewDatabase(db, &triedb.Config{HashDB: hashdb.Defaults})
genesisBlock := l2Genesis.MustCommit(db, trieDB) genesisBlock := l2Genesis.MustCommit(db, trieDB)
......
...@@ -45,12 +45,13 @@ type OracleBackedL2Chain struct { ...@@ -45,12 +45,13 @@ type OracleBackedL2Chain struct {
var _ engineapi.CachingEngineBackend = (*OracleBackedL2Chain)(nil) var _ engineapi.CachingEngineBackend = (*OracleBackedL2Chain)(nil)
func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle engineapi.PrecompileOracle, chainCfg *params.ChainConfig, l2OutputRoot common.Hash) (*OracleBackedL2Chain, error) { func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle engineapi.PrecompileOracle, chainCfg *params.ChainConfig, l2OutputRoot common.Hash) (*OracleBackedL2Chain, error) {
output := oracle.OutputByRoot(l2OutputRoot) chainID := chainCfg.ChainID.Uint64()
output := oracle.OutputByRoot(l2OutputRoot, chainID)
outputV0, ok := output.(*eth.OutputV0) outputV0, ok := output.(*eth.OutputV0)
if !ok { if !ok {
return nil, fmt.Errorf("unsupported L2 output version: %d", output.Version()) return nil, fmt.Errorf("unsupported L2 output version: %d", output.Version())
} }
head := oracle.BlockByHash(outputV0.BlockHash) head := oracle.BlockByHash(outputV0.BlockHash, chainID)
logger.Info("Loaded L2 head", "hash", head.Hash(), "number", head.Number()) logger.Info("Loaded L2 head", "hash", head.Hash(), "number", head.Number())
return &OracleBackedL2Chain{ return &OracleBackedL2Chain{
log: logger, log: logger,
...@@ -69,7 +70,7 @@ func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle e ...@@ -69,7 +70,7 @@ func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, precompileOracle e
finalized: head.Header(), finalized: head.Header(),
oracleHead: head.Header(), oracleHead: head.Header(),
blocks: make(map[common.Hash]*types.Block), blocks: make(map[common.Hash]*types.Block),
db: NewOracleBackedDB(oracle), db: NewOracleBackedDB(oracle, chainID),
vmCfg: vm.Config{ vmCfg: vm.Config{
PrecompileOverrides: engineapi.CreatePrecompileOverrides(precompileOracle), PrecompileOverrides: engineapi.CreatePrecompileOverrides(precompileOracle),
}, },
...@@ -122,7 +123,7 @@ func (o *OracleBackedL2Chain) GetBlockByHash(hash common.Hash) *types.Block { ...@@ -122,7 +123,7 @@ func (o *OracleBackedL2Chain) GetBlockByHash(hash common.Hash) *types.Block {
return block return block
} }
// Retrieve from the oracle // Retrieve from the oracle
return o.oracle.BlockByHash(hash) return o.oracle.BlockByHash(hash, o.chainCfg.ChainID.Uint64())
} }
func (o *OracleBackedL2Chain) GetBlock(hash common.Hash, number uint64) *types.Block { func (o *OracleBackedL2Chain) GetBlock(hash common.Hash, number uint64) *types.Block {
......
...@@ -378,7 +378,7 @@ func createBlock(t *testing.T, chain *OracleBackedL2Chain, opts ...blockCreateOp ...@@ -378,7 +378,7 @@ func createBlock(t *testing.T, chain *OracleBackedL2Chain, opts ...blockCreateOp
require.NoError(t, err) require.NoError(t, err)
nonce := parentDB.GetNonce(fundedAddress) nonce := parentDB.GetNonce(fundedAddress)
config := chain.Config() config := chain.Config()
db := rawdb.NewDatabase(NewOracleBackedDB(chain.oracle)) db := rawdb.NewDatabase(NewOracleBackedDB(chain.oracle, config.ChainID.Uint64()))
blocks, _ := core.GenerateChain(config, parent, chain.Engine(), db, 1, func(i int, gen *core.BlockGen) { blocks, _ := core.GenerateChain(config, parent, chain.Engine(), db, 1, func(i int, gen *core.BlockGen) {
rawTx := &types.DynamicFeeTx{ rawTx := &types.DynamicFeeTx{
ChainID: config.ChainID, ChainID: config.ChainID,
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
preimage "github.com/ethereum-optimism/optimism/op-preimage" preimage "github.com/ethereum-optimism/optimism/op-preimage"
) )
...@@ -19,43 +20,95 @@ const ( ...@@ -19,43 +20,95 @@ const (
HintAgreedPrestate = "agreed-pre-state" HintAgreedPrestate = "agreed-pre-state"
) )
type BlockHeaderHint common.Hash type LegacyBlockHeaderHint common.Hash
var _ preimage.Hint = LegacyBlockHeaderHint{}
func (l LegacyBlockHeaderHint) Hint() string {
return HintL2BlockHeader + " " + (common.Hash)(l).String()
}
type HashAndChainID struct {
Hash common.Hash
ChainID uint64
}
func (h HashAndChainID) Marshal() []byte {
d := make([]byte, 32+8)
copy(d[:32], h.Hash[:])
binary.BigEndian.PutUint64(d[32:], h.ChainID)
return d
}
type BlockHeaderHint HashAndChainID
var _ preimage.Hint = BlockHeaderHint{} var _ preimage.Hint = BlockHeaderHint{}
func (l BlockHeaderHint) Hint() string { func (l BlockHeaderHint) Hint() string {
return HintL2BlockHeader + " " + (common.Hash)(l).String() return HintL2BlockHeader + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyTransactionsHint common.Hash
var _ preimage.Hint = LegacyTransactionsHint{}
func (l LegacyTransactionsHint) Hint() string {
return HintL2Transactions + " " + (common.Hash)(l).String()
} }
type TransactionsHint common.Hash type TransactionsHint HashAndChainID
var _ preimage.Hint = TransactionsHint{} var _ preimage.Hint = TransactionsHint{}
func (l TransactionsHint) Hint() string { func (l TransactionsHint) Hint() string {
return HintL2Transactions + " " + (common.Hash)(l).String() return HintL2Transactions + " " + hexutil.Encode(HashAndChainID(l).Marshal())
} }
type CodeHint common.Hash type CodeHint HashAndChainID
var _ preimage.Hint = CodeHint{} var _ preimage.Hint = CodeHint{}
func (l CodeHint) Hint() string { func (l CodeHint) Hint() string {
return HintL2Code + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyCodeHint common.Hash
var _ preimage.Hint = LegacyCodeHint{}
func (l LegacyCodeHint) Hint() string {
return HintL2Code + " " + (common.Hash)(l).String() return HintL2Code + " " + (common.Hash)(l).String()
} }
type StateNodeHint common.Hash type StateNodeHint HashAndChainID
var _ preimage.Hint = StateNodeHint{} var _ preimage.Hint = StateNodeHint{}
func (l StateNodeHint) Hint() string { func (l StateNodeHint) Hint() string {
return HintL2StateNode + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyStateNodeHint common.Hash
var _ preimage.Hint = LegacyStateNodeHint{}
func (l LegacyStateNodeHint) Hint() string {
return HintL2StateNode + " " + (common.Hash)(l).String() return HintL2StateNode + " " + (common.Hash)(l).String()
} }
type L2OutputHint common.Hash type L2OutputHint HashAndChainID
var _ preimage.Hint = L2OutputHint{} var _ preimage.Hint = L2OutputHint{}
func (l L2OutputHint) Hint() string { func (l L2OutputHint) Hint() string {
return HintL2Output + " " + hexutil.Encode(HashAndChainID(l).Marshal())
}
type LegacyL2OutputHint common.Hash
var _ preimage.Hint = LegacyL2OutputHint{}
func (l LegacyL2OutputHint) Hint() string {
return HintL2Output + " " + (common.Hash)(l).String() return HintL2Output + " " + (common.Hash)(l).String()
} }
......
...@@ -19,11 +19,11 @@ type StateOracle interface { ...@@ -19,11 +19,11 @@ type StateOracle interface {
// NodeByHash retrieves the merkle-patricia trie node pre-image for a given hash. // NodeByHash retrieves the merkle-patricia trie node pre-image for a given hash.
// Trie nodes may be from the world state trie or any account storage trie. // Trie nodes may be from the world state trie or any account storage trie.
// Contract code is not stored as part of the trie and must be retrieved via CodeByHash // Contract code is not stored as part of the trie and must be retrieved via CodeByHash
NodeByHash(nodeHash common.Hash) []byte NodeByHash(nodeHash common.Hash, chainID uint64) []byte
// CodeByHash retrieves the contract code pre-image for a given hash. // CodeByHash retrieves the contract code pre-image for a given hash.
// codeHash should be retrieved from the world state account for a contract. // codeHash should be retrieved from the world state account for a contract.
CodeByHash(codeHash common.Hash) []byte CodeByHash(codeHash common.Hash, chainID uint64) []byte
} }
// Oracle defines the high-level API used to retrieve L2 data. // Oracle defines the high-level API used to retrieve L2 data.
...@@ -32,9 +32,9 @@ type Oracle interface { ...@@ -32,9 +32,9 @@ type Oracle interface {
StateOracle StateOracle
// BlockByHash retrieves the block with the given hash. // BlockByHash retrieves the block with the given hash.
BlockByHash(blockHash common.Hash) *types.Block BlockByHash(blockHash common.Hash, chainID uint64) *types.Block
OutputByRoot(root common.Hash) eth.Output OutputByRoot(root common.Hash, chainID uint64) eth.Output
// BlockDataByHash retrieves the block, including all data used to construct it. // BlockDataByHash retrieves the block, including all data used to construct it.
BlockDataByHash(agreedBlockHash, blockHash common.Hash, chainID uint64) *types.Block BlockDataByHash(agreedBlockHash, blockHash common.Hash, chainID uint64) *types.Block
...@@ -45,21 +45,27 @@ type Oracle interface { ...@@ -45,21 +45,27 @@ type Oracle interface {
// PreimageOracle implements Oracle using by interfacing with the pure preimage.Oracle // PreimageOracle implements Oracle using by interfacing with the pure preimage.Oracle
// to fetch pre-images to decode into the requested data. // to fetch pre-images to decode into the requested data.
type PreimageOracle struct { type PreimageOracle struct {
oracle preimage.Oracle oracle preimage.Oracle
hint preimage.Hinter hint preimage.Hinter
hintL2ChainIDs bool
} }
var _ Oracle = (*PreimageOracle)(nil) var _ Oracle = (*PreimageOracle)(nil)
func NewPreimageOracle(raw preimage.Oracle, hint preimage.Hinter) *PreimageOracle { func NewPreimageOracle(raw preimage.Oracle, hint preimage.Hinter, hintL2ChainIDs bool) *PreimageOracle {
return &PreimageOracle{ return &PreimageOracle{
oracle: raw, oracle: raw,
hint: hint, hint: hint,
hintL2ChainIDs: hintL2ChainIDs,
} }
} }
func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash) *types.Header { func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash, chainID uint64) *types.Header {
p.hint.Hint(BlockHeaderHint(blockHash)) if p.hintL2ChainIDs {
p.hint.Hint(BlockHeaderHint{Hash: blockHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyBlockHeaderHint(blockHash))
}
headerRlp := p.oracle.Get(preimage.Keccak256Key(blockHash)) headerRlp := p.oracle.Get(preimage.Keccak256Key(blockHash))
var header types.Header var header types.Header
if err := rlp.DecodeBytes(headerRlp, &header); err != nil { if err := rlp.DecodeBytes(headerRlp, &header); err != nil {
...@@ -68,15 +74,19 @@ func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash) *types.Header ...@@ -68,15 +74,19 @@ func (p *PreimageOracle) headerByBlockHash(blockHash common.Hash) *types.Header
return &header return &header
} }
func (p *PreimageOracle) BlockByHash(blockHash common.Hash) *types.Block { func (p *PreimageOracle) BlockByHash(blockHash common.Hash, chainID uint64) *types.Block {
header := p.headerByBlockHash(blockHash) header := p.headerByBlockHash(blockHash, chainID)
txs := p.LoadTransactions(blockHash, header.TxHash) txs := p.LoadTransactions(blockHash, header.TxHash, chainID)
return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs}) return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs})
} }
func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.Hash) []*types.Transaction { func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.Hash, chainID uint64) []*types.Transaction {
p.hint.Hint(TransactionsHint(blockHash)) if p.hintL2ChainIDs {
p.hint.Hint(TransactionsHint{Hash: blockHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyTransactionsHint(blockHash))
}
opaqueTxs := mpt.ReadTrie(txHash, func(key common.Hash) []byte { opaqueTxs := mpt.ReadTrie(txHash, func(key common.Hash) []byte {
return p.oracle.Get(preimage.Keccak256Key(key)) return p.oracle.Get(preimage.Keccak256Key(key))
...@@ -89,18 +99,30 @@ func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.H ...@@ -89,18 +99,30 @@ func (p *PreimageOracle) LoadTransactions(blockHash common.Hash, txHash common.H
return txs return txs
} }
func (p *PreimageOracle) NodeByHash(nodeHash common.Hash) []byte { func (p *PreimageOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
p.hint.Hint(StateNodeHint(nodeHash)) if p.hintL2ChainIDs {
p.hint.Hint(StateNodeHint{Hash: nodeHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyStateNodeHint(nodeHash))
}
return p.oracle.Get(preimage.Keccak256Key(nodeHash)) return p.oracle.Get(preimage.Keccak256Key(nodeHash))
} }
func (p *PreimageOracle) CodeByHash(codeHash common.Hash) []byte { func (p *PreimageOracle) CodeByHash(codeHash common.Hash, chainID uint64) []byte {
p.hint.Hint(CodeHint(codeHash)) if p.hintL2ChainIDs {
p.hint.Hint(CodeHint{Hash: codeHash, ChainID: chainID})
} else {
p.hint.Hint(LegacyCodeHint(codeHash))
}
return p.oracle.Get(preimage.Keccak256Key(codeHash)) return p.oracle.Get(preimage.Keccak256Key(codeHash))
} }
func (p *PreimageOracle) OutputByRoot(l2OutputRoot common.Hash) eth.Output { func (p *PreimageOracle) OutputByRoot(l2OutputRoot common.Hash, chainID uint64) eth.Output {
p.hint.Hint(L2OutputHint(l2OutputRoot)) if p.hintL2ChainIDs {
p.hint.Hint(L2OutputHint{Hash: l2OutputRoot, ChainID: chainID})
} else {
p.hint.Hint(LegacyL2OutputHint(l2OutputRoot))
}
data := p.oracle.Get(preimage.Keccak256Key(l2OutputRoot)) data := p.oracle.Get(preimage.Keccak256Key(l2OutputRoot))
output, err := eth.UnmarshalOutput(data) output, err := eth.UnmarshalOutput(data)
if err != nil { if err != nil {
...@@ -116,8 +138,8 @@ func (p *PreimageOracle) BlockDataByHash(agreedBlockHash, blockHash common.Hash, ...@@ -116,8 +138,8 @@ func (p *PreimageOracle) BlockDataByHash(agreedBlockHash, blockHash common.Hash,
ChainID: chainID, ChainID: chainID,
} }
p.hint.Hint(hint) p.hint.Hint(hint)
header := p.headerByBlockHash(blockHash) header := p.headerByBlockHash(blockHash, chainID)
txs := p.LoadTransactions(blockHash, header.TxHash) txs := p.LoadTransactions(blockHash, header.TxHash, chainID)
return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs}) return types.NewBlockWithHeader(header).WithBody(types.Body{Transactions: txs})
} }
......
...@@ -19,29 +19,27 @@ import ( ...@@ -19,29 +19,27 @@ import (
"github.com/ethereum-optimism/optimism/op-service/testutils" "github.com/ethereum-optimism/optimism/op-service/testutils"
) )
func mockPreimageOracle(t *testing.T) (po *PreimageOracle, hintsMock *mock.Mock, preimages map[common.Hash][]byte) { func mockPreimageOracle(t *testing.T, hintL2ChainIDs bool) (po *PreimageOracle, hintsMock *mock.Mock, preimages map[common.Hash][]byte) {
// Prepare the pre-images // Prepare the pre-images
preimages = make(map[common.Hash][]byte) preimages = make(map[common.Hash][]byte)
hintsMock = new(mock.Mock) hintsMock = new(mock.Mock)
po = &PreimageOracle{ rawOracle := preimage.OracleFn(func(key preimage.Key) []byte {
oracle: preimage.OracleFn(func(key preimage.Key) []byte { v, ok := preimages[key.PreimageKey()]
v, ok := preimages[key.PreimageKey()] require.True(t, ok, "preimage must exist")
require.True(t, ok, "preimage must exist") return v
return v })
}), hinter := preimage.HinterFn(func(v preimage.Hint) {
hint: preimage.HinterFn(func(v preimage.Hint) { hintsMock.MethodCalled("hint", v.Hint())
hintsMock.MethodCalled("hint", v.Hint()) })
}), po = NewPreimageOracle(rawOracle, hinter, hintL2ChainIDs)
} return
return po, hintsMock, preimages
} }
// testBlock tests that the given block can be passed through the preimage oracle. // testBlock tests that the given block can be passed through the preimage oracle.
func testBlock(t *testing.T, block *types.Block) { func testBlock(t *testing.T, block *types.Block, hintL2ChainIDs bool) {
po, hints, preimages := mockPreimageOracle(t) po, hints, preimages := mockPreimageOracle(t, hintL2ChainIDs)
hdrBytes, err := rlp.EncodeToBytes(block.Header()) hdrBytes, err := rlp.EncodeToBytes(block.Header())
require.NoError(t, err) require.NoError(t, err)
...@@ -54,12 +52,19 @@ func testBlock(t *testing.T, block *types.Block) { ...@@ -54,12 +52,19 @@ func testBlock(t *testing.T, block *types.Block) {
preimages[preimage.Keccak256Key(crypto.Keccak256Hash(p)).PreimageKey()] = p preimages[preimage.Keccak256Key(crypto.Keccak256Hash(p)).PreimageKey()] = p
} }
chainID := uint64(4924)
// Prepare a raw mock pre-image oracle that will serve the pre-image data and handle hints // Prepare a raw mock pre-image oracle that will serve the pre-image data and handle hints
// Check if blocks with txs work // Check if blocks with txs work
hints.On("hint", BlockHeaderHint(block.Hash()).Hint()).Once().Return() if hintL2ChainIDs {
hints.On("hint", TransactionsHint(block.Hash()).Hint()).Once().Return() hints.On("hint", BlockHeaderHint{Hash: block.Hash(), ChainID: chainID}.Hint()).Once().Return()
gotBlock := po.BlockByHash(block.Hash()) hints.On("hint", TransactionsHint{Hash: block.Hash(), ChainID: chainID}.Hint()).Once().Return()
} else {
hints.On("hint", LegacyBlockHeaderHint(block.Hash()).Hint()).Once().Return()
hints.On("hint", LegacyTransactionsHint(block.Hash()).Hint()).Once().Return()
}
gotBlock := po.BlockByHash(block.Hash(), chainID)
hints.AssertExpectations(t) hints.AssertExpectations(t)
require.Equal(t, gotBlock.Hash(), block.Hash()) require.Equal(t, gotBlock.Hash(), block.Hash())
...@@ -75,8 +80,12 @@ func TestPreimageOracleBlockByHash(t *testing.T) { ...@@ -75,8 +80,12 @@ func TestPreimageOracleBlockByHash(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
block, _ := testutils.RandomBlock(rng, 10) block, _ := testutils.RandomBlock(rng, 10)
t.Run(fmt.Sprintf("legacy_block_%d", i), func(t *testing.T) {
testBlock(t, block, false)
})
t.Run(fmt.Sprintf("block_%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("block_%d", i), func(t *testing.T) {
testBlock(t, block) testBlock(t, block, true)
}) })
} }
} }
...@@ -85,8 +94,24 @@ func TestPreimageOracleNodeByHash(t *testing.T) { ...@@ -85,8 +94,24 @@ func TestPreimageOracleNodeByHash(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_node_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
node := make([]byte, 123)
rng.Read(node)
h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", LegacyStateNodeHint(h).Hint()).Once().Return()
gotNode := po.NodeByHash(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "node matches")
})
t.Run(fmt.Sprintf("node_%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("node_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t) po, hints, preimages := mockPreimageOracle(t, true)
node := make([]byte, 123) node := make([]byte, 123)
rng.Read(node) rng.Read(node)
...@@ -94,8 +119,8 @@ func TestPreimageOracleNodeByHash(t *testing.T) { ...@@ -94,8 +119,8 @@ func TestPreimageOracleNodeByHash(t *testing.T) {
h := crypto.Keccak256Hash(node) h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", StateNodeHint(h).Hint()).Once().Return() hints.On("hint", StateNodeHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotNode := po.NodeByHash(h) gotNode := po.NodeByHash(h, chainID)
hints.AssertExpectations(t) hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "node matches") require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "node matches")
}) })
...@@ -106,8 +131,24 @@ func TestPreimageOracleCodeByHash(t *testing.T) { ...@@ -106,8 +131,24 @@ func TestPreimageOracleCodeByHash(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_code_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
node := make([]byte, 123)
rng.Read(node)
h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", LegacyCodeHint(h).Hint()).Once().Return()
gotNode := po.CodeByHash(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "code matches")
})
t.Run(fmt.Sprintf("code_%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("code_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t) po, hints, preimages := mockPreimageOracle(t, true)
node := make([]byte, 123) node := make([]byte, 123)
rng.Read(node) rng.Read(node)
...@@ -115,8 +156,8 @@ func TestPreimageOracleCodeByHash(t *testing.T) { ...@@ -115,8 +156,8 @@ func TestPreimageOracleCodeByHash(t *testing.T) {
h := crypto.Keccak256Hash(node) h := crypto.Keccak256Hash(node)
preimages[preimage.Keccak256Key(h).PreimageKey()] = node preimages[preimage.Keccak256Key(h).PreimageKey()] = node
hints.On("hint", CodeHint(h).Hint()).Once().Return() hints.On("hint", CodeHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotNode := po.CodeByHash(h) gotNode := po.CodeByHash(h, chainID)
hints.AssertExpectations(t) hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "code matches") require.Equal(t, hexutil.Bytes(node), hexutil.Bytes(gotNode), "code matches")
}) })
...@@ -127,14 +168,26 @@ func TestPreimageOracleOutputByRoot(t *testing.T) { ...@@ -127,14 +168,26 @@ func TestPreimageOracleOutputByRoot(t *testing.T) {
rng := rand.New(rand.NewSource(123)) rng := rand.New(rand.NewSource(123))
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
chainID := rng.Uint64()
t.Run(fmt.Sprintf("legacy_output_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t, false)
output := testutils.RandomOutputV0(rng)
h := common.Hash(eth.OutputRoot(output))
preimages[preimage.Keccak256Key(h).PreimageKey()] = output.Marshal()
hints.On("hint", LegacyL2OutputHint(h).Hint()).Once().Return()
gotOutput := po.OutputByRoot(h, chainID)
hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(output.Marshal()), hexutil.Bytes(gotOutput.Marshal()), "output matches")
})
t.Run(fmt.Sprintf("output_%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("output_%d", i), func(t *testing.T) {
po, hints, preimages := mockPreimageOracle(t) po, hints, preimages := mockPreimageOracle(t, true)
output := testutils.RandomOutputV0(rng) output := testutils.RandomOutputV0(rng)
h := common.Hash(eth.OutputRoot(output)) h := common.Hash(eth.OutputRoot(output))
preimages[preimage.Keccak256Key(h).PreimageKey()] = output.Marshal() preimages[preimage.Keccak256Key(h).PreimageKey()] = output.Marshal()
hints.On("hint", L2OutputHint(h).Hint()).Once().Return() hints.On("hint", L2OutputHint{Hash: h, ChainID: chainID}.Hint()).Once().Return()
gotOutput := po.OutputByRoot(h) gotOutput := po.OutputByRoot(h, chainID)
hints.AssertExpectations(t) hints.AssertExpectations(t)
require.Equal(t, hexutil.Bytes(output.Marshal()), hexutil.Bytes(gotOutput.Marshal()), "output matches") require.Equal(t, hexutil.Bytes(output.Marshal()), hexutil.Bytes(gotOutput.Marshal()), "output matches")
}) })
......
...@@ -15,8 +15,8 @@ import ( ...@@ -15,8 +15,8 @@ import (
// Same as l2.StateOracle but need to use our own copy to avoid dependency loops // Same as l2.StateOracle but need to use our own copy to avoid dependency loops
type stateOracle interface { type stateOracle interface {
NodeByHash(nodeHash common.Hash) []byte NodeByHash(nodeHash common.Hash, chainID uint64) []byte
CodeByHash(codeHash common.Hash) []byte CodeByHash(codeHash common.Hash, chainID uint64) []byte
} }
type StubBlockOracle struct { type StubBlockOracle struct {
...@@ -56,7 +56,7 @@ func NewStubOracleWithBlocks(t *testing.T, chain []*gethTypes.Block, outputs []e ...@@ -56,7 +56,7 @@ func NewStubOracleWithBlocks(t *testing.T, chain []*gethTypes.Block, outputs []e
} }
} }
func (o StubBlockOracle) BlockByHash(blockHash common.Hash) *gethTypes.Block { func (o StubBlockOracle) BlockByHash(blockHash common.Hash, chainID uint64) *gethTypes.Block {
block, ok := o.Blocks[blockHash] block, ok := o.Blocks[blockHash]
if !ok { if !ok {
o.t.Fatalf("requested unknown block %s", blockHash) o.t.Fatalf("requested unknown block %s", blockHash)
...@@ -64,7 +64,7 @@ func (o StubBlockOracle) BlockByHash(blockHash common.Hash) *gethTypes.Block { ...@@ -64,7 +64,7 @@ func (o StubBlockOracle) BlockByHash(blockHash common.Hash) *gethTypes.Block {
return block return block
} }
func (o StubBlockOracle) OutputByRoot(root common.Hash) eth.Output { func (o StubBlockOracle) OutputByRoot(root common.Hash, chainID uint64) eth.Output {
output, ok := o.Outputs[root] output, ok := o.Outputs[root]
if !ok { if !ok {
o.t.Fatalf("requested unknown output root %s", root) o.t.Fatalf("requested unknown output root %s", root)
...@@ -100,7 +100,7 @@ func NewKvStateOracle(t *testing.T, db ethdb.KeyValueStore) *KvStateOracle { ...@@ -100,7 +100,7 @@ func NewKvStateOracle(t *testing.T, db ethdb.KeyValueStore) *KvStateOracle {
} }
} }
func (o *KvStateOracle) NodeByHash(nodeHash common.Hash) []byte { func (o *KvStateOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
val, err := o.Source.Get(nodeHash.Bytes()) val, err := o.Source.Get(nodeHash.Bytes())
if err != nil { if err != nil {
o.t.Fatalf("error retrieving node %v: %v", nodeHash, err) o.t.Fatalf("error retrieving node %v: %v", nodeHash, err)
...@@ -108,7 +108,7 @@ func (o *KvStateOracle) NodeByHash(nodeHash common.Hash) []byte { ...@@ -108,7 +108,7 @@ func (o *KvStateOracle) NodeByHash(nodeHash common.Hash) []byte {
return val return val
} }
func (o *KvStateOracle) CodeByHash(hash common.Hash) []byte { func (o *KvStateOracle) CodeByHash(hash common.Hash, chainID uint64) []byte {
return rawdb.ReadCode(o.Source, hash) return rawdb.ReadCode(o.Source, hash)
} }
...@@ -127,7 +127,7 @@ type StubStateOracle struct { ...@@ -127,7 +127,7 @@ type StubStateOracle struct {
Code map[common.Hash][]byte Code map[common.Hash][]byte
} }
func (o *StubStateOracle) NodeByHash(nodeHash common.Hash) []byte { func (o *StubStateOracle) NodeByHash(nodeHash common.Hash, chainID uint64) []byte {
data, ok := o.Data[nodeHash] data, ok := o.Data[nodeHash]
if !ok { if !ok {
o.t.Fatalf("no value for node %v", nodeHash) o.t.Fatalf("no value for node %v", nodeHash)
...@@ -135,7 +135,7 @@ func (o *StubStateOracle) NodeByHash(nodeHash common.Hash) []byte { ...@@ -135,7 +135,7 @@ func (o *StubStateOracle) NodeByHash(nodeHash common.Hash) []byte {
return data return data
} }
func (o *StubStateOracle) CodeByHash(hash common.Hash) []byte { func (o *StubStateOracle) CodeByHash(hash common.Hash, chainID uint64) []byte {
data, ok := o.Code[hash] data, ok := o.Code[hash]
if !ok { if !ok {
o.t.Fatalf("no value for code %v", hash) o.t.Fatalf("no value for code %v", hash)
......
...@@ -45,7 +45,7 @@ func RunProgram(logger log.Logger, preimageOracle io.ReadWriter, preimageHinter ...@@ -45,7 +45,7 @@ func RunProgram(logger log.Logger, preimageOracle io.ReadWriter, preimageHinter
pClient := preimage.NewOracleClient(preimageOracle) pClient := preimage.NewOracleClient(preimageOracle)
hClient := preimage.NewHintWriter(preimageHinter) hClient := preimage.NewHintWriter(preimageHinter)
l1PreimageOracle := l1.NewCachingOracle(l1.NewPreimageOracle(pClient, hClient)) l1PreimageOracle := l1.NewCachingOracle(l1.NewPreimageOracle(pClient, hClient))
l2PreimageOracle := l2.NewCachingOracle(l2.NewPreimageOracle(pClient, hClient)) l2PreimageOracle := l2.NewCachingOracle(l2.NewPreimageOracle(pClient, hClient, cfg.InteropEnabled))
if cfg.InteropEnabled { if cfg.InteropEnabled {
bootInfo := boot.BootstrapInterop(pClient) bootInfo := boot.BootstrapInterop(pClient)
......
...@@ -32,3 +32,7 @@ func NewL2Client(client client.RPC, log log.Logger, metrics caching.Metrics, con ...@@ -32,3 +32,7 @@ func NewL2Client(client client.RPC, log log.Logger, metrics caching.Metrics, con
func (s *L2Client) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error) { func (s *L2Client) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error) {
return s.OutputV0AtBlock(ctx, blockRoot) return s.OutputV0AtBlock(ctx, blockRoot)
} }
func (s *L2Client) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
return s.OutputV0AtBlockNumber(ctx, blockNum)
}
...@@ -117,6 +117,14 @@ func (l *L2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth ...@@ -117,6 +117,14 @@ func (l *L2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth
return l.canonicalEthClient.OutputByRoot(ctx, blockRoot) return l.canonicalEthClient.OutputByRoot(ctx, blockRoot)
} }
// OutputByBlockNumber implements prefetcher.L2Source.
func (l *L2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
if l.ExperimentalEnabled() {
return l.experimentalClient.OutputByNumber(ctx, blockNum)
}
return l.canonicalEthClient.OutputByNumber(ctx, blockNum)
}
// ExecutionWitness implements prefetcher.L2Source. // ExecutionWitness implements prefetcher.L2Source.
func (l *L2Source) ExecutionWitness(ctx context.Context, blockNum uint64) (*eth.ExecutionWitness, error) { func (l *L2Source) ExecutionWitness(ctx context.Context, blockNum uint64) (*eth.ExecutionWitness, error) {
if !l.ExperimentalEnabled() { if !l.ExperimentalEnabled() {
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"strings" "strings"
preimage "github.com/ethereum-optimism/optimism/op-preimage" preimage "github.com/ethereum-optimism/optimism/op-preimage"
clientTypes "github.com/ethereum-optimism/optimism/op-program/client/interop/types"
"github.com/ethereum-optimism/optimism/op-program/client/l1" "github.com/ethereum-optimism/optimism/op-program/client/l1"
"github.com/ethereum-optimism/optimism/op-program/client/l2" "github.com/ethereum-optimism/optimism/op-program/client/l2"
"github.com/ethereum-optimism/optimism/op-program/client/mpt" "github.com/ethereum-optimism/optimism/op-program/client/mpt"
...@@ -258,11 +259,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -258,11 +259,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
} }
return p.kvStore.Put(preimage.PrecompileKey(inputHash).PreimageKey(), result) return p.kvStore.Put(preimage.PrecompileKey(inputHash).PreimageKey(), result)
case l2.HintL2BlockHeader, l2.HintL2Transactions: case l2.HintL2BlockHeader, l2.HintL2Transactions:
if len(hintBytes) != 32 { hash, chainID, err := p.parseHashAndChainID("L2 header/tx", hintBytes)
return fmt.Errorf("invalid L2 header/tx hint: %x", hint) if err != nil {
return err
} }
hash := common.Hash(hintBytes) source, err := p.l2Sources.ForChainID(chainID)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
if err != nil { if err != nil {
return err return err
} }
...@@ -280,11 +281,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -280,11 +281,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
} }
return p.storeTransactions(txs) return p.storeTransactions(txs)
case l2.HintL2StateNode: case l2.HintL2StateNode:
if len(hintBytes) != 32 { hash, chainID, err := p.parseHashAndChainID("L2 state node", hintBytes)
return fmt.Errorf("invalid L2 state node hint: %x", hint) if err != nil {
return err
} }
hash := common.Hash(hintBytes) source, err := p.l2Sources.ForChainID(chainID)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
if err != nil { if err != nil {
return err return err
} }
...@@ -294,11 +295,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -294,11 +295,11 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
} }
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), node) return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), node)
case l2.HintL2Code: case l2.HintL2Code:
if len(hintBytes) != 32 { hash, chainID, err := p.parseHashAndChainID("L2 code", hintBytes)
return fmt.Errorf("invalid L2 code hint: %x", hint) if err != nil {
return err
} }
hash := common.Hash(hintBytes) source, err := p.l2Sources.ForChainID(chainID)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
if err != nil { if err != nil {
return err return err
} }
...@@ -308,23 +309,47 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -308,23 +309,47 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
} }
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), code) return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), code)
case l2.HintL2Output: case l2.HintL2Output:
if len(hintBytes) != 32 { requestedHash, chainID, err := p.parseHashAndChainID("L2 output", hintBytes)
return fmt.Errorf("invalid L2 output hint: %x", hint)
}
requestedHash := common.Hash(hintBytes)
source, err := p.l2Sources.ForChainID(p.defaultChainID)
if err != nil { if err != nil {
return err return err
} }
output, err := source.OutputByRoot(ctx, p.l2Head) source, err := p.l2Sources.ForChainID(chainID)
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch L2 output root for block %s: %w", p.l2Head, err) return err
} }
hash := common.Hash(eth.OutputRoot(output)) if len(p.agreedPrestate) == 0 {
if requestedHash != hash { output, err := source.OutputByRoot(ctx, p.l2Head)
return fmt.Errorf("output root %v from block %v does not match requested root: %v", hash, p.l2Head, requestedHash) if err != nil {
return fmt.Errorf("failed to fetch L2 output root for block %s: %w", p.l2Head, err)
}
hash := common.Hash(eth.OutputRoot(output))
if requestedHash != hash {
return fmt.Errorf("output root %v from block %v does not match requested root: %v", hash, p.l2Head, requestedHash)
}
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), output.Marshal())
} else {
prestate, err := clientTypes.UnmarshalTransitionState(p.agreedPrestate)
if err != nil {
return fmt.Errorf("cannot fetch output root, invalid agreed prestate: %w", err)
}
superRoot, err := eth.UnmarshalSuperRoot(prestate.SuperRoot)
if err != nil {
return fmt.Errorf("cannot fetch output root, invalid super root in prestate: %w", err)
}
superV1, ok := superRoot.(*eth.SuperV1)
if !ok {
return fmt.Errorf("cannot fetch output root, unsupported super root version in prestate: %v", superRoot.Version())
}
blockNum, err := source.RollupConfig().TargetBlockNumber(superV1.Timestamp)
if err != nil {
return fmt.Errorf("cannot fetch output root, failed to calculate block number at timestamp %v: %w", superV1.Timestamp, err)
}
output, err := source.OutputByNumber(ctx, blockNum)
if err != nil {
return fmt.Errorf("failed to fetch L2 output root for block %v: %w", blockNum, err)
}
return p.kvStore.Put(preimage.Keccak256Key(eth.OutputRoot(output)).PreimageKey(), output.Marshal())
} }
return p.kvStore.Put(preimage.Keccak256Key(hash).PreimageKey(), output.Marshal())
case l2.HintL2BlockData: case l2.HintL2BlockData:
if p.executor == nil { if p.executor == nil {
return fmt.Errorf("this prefetcher does not support native block execution") return fmt.Errorf("this prefetcher does not support native block execution")
...@@ -353,6 +378,17 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error { ...@@ -353,6 +378,17 @@ func (p *Prefetcher) prefetch(ctx context.Context, hint string) error {
return fmt.Errorf("unknown hint type: %v", hintType) return fmt.Errorf("unknown hint type: %v", hintType)
} }
func (p *Prefetcher) parseHashAndChainID(hintType string, hintBytes []byte) (common.Hash, uint64, error) {
switch len(hintBytes) {
case 32:
return common.Hash(hintBytes), p.defaultChainID, nil
case 40:
return common.Hash(hintBytes[0:32]), binary.BigEndian.Uint64(hintBytes[32:]), nil
default:
return common.Hash{}, 0, fmt.Errorf("invalid %s hint: %x", hintType, hintBytes)
}
}
type BlockDataKey [32]byte type BlockDataKey [32]byte
func (p BlockDataKey) Key() [32]byte { func (p BlockDataKey) Key() [32]byte {
......
...@@ -152,6 +152,17 @@ func (s *RetryingL2Source) OutputByRoot(ctx context.Context, blockRoot common.Ha ...@@ -152,6 +152,17 @@ func (s *RetryingL2Source) OutputByRoot(ctx context.Context, blockRoot common.Ha
}) })
} }
func (s *RetryingL2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
return retry.Do(ctx, maxAttempts, s.strategy, func() (eth.Output, error) {
o, err := s.source.OutputByNumber(ctx, blockNum)
if err != nil {
s.logger.Warn("Failed to fetch l2 output", "block", blockNum, "err", err)
return o, err
}
return o, nil
})
}
func NewRetryingL2Source(logger log.Logger, source hosttypes.L2Source) *RetryingL2Source { func NewRetryingL2Source(logger log.Logger, source hosttypes.L2Source) *RetryingL2Source {
return &RetryingL2Source{ return &RetryingL2Source{
logger: logger, logger: logger,
......
...@@ -226,7 +226,8 @@ func createL1BlobSource(t *testing.T) (*RetryingL1BlobSource, *testutils.MockBlo ...@@ -226,7 +226,8 @@ func createL1BlobSource(t *testing.T) (*RetryingL1BlobSource, *testutils.MockBlo
func TestRetryingL2Source(t *testing.T) { func TestRetryingL2Source(t *testing.T) {
ctx := context.Background() ctx := context.Background()
hash := common.Hash{0xab} hash := common.Hash{0xab}
info := &testutils.MockBlockInfo{InfoHash: hash} blockNum := uint64(14982)
info := &testutils.MockBlockInfo{InfoHash: hash, InfoNum: blockNum}
// The mock really doesn't like returning nil for a eth.BlockInfo so return a value we expect to be ignored instead // The mock really doesn't like returning nil for a eth.BlockInfo so return a value we expect to be ignored instead
wrongInfo := &testutils.MockBlockInfo{InfoHash: common.Hash{0x99}} wrongInfo := &testutils.MockBlockInfo{InfoHash: common.Hash{0x99}}
txs := types.Transactions{ txs := types.Transactions{
...@@ -325,6 +326,28 @@ func TestRetryingL2Source(t *testing.T) { ...@@ -325,6 +326,28 @@ func TestRetryingL2Source(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, output, actualOutput) require.Equal(t, output, actualOutput)
}) })
t.Run("OutputByNumber Success", func(t *testing.T) {
source, mock := createL2Source(t)
defer mock.AssertExpectations(t)
mock.ExpectOutputByNumber(blockNum, output, nil)
actualOutput, err := source.OutputByNumber(ctx, blockNum)
require.NoError(t, err)
require.Equal(t, output, actualOutput)
})
t.Run("OutputByNumber Error", func(t *testing.T) {
source, mock := createL2Source(t)
defer mock.AssertExpectations(t)
expectedErr := errors.New("boom")
mock.ExpectOutputByNumber(blockNum, wrongOutput, expectedErr)
mock.ExpectOutputByNumber(blockNum, output, nil)
actualOutput, err := source.OutputByNumber(ctx, blockNum)
require.NoError(t, err)
require.Equal(t, output, actualOutput)
})
} }
func createL2Source(t *testing.T) (*RetryingL2Source, *MockL2Source) { func createL2Source(t *testing.T) (*RetryingL2Source, *MockL2Source) {
...@@ -370,6 +393,11 @@ func (m *MockL2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash) ...@@ -370,6 +393,11 @@ func (m *MockL2Source) OutputByRoot(ctx context.Context, blockRoot common.Hash)
return out[0].(eth.Output), *out[1].(*error) return out[0].(eth.Output), *out[1].(*error)
} }
func (m *MockL2Source) OutputByNumber(ctx context.Context, blockNum uint64) (eth.Output, error) {
out := m.Mock.MethodCalled("OutputByNumber", blockNum)
return out[0].(eth.Output), *out[1].(*error)
}
func (m *MockL2Source) ExpectInfoAndTxsByHash(blockHash common.Hash, info eth.BlockInfo, txs types.Transactions, err error) { func (m *MockL2Source) ExpectInfoAndTxsByHash(blockHash common.Hash, info eth.BlockInfo, txs types.Transactions, err error) {
m.Mock.On("InfoAndTxsByHash", blockHash).Once().Return(info, txs, &err) m.Mock.On("InfoAndTxsByHash", blockHash).Once().Return(info, txs, &err)
} }
...@@ -386,4 +414,8 @@ func (m *MockL2Source) ExpectOutputByRoot(blockHash common.Hash, output eth.Outp ...@@ -386,4 +414,8 @@ func (m *MockL2Source) ExpectOutputByRoot(blockHash common.Hash, output eth.Outp
m.Mock.On("OutputByRoot", blockHash).Once().Return(output, &err) m.Mock.On("OutputByRoot", blockHash).Once().Return(output, &err)
} }
func (m *MockL2Source) ExpectOutputByNumber(blockNum uint64, output eth.Output, err error) {
m.Mock.On("OutputByNumber", blockNum).Once().Return(output, &err)
}
var _ hosttypes.L2Source = (*MockL2Source)(nil) var _ hosttypes.L2Source = (*MockL2Source)(nil)
...@@ -24,6 +24,7 @@ type L2Source interface { ...@@ -24,6 +24,7 @@ type L2Source interface {
NodeByHash(ctx context.Context, hash common.Hash) ([]byte, error) NodeByHash(ctx context.Context, hash common.Hash) ([]byte, error)
CodeByHash(ctx context.Context, hash common.Hash) ([]byte, error) CodeByHash(ctx context.Context, hash common.Hash) ([]byte, error)
OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error) OutputByRoot(ctx context.Context, blockRoot common.Hash) (eth.Output, error)
OutputByNumber(ctx context.Context, blockNumber uint64) (eth.Output, error)
RollupConfig() *rollup.Config RollupConfig() *rollup.Config
ExperimentalEnabled() bool ExperimentalEnabled() bool
} }
......
...@@ -172,15 +172,26 @@ func (s *L2Client) SystemConfigByL2Hash(ctx context.Context, hash common.Hash) ( ...@@ -172,15 +172,26 @@ func (s *L2Client) SystemConfigByL2Hash(ctx context.Context, hash common.Hash) (
return cfg, nil return cfg, nil
} }
func (s *L2Client) OutputV0AtBlockNumber(ctx context.Context, blockNum uint64) (*eth.OutputV0, error) {
head, err := s.InfoByNumber(ctx, blockNum)
if err != nil {
return nil, fmt.Errorf("failed to get L2 block by hash: %w", err)
}
return s.outputV0(ctx, head)
}
func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) (*eth.OutputV0, error) { func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) (*eth.OutputV0, error) {
head, err := s.InfoByHash(ctx, blockHash) head, err := s.InfoByHash(ctx, blockHash)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get L2 block by hash: %w", err) return nil, fmt.Errorf("failed to get L2 block by hash: %w", err)
} }
if head == nil { return s.outputV0(ctx, head)
}
func (s *L2Client) outputV0(ctx context.Context, block eth.BlockInfo) (*eth.OutputV0, error) {
if block == nil {
return nil, ethereum.NotFound return nil, ethereum.NotFound
} }
blockHash := block.Hash()
proof, err := s.GetProof(ctx, predeploys.L2ToL1MessagePasserAddr, []common.Hash{}, blockHash.String()) proof, err := s.GetProof(ctx, predeploys.L2ToL1MessagePasserAddr, []common.Hash{}, blockHash.String())
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get contract proof at block %s: %w", blockHash, err) return nil, fmt.Errorf("failed to get contract proof at block %s: %w", blockHash, err)
...@@ -189,10 +200,10 @@ func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) ( ...@@ -189,10 +200,10 @@ func (s *L2Client) OutputV0AtBlock(ctx context.Context, blockHash common.Hash) (
return nil, fmt.Errorf("proof %w", ethereum.NotFound) return nil, fmt.Errorf("proof %w", ethereum.NotFound)
} }
// make sure that the proof (including storage hash) that we retrieved is correct by verifying it against the state-root // make sure that the proof (including storage hash) that we retrieved is correct by verifying it against the state-root
if err := proof.Verify(head.Root()); err != nil { if err := proof.Verify(block.Root()); err != nil {
return nil, fmt.Errorf("invalid withdrawal root hash, state root was %s: %w", head.Root(), err) return nil, fmt.Errorf("invalid withdrawal root hash, state root was %s: %w", block.Root(), err)
} }
stateRoot := head.Root() stateRoot := block.Root()
return &eth.OutputV0{ return &eth.OutputV0{
StateRoot: eth.Bytes32(stateRoot), StateRoot: eth.Bytes32(stateRoot),
MessagePasserStorageRoot: eth.Bytes32(proof.StorageHash), MessagePasserStorageRoot: eth.Bytes32(proof.StorageHash),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment