Commit 7ddf9417 authored by Adrian Sutton's avatar Adrian Sutton

op-program: Cache blocks by number

Avoids repeatedly iterating from the head block back to the requested block number.
parent 3d990967
...@@ -18,16 +18,20 @@ var ( ...@@ -18,16 +18,20 @@ var (
) )
type OracleL1Client struct { type OracleL1Client struct {
oracle Oracle oracle Oracle
head eth.L1BlockRef head eth.L1BlockRef
hashByNum map[uint64]common.Hash
earliestIndexedBlock eth.L1BlockRef
} }
func NewOracleL1Client(logger log.Logger, oracle Oracle, l1Head common.Hash) *OracleL1Client { func NewOracleL1Client(logger log.Logger, oracle Oracle, l1Head common.Hash) *OracleL1Client {
head := eth.InfoToL1BlockRef(oracle.HeaderByBlockHash(l1Head)) head := eth.InfoToL1BlockRef(oracle.HeaderByBlockHash(l1Head))
logger.Info("L1 head loaded", "hash", head.Hash, "number", head.Number) logger.Info("L1 head loaded", "hash", head.Hash, "number", head.Number)
return &OracleL1Client{ return &OracleL1Client{
oracle: oracle, oracle: oracle,
head: head, head: head,
hashByNum: map[uint64]common.Hash{head.Number: head.Hash},
earliestIndexedBlock: head,
} }
} }
...@@ -43,9 +47,15 @@ func (o *OracleL1Client) L1BlockRefByNumber(ctx context.Context, number uint64) ...@@ -43,9 +47,15 @@ func (o *OracleL1Client) L1BlockRefByNumber(ctx context.Context, number uint64)
if number > o.head.Number { if number > o.head.Number {
return eth.L1BlockRef{}, fmt.Errorf("%w: block number %d", ErrNotFound, number) return eth.L1BlockRef{}, fmt.Errorf("%w: block number %d", ErrNotFound, number)
} }
block := o.head hash, ok := o.hashByNum[number]
if ok {
return o.L1BlockRefByHash(ctx, hash)
}
block := o.earliestIndexedBlock
for block.Number > number { for block.Number > number {
block = eth.InfoToL1BlockRef(o.oracle.HeaderByBlockHash(block.ParentHash)) block = eth.InfoToL1BlockRef(o.oracle.HeaderByBlockHash(block.ParentHash))
o.hashByNum[block.Number] = block.Hash
o.earliestIndexedBlock = block
} }
return block, nil return block, nil
} }
......
...@@ -126,8 +126,7 @@ func TestL1BlockRefByNumber(t *testing.T) { ...@@ -126,8 +126,7 @@ func TestL1BlockRefByNumber(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, eth.InfoToL1BlockRef(parent), ref) require.Equal(t, eth.InfoToL1BlockRef(parent), ref)
}) })
t.Run("AncestorOfHead", func(t *testing.T) { createBlocks := func(oracle *test.StubOracle) []eth.BlockInfo {
client, oracle := newClient(t)
block := head block := head
blocks := []eth.BlockInfo{block} blocks := []eth.BlockInfo{block}
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
...@@ -135,6 +134,11 @@ func TestL1BlockRefByNumber(t *testing.T) { ...@@ -135,6 +134,11 @@ func TestL1BlockRefByNumber(t *testing.T) {
oracle.Blocks[block.Hash()] = block oracle.Blocks[block.Hash()] = block
blocks = append(blocks, block) blocks = append(blocks, block)
} }
return blocks
}
t.Run("AncestorsAccessForwards", func(t *testing.T) {
client, oracle := newClient(t)
blocks := createBlocks(oracle)
for _, block := range blocks { for _, block := range blocks {
ref, err := client.L1BlockRefByNumber(context.Background(), block.NumberU64()) ref, err := client.L1BlockRefByNumber(context.Background(), block.NumberU64())
...@@ -142,6 +146,17 @@ func TestL1BlockRefByNumber(t *testing.T) { ...@@ -142,6 +146,17 @@ func TestL1BlockRefByNumber(t *testing.T) {
require.Equal(t, eth.InfoToL1BlockRef(block), ref) require.Equal(t, eth.InfoToL1BlockRef(block), ref)
} }
}) })
t.Run("AncestorsAccessReverse", func(t *testing.T) {
client, oracle := newClient(t)
blocks := createBlocks(oracle)
for i := len(blocks) - 1; i >= 0; i-- {
block := blocks[i]
ref, err := client.L1BlockRefByNumber(context.Background(), block.NumberU64())
require.NoError(t, err)
require.Equal(t, eth.InfoToL1BlockRef(block), ref)
}
})
} }
func newClient(t *testing.T) (*OracleL1Client, *test.StubOracle) { func newClient(t *testing.T) (*OracleL1Client, *test.StubOracle) {
......
...@@ -28,6 +28,10 @@ type OracleBackedL2Chain struct { ...@@ -28,6 +28,10 @@ type OracleBackedL2Chain struct {
finalized *types.Header finalized *types.Header
vmCfg vm.Config vmCfg vm.Config
// Block by number cache
hashByNum map[uint64]common.Hash
earliestIndexedBlock *types.Header
// Inserted blocks // Inserted blocks
blocks map[common.Hash]*types.Block blocks map[common.Hash]*types.Block
db ethdb.KeyValueStore db ethdb.KeyValueStore
...@@ -44,6 +48,11 @@ func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, chainCfg *params.C ...@@ -44,6 +48,11 @@ func NewOracleBackedL2Chain(logger log.Logger, oracle Oracle, chainCfg *params.C
chainCfg: chainCfg, chainCfg: chainCfg,
engine: beacon.New(nil), engine: beacon.New(nil),
hashByNum: map[uint64]common.Hash{
head.NumberU64(): head.Hash(),
},
earliestIndexedBlock: head.Header(),
// Treat the agreed starting head as finalized - nothing before it can be disputed // Treat the agreed starting head as finalized - nothing before it can be disputed
head: head.Header(), head: head.Header(),
safe: head.Header(), safe: head.Header(),
...@@ -59,14 +68,20 @@ func (o *OracleBackedL2Chain) CurrentHeader() *types.Header { ...@@ -59,14 +68,20 @@ func (o *OracleBackedL2Chain) CurrentHeader() *types.Header {
} }
func (o *OracleBackedL2Chain) GetHeaderByNumber(n uint64) *types.Header { func (o *OracleBackedL2Chain) GetHeaderByNumber(n uint64) *types.Header {
// Walk back from current head to the requested block number if o.head.Number.Uint64() < n {
h := o.head
if h.Number.Uint64() < n {
return nil return nil
} }
hash, ok := o.hashByNum[n]
if ok {
return o.GetHeaderByHash(hash)
}
// Walk back from current head to the requested block number
h := o.head
for h.Number.Uint64() > n { for h.Number.Uint64() > n {
h = o.GetHeaderByHash(h.ParentHash) h = o.GetHeaderByHash(h.ParentHash)
o.hashByNum[h.Number.Uint64()] = h.Hash()
} }
o.earliestIndexedBlock = h
return h return h
} }
...@@ -176,7 +191,28 @@ func (o *OracleBackedL2Chain) InsertBlockWithoutSetHead(block *types.Block) erro ...@@ -176,7 +191,28 @@ func (o *OracleBackedL2Chain) InsertBlockWithoutSetHead(block *types.Block) erro
} }
func (o *OracleBackedL2Chain) SetCanonical(head *types.Block) (common.Hash, error) { func (o *OracleBackedL2Chain) SetCanonical(head *types.Block) (common.Hash, error) {
oldHead := o.head
o.head = head.Header() o.head = head.Header()
// Remove canonical hashes after the new header
for n := head.NumberU64() + 1; n <= oldHead.Number.Uint64(); n++ {
delete(o.hashByNum, n)
}
// Add new canonical blocks to the block by number cache
// Since the original head is added to the block number cache and acts as the finalized block,
// at some point we must reach the existing canonical chain and can stop updating.
h := o.head
for {
newHash := h.Hash()
prevHash, ok := o.hashByNum[h.Number.Uint64()]
if ok && prevHash == newHash {
// Connected with the existing canonical chain so stop updating
break
}
o.hashByNum[h.Number.Uint64()] = newHash
h = o.GetHeaderByHash(h.ParentHash)
}
return head.Hash(), nil return head.Hash(), nil
} }
......
...@@ -123,6 +123,66 @@ func TestRejectBlockWithStateRootMismatch(t *testing.T) { ...@@ -123,6 +123,66 @@ func TestRejectBlockWithStateRootMismatch(t *testing.T) {
require.ErrorContains(t, err, "block root mismatch") require.ErrorContains(t, err, "block root mismatch")
} }
func TestGetHeaderByNumber(t *testing.T) {
t.Run("Forwards", func(t *testing.T) {
blocks, chain := setupOracleBackedChain(t, 10)
for _, block := range blocks {
result := chain.GetHeaderByNumber(block.NumberU64())
require.Equal(t, block.Header(), result)
}
})
t.Run("Reverse", func(t *testing.T) {
blocks, chain := setupOracleBackedChain(t, 10)
for i := len(blocks) - 1; i >= 0; i-- {
block := blocks[i]
result := chain.GetHeaderByNumber(block.NumberU64())
require.Equal(t, block.Header(), result)
}
})
t.Run("AppendedBlock", func(t *testing.T) {
_, chain := setupOracleBackedChain(t, 10)
// Append a block
newBlock := createBlock(t, chain)
require.NoError(t, chain.InsertBlockWithoutSetHead(newBlock))
_, err := chain.SetCanonical(newBlock)
require.NoError(t, err)
require.Equal(t, newBlock.Header(), chain.GetHeaderByNumber(newBlock.NumberU64()))
})
t.Run("AppendedBlockAfterLookup", func(t *testing.T) {
blocks, chain := setupOracleBackedChain(t, 10)
// Look up an early block to prime the block cache
require.Equal(t, blocks[0].Header(), chain.GetHeaderByNumber(blocks[0].NumberU64()))
// Append a block
newBlock := createBlock(t, chain)
require.NoError(t, chain.InsertBlockWithoutSetHead(newBlock))
_, err := chain.SetCanonical(newBlock)
require.NoError(t, err)
require.Equal(t, newBlock.Header(), chain.GetHeaderByNumber(newBlock.NumberU64()))
})
t.Run("AppendedMultipleBlocks", func(t *testing.T) {
blocks, chain := setupOracleBackedChainWithLowerHead(t, 5, 2)
// Append a few blocks
newBlock1 := blocks[3]
newBlock2 := blocks[4]
newBlock3 := blocks[5]
require.NoError(t, chain.InsertBlockWithoutSetHead(newBlock1))
require.NoError(t, chain.InsertBlockWithoutSetHead(newBlock2))
require.NoError(t, chain.InsertBlockWithoutSetHead(newBlock3))
_, err := chain.SetCanonical(newBlock3)
require.NoError(t, err)
require.Equal(t, newBlock3.Header(), chain.GetHeaderByNumber(newBlock3.NumberU64()), "Lookup block3")
require.Equal(t, newBlock2.Header(), chain.GetHeaderByNumber(newBlock2.NumberU64()), "Lookup block2")
require.Equal(t, newBlock1.Header(), chain.GetHeaderByNumber(newBlock1.NumberU64()), "Lookup block1")
})
}
func assertBlockDataAvailable(t *testing.T, chain *OracleBackedL2Chain, block *types.Block, blockNumber uint64) { func assertBlockDataAvailable(t *testing.T, chain *OracleBackedL2Chain, block *types.Block, blockNumber uint64) {
require.Equal(t, block, chain.GetBlockByHash(block.Hash()), "get block %v by hash", blockNumber) require.Equal(t, block, chain.GetBlockByHash(block.Hash()), "get block %v by hash", blockNumber)
require.Equal(t, block.Header(), chain.GetHeaderByHash(block.Hash()), "get header %v by hash", blockNumber) require.Equal(t, block.Header(), chain.GetHeaderByHash(block.Hash()), "get header %v by hash", blockNumber)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment