Commit 95801790 authored by Axel Kingsley's avatar Axel Kingsley Committed by GitHub

Refactor Derivation DB Reset Functions (#13874)

parent 57c7d168
......@@ -209,6 +209,7 @@ const (
func (m *ManagedMode) Reset(ctx context.Context, unsafe, safe, finalized eth.BlockID) error {
logger := m.log.New("unsafe", unsafe, "safe", safe, "finalized", finalized)
logger.Debug("Received reset request", "unsafe", unsafe, "safe", safe, "finalized", finalized)
verify := func(ref eth.BlockID, name string) (eth.L2BlockRef, error) {
result, err := m.l2.L2BlockRefByNumber(ctx, ref.Number)
......
......@@ -567,3 +567,8 @@ func (su *SupervisorBackend) PullFinalizedL1() error {
func (su *SupervisorBackend) SetConfDepthL1(depth uint64) {
su.l1Accessor.SetConfDepth(depth)
}
// Rewind rolls back the state of the supervisor for the given chain.
func (su *SupervisorBackend) Rewind(chain eth.ChainID, block eth.BlockID) error {
return su.chainDBs.Rewind(chain, block)
}
......@@ -61,6 +61,7 @@ type LocalDerivedFromStorage interface {
NextDerived(derived eth.BlockID) (next types.DerivedBlockSealPair, err error)
PreviousDerivedFrom(derivedFrom eth.BlockID) (prevDerivedFrom types.BlockSeal, err error)
PreviousDerived(derived eth.BlockID) (prevDerived types.BlockSeal, err error)
RewindToL2(derived uint64) error
}
var _ LocalDerivedFromStorage = (*fromda.DB)(nil)
......
......@@ -51,22 +51,6 @@ func NewFromEntryStore(logger log.Logger, m Metrics, store EntryStore) (*DB, err
return db, nil
}
// Rewind to the last entry that was derived from a L1 block with the given block number.
func (db *DB) Rewind(derivedFrom uint64) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
index, _, err := db.lastDerivedAt(derivedFrom)
if err != nil {
return fmt.Errorf("failed to find point to rewind to: %w", err)
}
err = db.store.Truncate(index)
if err != nil {
return err
}
db.m.RecordDBDerivedEntryCount(int64(index) + 1)
return nil
}
// First returns the first known values, alike to Latest.
func (db *DB) First() (pair types.DerivedBlockSealPair, err error) {
db.rwLock.RLock()
......
......@@ -679,17 +679,17 @@ func TestRewind(t *testing.T) {
require.Equal(t, l2Block2, pair.Derived)
// Rewind to the future
require.ErrorIs(t, db.Rewind(6), types.ErrFuture)
require.ErrorIs(t, db.RewindToL1(6), types.ErrFuture)
// Rewind to the exact block we're at
require.NoError(t, db.Rewind(l1Block5.Number))
require.NoError(t, db.RewindToL1(l1Block5.Number))
pair, err = db.Latest()
require.NoError(t, err)
require.Equal(t, l1Block5, pair.DerivedFrom)
require.Equal(t, l2Block2, pair.Derived)
// Now rewind to L1 block 3 (inclusive).
require.NoError(t, db.Rewind(l1Block3.Number))
require.NoError(t, db.RewindToL1(l1Block3.Number))
// See if we find consistent data
pair, err = db.Latest()
......@@ -698,14 +698,14 @@ func TestRewind(t *testing.T) {
require.Equal(t, l2Block1, pair.Derived)
// Rewind further to L1 block 1 (inclusive).
require.NoError(t, db.Rewind(l1Block1.Number))
require.NoError(t, db.RewindToL1(l1Block1.Number))
pair, err = db.Latest()
require.NoError(t, err)
require.Equal(t, l1Block1, pair.DerivedFrom)
require.Equal(t, l2Block1, pair.Derived)
// Rewind further to L1 block 0 (inclusive).
require.NoError(t, db.Rewind(l1Block0.Number))
require.NoError(t, db.RewindToL1(l1Block0.Number))
pair, err = db.Latest()
require.NoError(t, err)
require.Equal(t, l1Block0, pair.DerivedFrom)
......
......@@ -66,25 +66,83 @@ func (db *DB) ReplaceInvalidatedBlock(replacementDerived eth.BlockRef, invalidat
func (db *DB) RewindAndInvalidate(invalidated types.DerivedBlockRefPair) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
i, link, err := db.lookup(invalidated.DerivedFrom.Number, invalidated.Derived.Number)
invalidatedSeals := types.DerivedBlockSealPair{
DerivedFrom: types.BlockSealFromRef(invalidated.DerivedFrom),
Derived: types.BlockSealFromRef(invalidated.Derived),
}
if err := db.rewindLocked(invalidatedSeals, true); err != nil {
return err
}
if err := db.addLink(invalidated.DerivedFrom, invalidated.Derived, invalidated.Derived.Hash); err != nil {
return fmt.Errorf("failed to add invalidation entry %s: %w", invalidated, err)
}
return nil
}
// Rewind rolls back the database to the target, including the target if the including flag is set.
// it locks the DB and calls rewindLocked.
func (db *DB) Rewind(target types.DerivedBlockSealPair, including bool) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
return db.rewindLocked(target, including)
}
// RewindToL2 rewinds to the first entry where the L2 block with the given number was derived.
func (db *DB) RewindToL2(derived uint64) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
_, link, err := db.firstDerivedFrom(derived)
if err != nil {
return fmt.Errorf("failed to find last derived-from %d: %w", derived, err)
}
return db.rewindLocked(types.DerivedBlockSealPair{
DerivedFrom: link.derivedFrom,
Derived: link.derived,
}, false)
}
// RewindToL1 rewinds to the last entry that was derived from a L1 block with the given block number.
func (db *DB) RewindToL1(derivedFrom uint64) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
_, link, err := db.lastDerivedAt(derivedFrom)
if err != nil {
return fmt.Errorf("failed to find last derived-from %d: %w", derivedFrom, err)
}
return db.rewindLocked(types.DerivedBlockSealPair{
DerivedFrom: link.derivedFrom,
Derived: link.derived,
}, false)
}
// rewindLocked performs the truncate operation to a specified block seal pair.
// data beyond the specified block seal pair is truncated from the database.
// if including is true, the block seal pair itself is removed as well.
// Note: This function must be called with the rwLock held.
// Callers are responsible for locking and unlocking the Database.
func (db *DB) rewindLocked(t types.DerivedBlockSealPair, including bool) error {
i, link, err := db.lookup(t.DerivedFrom.Number, t.Derived.Number)
if err != nil {
return err
}
if link.derivedFrom.Hash != invalidated.DerivedFrom.Hash {
if link.derivedFrom.Hash != t.DerivedFrom.Hash {
return fmt.Errorf("found derived-from %s, but expected %s: %w",
link.derivedFrom, invalidated.DerivedFrom, types.ErrConflict)
link.derivedFrom, t.DerivedFrom, types.ErrConflict)
}
if link.derived.Hash != invalidated.Derived.Hash {
if link.derived.Hash != t.Derived.Hash {
return fmt.Errorf("found derived %s, but expected %s: %w",
link.derived, invalidated.Derived, types.ErrConflict)
link.derived, t.Derived, types.ErrConflict)
}
if err := db.store.Truncate(i - 1); err != nil {
return fmt.Errorf("failed to rewind upon block invalidation of %s: %w", invalidated, err)
// adjust the target index to include the block seal pair itself if requested
target := i
if including {
target = i - 1
}
db.m.RecordDBDerivedEntryCount(int64(i))
if err := db.addLink(invalidated.DerivedFrom, invalidated.Derived, invalidated.Derived.Hash); err != nil {
return fmt.Errorf("failed to add invalidation entry %s: %w", invalidated, err)
if err := db.store.Truncate(target); err != nil {
return fmt.Errorf("failed to rewind upon block invalidation of %s: %w", t, err)
}
db.m.RecordDBDerivedEntryCount(int64(target) + 1)
return nil
}
......
......@@ -41,11 +41,33 @@ func (db *ChainsDB) SealBlock(chain eth.ChainID, block eth.BlockRef) error {
}
func (db *ChainsDB) Rewind(chain eth.ChainID, headBlock eth.BlockID) error {
// Rewind the logDB
logDB, ok := db.logDBs.Get(chain)
if !ok {
return fmt.Errorf("cannot Rewind: %w: %s", types.ErrUnknownChain, chain)
}
return logDB.Rewind(headBlock)
if err := logDB.Rewind(headBlock); err != nil {
return fmt.Errorf("failed to rewind to block %v: %w", headBlock, err)
}
// Rewind the localDB
localDB, ok := db.localDBs.Get(chain)
if !ok {
return fmt.Errorf("cannot Rewind (localDB not found): %w: %s", types.ErrUnknownChain, chain)
}
if err := localDB.RewindToL2(headBlock.Number); err != nil {
return fmt.Errorf("failed to rewind localDB to block %v: %w", headBlock, err)
}
// Rewind the crossDB
crossDB, ok := db.crossDBs.Get(chain)
if !ok {
return fmt.Errorf("cannot Rewind (crossDB not found): %w: %s", types.ErrUnknownChain, chain)
}
if err := crossDB.RewindToL2(headBlock.Number); err != nil {
return fmt.Errorf("failed to rewind crossDB to block %v: %w", headBlock, err)
}
return nil
}
func (db *ChainsDB) UpdateLocalSafe(chain eth.ChainID, derivedFrom eth.BlockRef, lastDerived eth.BlockRef) {
......@@ -57,7 +79,7 @@ func (db *ChainsDB) UpdateLocalSafe(chain eth.ChainID, derivedFrom eth.BlockRef,
}
logger.Debug("Updating local safe DB")
if err := localDB.AddDerived(derivedFrom, lastDerived); err != nil {
db.logger.Warn("Failed to update local safe")
db.logger.Warn("Failed to update local safe", "err", err)
db.emitter.Emit(superevents.LocalSafeOutOfSyncEvent{
ChainID: chain,
L1Ref: derivedFrom,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment