Commit fbabaf8a authored by Adrian Sutton's avatar Adrian Sutton

op-program: Modify L2 oracle to not return error

parent 0c0cf54b
...@@ -40,12 +40,12 @@ func (o *OracleKeyValueStore) Get(key []byte) ([]byte, error) { ...@@ -40,12 +40,12 @@ func (o *OracleKeyValueStore) Get(key []byte) ([]byte, error) {
if len(key) == codePrefixedKeyLength && bytes.HasPrefix(key, rawdb.CodePrefix) { if len(key) == codePrefixedKeyLength && bytes.HasPrefix(key, rawdb.CodePrefix) {
key = key[len(rawdb.CodePrefix):] key = key[len(rawdb.CodePrefix):]
return o.oracle.CodeByHash(*(*[common.HashLength]byte)(key)) return o.oracle.CodeByHash(*(*[common.HashLength]byte)(key)), nil
} }
if len(key) != common.HashLength { if len(key) != common.HashLength {
return nil, ErrInvalidKeyLength return nil, ErrInvalidKeyLength
} }
return o.oracle.NodeByHash(*(*[common.HashLength]byte)(key)) return o.oracle.NodeByHash(*(*[common.HashLength]byte)(key)), nil
} }
func (o *OracleKeyValueStore) NewBatch() ethdb.Batch { func (o *OracleKeyValueStore) NewBatch() ethdb.Batch {
......
package l2 package l2
import ( import (
"fmt"
"math/big" "math/big"
"testing" "testing"
...@@ -27,16 +26,8 @@ var ( ...@@ -27,16 +26,8 @@ var (
var _ ethdb.KeyValueStore = (*OracleKeyValueStore)(nil) var _ ethdb.KeyValueStore = (*OracleKeyValueStore)(nil)
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
t.Run("UnknownKey", func(t *testing.T) {
oracle := newStubStateOracle()
db := NewOracleBackedDB(oracle)
val, err := db.Get(common.Hash{}.Bytes())
require.Error(t, err)
require.Nil(t, val)
})
t.Run("IncorrectLengthKey", func(t *testing.T) { t.Run("IncorrectLengthKey", func(t *testing.T) {
oracle := newStubStateOracle() oracle := newStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle)
val, err := db.Get([]byte{1, 2, 3}) val, err := db.Get([]byte{1, 2, 3})
require.ErrorIs(t, err, ErrInvalidKeyLength) require.ErrorIs(t, err, ErrInvalidKeyLength)
...@@ -44,7 +35,7 @@ func TestGet(t *testing.T) { ...@@ -44,7 +35,7 @@ func TestGet(t *testing.T) {
}) })
t.Run("KeyWithCodePrefix", func(t *testing.T) { t.Run("KeyWithCodePrefix", func(t *testing.T) {
oracle := newStubStateOracle() oracle := newStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle)
key := common.HexToHash("0x12345678") key := common.HexToHash("0x12345678")
prefixedKey := append(rawdb.CodePrefix, key.Bytes()...) prefixedKey := append(rawdb.CodePrefix, key.Bytes()...)
...@@ -58,7 +49,7 @@ func TestGet(t *testing.T) { ...@@ -58,7 +49,7 @@ func TestGet(t *testing.T) {
}) })
t.Run("NormalKeyThatHappensToStartWithCodePrefix", func(t *testing.T) { t.Run("NormalKeyThatHappensToStartWithCodePrefix", func(t *testing.T) {
oracle := newStubStateOracle() oracle := newStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle)
key := make([]byte, common.HashLength) key := make([]byte, common.HashLength)
copy(rawdb.CodePrefix, key) copy(rawdb.CodePrefix, key)
...@@ -74,7 +65,7 @@ func TestGet(t *testing.T) { ...@@ -74,7 +65,7 @@ func TestGet(t *testing.T) {
t.Run("KnownKey", func(t *testing.T) { t.Run("KnownKey", func(t *testing.T) {
key := common.HexToHash("0xAA4488") key := common.HexToHash("0xAA4488")
expected := []byte{2, 6, 3, 8} expected := []byte{2, 6, 3, 8}
oracle := newStubStateOracle() oracle := newStubStateOracle(t)
oracle.data[key] = expected oracle.data[key] = expected
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle)
val, err := db.Get(key.Bytes()) val, err := db.Get(key.Bytes())
...@@ -85,7 +76,7 @@ func TestGet(t *testing.T) { ...@@ -85,7 +76,7 @@ func TestGet(t *testing.T) {
func TestPut(t *testing.T) { func TestPut(t *testing.T) {
t.Run("NewKey", func(t *testing.T) { t.Run("NewKey", func(t *testing.T) {
oracle := newStubStateOracle() oracle := newStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle)
key := common.HexToHash("0xAA4488") key := common.HexToHash("0xAA4488")
value := []byte{2, 6, 3, 8} value := []byte{2, 6, 3, 8}
...@@ -97,7 +88,7 @@ func TestPut(t *testing.T) { ...@@ -97,7 +88,7 @@ func TestPut(t *testing.T) {
require.Equal(t, value, actual) require.Equal(t, value, actual)
}) })
t.Run("ReplaceKey", func(t *testing.T) { t.Run("ReplaceKey", func(t *testing.T) {
oracle := newStubStateOracle() oracle := newStubStateOracle(t)
db := NewOracleBackedDB(oracle) db := NewOracleBackedDB(oracle)
key := common.HexToHash("0xAA4488") key := common.HexToHash("0xAA4488")
value1 := []byte{2, 6, 3, 8} value1 := []byte{2, 6, 3, 8}
...@@ -119,6 +110,7 @@ func TestSupportsStateDBOperations(t *testing.T) { ...@@ -119,6 +110,7 @@ func TestSupportsStateDBOperations(t *testing.T) {
genesisBlock := l2Genesis.MustCommit(realDb) genesisBlock := l2Genesis.MustCommit(realDb)
loader := &kvStateOracle{ loader := &kvStateOracle{
t: t,
source: realDb, source: realDb,
} }
assertStateDataAvailable(t, NewOracleBackedDB(loader), l2Genesis, genesisBlock) assertStateDataAvailable(t, NewOracleBackedDB(loader), l2Genesis, genesisBlock)
...@@ -126,7 +118,7 @@ func TestSupportsStateDBOperations(t *testing.T) { ...@@ -126,7 +118,7 @@ func TestSupportsStateDBOperations(t *testing.T) {
func TestUpdateState(t *testing.T) { func TestUpdateState(t *testing.T) {
l2Genesis := createGenesis() l2Genesis := createGenesis()
oracle := newStubStateOracle() oracle := newStubStateOracle(t)
db := rawdb.NewDatabase(NewOracleBackedDB(oracle)) db := rawdb.NewDatabase(NewOracleBackedDB(oracle))
genesisBlock := l2Genesis.MustCommit(db) genesisBlock := l2Genesis.MustCommit(db)
...@@ -203,43 +195,50 @@ func assertStateDataAvailable(t *testing.T, db ethdb.KeyValueStore, l2Genesis *c ...@@ -203,43 +195,50 @@ func assertStateDataAvailable(t *testing.T, db ethdb.KeyValueStore, l2Genesis *c
require.Equal(t, common.Hash{}, statedb.GetCodeHash(unknownAccount), "unset account code hash") require.Equal(t, common.Hash{}, statedb.GetCodeHash(unknownAccount), "unset account code hash")
} }
func newStubStateOracle() *stubStateOracle { func newStubStateOracle(t *testing.T) *stubStateOracle {
return &stubStateOracle{ return &stubStateOracle{
t: t,
data: make(map[common.Hash][]byte), data: make(map[common.Hash][]byte),
code: make(map[common.Hash][]byte), code: make(map[common.Hash][]byte),
} }
} }
type stubStateOracle struct { type stubStateOracle struct {
t *testing.T
data map[common.Hash][]byte data map[common.Hash][]byte
code map[common.Hash][]byte code map[common.Hash][]byte
} }
func (o *stubStateOracle) NodeByHash(nodeHash common.Hash) ([]byte, error) { func (o *stubStateOracle) NodeByHash(nodeHash common.Hash) []byte {
data, ok := o.data[nodeHash] data, ok := o.data[nodeHash]
if !ok { if !ok {
return nil, fmt.Errorf("no value for node %v", nodeHash) o.t.Fatalf("no value for node %v", nodeHash)
} }
return data, nil return data
} }
func (o *stubStateOracle) CodeByHash(hash common.Hash) ([]byte, error) { func (o *stubStateOracle) CodeByHash(hash common.Hash) []byte {
data, ok := o.code[hash] data, ok := o.code[hash]
if !ok { if !ok {
return nil, fmt.Errorf("no value for code %v", hash) o.t.Fatalf("no value for code %v", hash)
} }
return data, nil return data
} }
// kvStateOracle loads data from a source ethdb.KeyValueStore // kvStateOracle loads data from a source ethdb.KeyValueStore
type kvStateOracle struct { type kvStateOracle struct {
t *testing.T
source ethdb.KeyValueStore source ethdb.KeyValueStore
} }
func (o *kvStateOracle) NodeByHash(nodeHash common.Hash) ([]byte, error) { func (o *kvStateOracle) NodeByHash(nodeHash common.Hash) []byte {
return o.source.Get(nodeHash.Bytes()) val, err := o.source.Get(nodeHash.Bytes())
if err != nil {
o.t.Fatalf("error retrieving node %v: %v", nodeHash, err)
}
return val
} }
func (o *kvStateOracle) CodeByHash(hash common.Hash) ([]byte, error) { func (o *kvStateOracle) CodeByHash(hash common.Hash) []byte {
return rawdb.ReadCode(o.source, hash), nil return rawdb.ReadCode(o.source, hash)
} }
...@@ -35,10 +35,7 @@ type OracleBackedL2Chain struct { ...@@ -35,10 +35,7 @@ type OracleBackedL2Chain struct {
var _ engineapi.EngineBackend = (*OracleBackedL2Chain)(nil) var _ engineapi.EngineBackend = (*OracleBackedL2Chain)(nil)
func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, chainCfg *params.ChainConfig, l2Head common.Hash) (*OracleBackedL2Chain, error) { func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, chainCfg *params.ChainConfig, l2Head common.Hash) (*OracleBackedL2Chain, error) {
head, err := oracle.BlockByHash(l2Head) head := oracle.BlockByHash(l2Head)
if err != nil {
return nil, fmt.Errorf("loading l2 head: %w", err)
}
logger.Info("Loaded L2 head", "hash", head.Hash(), "number", head.Number()) logger.Info("Loaded L2 head", "hash", head.Hash(), "number", head.Number())
return &OracleBackedL2Chain{ return &OracleBackedL2Chain{
log: logger, log: logger,
...@@ -99,10 +96,7 @@ func (o *OracleBackedL2Chain) GetBlockByHash(hash common.Hash) *types.Block { ...@@ -99,10 +96,7 @@ func (o *OracleBackedL2Chain) GetBlockByHash(hash common.Hash) *types.Block {
return block return block
} }
// Retrieve from the oracle // Retrieve from the oracle
block, err := o.oracle.BlockByHash(hash) block = o.oracle.BlockByHash(hash)
if err != nil {
handleError(err)
}
if block == nil { if block == nil {
return nil return nil
} }
...@@ -195,7 +189,3 @@ func (o *OracleBackedL2Chain) SetFinalized(header *types.Header) { ...@@ -195,7 +189,3 @@ func (o *OracleBackedL2Chain) SetFinalized(header *types.Header) {
func (o *OracleBackedL2Chain) SetSafe(header *types.Header) { func (o *OracleBackedL2Chain) SetSafe(header *types.Header) {
o.safe = header o.safe = header
} }
func handleError(err error) {
panic(err)
}
...@@ -113,8 +113,9 @@ func TestUpdateStateDatabaseWhenImportingBlock(t *testing.T) { ...@@ -113,8 +113,9 @@ func TestUpdateStateDatabaseWhenImportingBlock(t *testing.T) {
require.NotEqual(t, blocks[1].Root(), newBlock.Root(), "block should have modified world state") require.NotEqual(t, blocks[1].Root(), newBlock.Root(), "block should have modified world state")
_, err = chain.StateAt(newBlock.Root()) require.Panics(t, func() {
require.Error(t, err, "state from non-imported block should not be available") _, _ = chain.StateAt(newBlock.Root())
}, "state from non-imported block should not be available")
err = chain.InsertBlockWithoutSetHead(newBlock) err = chain.InsertBlockWithoutSetHead(newBlock)
require.NoError(t, err) require.NoError(t, err)
...@@ -223,8 +224,8 @@ func newStubBlockOracle(chain []*types.Block, db ethdb.Database) *stubBlockOracl ...@@ -223,8 +224,8 @@ func newStubBlockOracle(chain []*types.Block, db ethdb.Database) *stubBlockOracl
} }
} }
func (o stubBlockOracle) BlockByHash(blockHash common.Hash) (*types.Block, error) { func (o stubBlockOracle) BlockByHash(blockHash common.Hash) *types.Block {
return o.blocks[blockHash], nil return o.blocks[blockHash]
} }
func TestEngineAPITests(t *testing.T) { func TestEngineAPITests(t *testing.T) {
......
...@@ -11,13 +11,11 @@ type StateOracle interface { ...@@ -11,13 +11,11 @@ type StateOracle interface {
// NodeByHash retrieves the merkle-patricia trie node pre-image for a given hash. // NodeByHash retrieves the merkle-patricia trie node pre-image for a given hash.
// Trie nodes may be from the world state trie or any account storage trie. // Trie nodes may be from the world state trie or any account storage trie.
// Contract code is not stored as part of the trie and must be retrieved via CodeByHash // Contract code is not stored as part of the trie and must be retrieved via CodeByHash
// Returns an error if the pre-image is unavailable. NodeByHash(nodeHash common.Hash) []byte
NodeByHash(nodeHash common.Hash) ([]byte, error)
// CodeByHash retrieves the contract code pre-image for a given hash. // CodeByHash retrieves the contract code pre-image for a given hash.
// codeHash should be retrieved from the world state account for a contract. // codeHash should be retrieved from the world state account for a contract.
// Returns an error if the pre-image is unavailable. CodeByHash(codeHash common.Hash) []byte
CodeByHash(codeHash common.Hash) ([]byte, error)
} }
// Oracle defines the high-level API used to retrieve L2 data. // Oracle defines the high-level API used to retrieve L2 data.
...@@ -26,6 +24,5 @@ type Oracle interface { ...@@ -26,6 +24,5 @@ type Oracle interface {
StateOracle StateOracle
// BlockByHash retrieves the block with the given hash. // BlockByHash retrieves the block with the given hash.
// Returns an error if the block is not available. BlockByHash(blockHash common.Hash) *types.Block
BlockByHash(blockHash common.Hash) (*types.Block, error)
} }
...@@ -42,19 +42,26 @@ func NewFetchingL2Oracle(ctx context.Context, logger log.Logger, l2Url string) ( ...@@ -42,19 +42,26 @@ func NewFetchingL2Oracle(ctx context.Context, logger log.Logger, l2Url string) (
}, nil }, nil
} }
func (o *FetchingL2Oracle) NodeByHash(hash common.Hash) ([]byte, error) { func (o *FetchingL2Oracle) NodeByHash(hash common.Hash) []byte {
// MPT nodes are stored as the hash of the node (with no prefix) // MPT nodes are stored as the hash of the node (with no prefix)
return o.dbGet(hash.Bytes()) node, err := o.dbGet(hash.Bytes())
if err != nil {
panic(err)
}
return node
} }
func (o *FetchingL2Oracle) CodeByHash(hash common.Hash) ([]byte, error) { func (o *FetchingL2Oracle) CodeByHash(hash common.Hash) []byte {
// First try retrieving with the new code prefix // First try retrieving with the new code prefix
code, err := o.dbGet(append(rawdb.CodePrefix, hash.Bytes()...)) code, err := o.dbGet(append(rawdb.CodePrefix, hash.Bytes()...))
if err != nil { if err != nil {
// Fallback to the legacy un-prefixed version // Fallback to the legacy un-prefixed version
return o.dbGet(hash.Bytes()) code, err = o.dbGet(hash.Bytes())
if err != nil {
panic(err)
}
} }
return code, nil return code
} }
func (o *FetchingL2Oracle) dbGet(key []byte) ([]byte, error) { func (o *FetchingL2Oracle) dbGet(key []byte) ([]byte, error) {
...@@ -66,10 +73,10 @@ func (o *FetchingL2Oracle) dbGet(key []byte) ([]byte, error) { ...@@ -66,10 +73,10 @@ func (o *FetchingL2Oracle) dbGet(key []byte) ([]byte, error) {
return node, nil return node, nil
} }
func (o *FetchingL2Oracle) BlockByHash(blockHash common.Hash) (*types.Block, error) { func (o *FetchingL2Oracle) BlockByHash(blockHash common.Hash) *types.Block {
block, err := o.blockSource.BlockByHash(o.ctx, blockHash) block, err := o.blockSource.BlockByHash(o.ctx, blockHash)
if err != nil { if err != nil {
return nil, fmt.Errorf("fetch block %s: %w", blockHash.Hex(), err) panic(fmt.Errorf("fetch block %s: %w", blockHash.Hex(), err))
} }
return block, nil return block
} }
...@@ -63,9 +63,9 @@ func TestNodeByHash(t *testing.T) { ...@@ -63,9 +63,9 @@ func TestNodeByHash(t *testing.T) {
} }
fetcher := newFetcher(nil, stub) fetcher := newFetcher(nil, stub)
node, err := fetcher.NodeByHash(hash) require.Panics(t, func() {
require.ErrorIs(t, err, stub.nextErr) fetcher.NodeByHash(hash)
require.Nil(t, node) })
}) })
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
...@@ -75,8 +75,7 @@ func TestNodeByHash(t *testing.T) { ...@@ -75,8 +75,7 @@ func TestNodeByHash(t *testing.T) {
} }
fetcher := newFetcher(nil, stub) fetcher := newFetcher(nil, stub)
node, err := fetcher.NodeByHash(hash) node := fetcher.NodeByHash(hash)
require.NoError(t, err)
require.EqualValues(t, expected, node) require.EqualValues(t, expected, node)
}) })
...@@ -86,7 +85,7 @@ func TestNodeByHash(t *testing.T) { ...@@ -86,7 +85,7 @@ func TestNodeByHash(t *testing.T) {
} }
fetcher := newFetcher(nil, stub) fetcher := newFetcher(nil, stub)
_, _ = fetcher.NodeByHash(hash) fetcher.NodeByHash(hash)
require.Len(t, stub.requests, 1, "should make single request") require.Len(t, stub.requests, 1, "should make single request")
req := stub.requests[0] req := stub.requests[0]
require.Equal(t, "debug_dbGet", req.method) require.Equal(t, "debug_dbGet", req.method)
...@@ -104,9 +103,7 @@ func TestCodeByHash(t *testing.T) { ...@@ -104,9 +103,7 @@ func TestCodeByHash(t *testing.T) {
} }
fetcher := newFetcher(nil, stub) fetcher := newFetcher(nil, stub)
node, err := fetcher.CodeByHash(hash) require.Panics(t, func() { fetcher.CodeByHash(hash) })
require.ErrorIs(t, err, stub.nextErr)
require.Nil(t, node)
}) })
t.Run("Success", func(t *testing.T) { t.Run("Success", func(t *testing.T) {
...@@ -116,8 +113,7 @@ func TestCodeByHash(t *testing.T) { ...@@ -116,8 +113,7 @@ func TestCodeByHash(t *testing.T) {
} }
fetcher := newFetcher(nil, stub) fetcher := newFetcher(nil, stub)
node, err := fetcher.CodeByHash(hash) node := fetcher.CodeByHash(hash)
require.NoError(t, err)
require.EqualValues(t, expected, node) require.EqualValues(t, expected, node)
}) })
...@@ -127,7 +123,7 @@ func TestCodeByHash(t *testing.T) { ...@@ -127,7 +123,7 @@ func TestCodeByHash(t *testing.T) {
} }
fetcher := newFetcher(nil, stub) fetcher := newFetcher(nil, stub)
_, _ = fetcher.CodeByHash(hash) fetcher.CodeByHash(hash)
require.Len(t, stub.requests, 1, "should make single request") require.Len(t, stub.requests, 1, "should make single request")
req := stub.requests[0] req := stub.requests[0]
require.Equal(t, "debug_dbGet", req.method) require.Equal(t, "debug_dbGet", req.method)
...@@ -141,7 +137,8 @@ func TestCodeByHash(t *testing.T) { ...@@ -141,7 +137,8 @@ func TestCodeByHash(t *testing.T) {
} }
fetcher := newFetcher(nil, stub) fetcher := newFetcher(nil, stub)
_, _ = fetcher.CodeByHash(hash) // Panics because the code can't be found with or without the prefix
require.Panics(t, func() { fetcher.CodeByHash(hash) })
require.Len(t, stub.requests, 2, "should request with and without prefix") require.Len(t, stub.requests, 2, "should request with and without prefix")
req := stub.requests[0] req := stub.requests[0]
require.Equal(t, "debug_dbGet", req.method) require.Equal(t, "debug_dbGet", req.method)
...@@ -183,8 +180,7 @@ func TestBlockByHash(t *testing.T) { ...@@ -183,8 +180,7 @@ func TestBlockByHash(t *testing.T) {
stub := &stubBlockSource{nextResult: block} stub := &stubBlockSource{nextResult: block}
fetcher := newFetcher(stub, nil) fetcher := newFetcher(stub, nil)
res, err := fetcher.BlockByHash(hash) res := fetcher.BlockByHash(hash)
require.NoError(t, err)
require.Same(t, block, res) require.Same(t, block, res)
}) })
...@@ -192,16 +188,16 @@ func TestBlockByHash(t *testing.T) { ...@@ -192,16 +188,16 @@ func TestBlockByHash(t *testing.T) {
stub := &stubBlockSource{nextErr: errors.New("boom")} stub := &stubBlockSource{nextErr: errors.New("boom")}
fetcher := newFetcher(stub, nil) fetcher := newFetcher(stub, nil)
res, err := fetcher.BlockByHash(hash) require.Panics(t, func() {
require.ErrorIs(t, err, stub.nextErr) fetcher.BlockByHash(hash)
require.Nil(t, res) })
}) })
t.Run("RequestArgs", func(t *testing.T) { t.Run("RequestArgs", func(t *testing.T) {
stub := &stubBlockSource{} stub := &stubBlockSource{}
fetcher := newFetcher(stub, nil) fetcher := newFetcher(stub, nil)
_, _ = fetcher.BlockByHash(hash) fetcher.BlockByHash(hash)
require.Len(t, stub.requests, 1, "should make single request") require.Len(t, stub.requests, 1, "should make single request")
req := stub.requests[0] req := stub.requests[0]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment