Commit 95801790 authored by Axel Kingsley's avatar Axel Kingsley Committed by GitHub

Refactor Derivation DB Reset Functions (#13874)

parent 57c7d168
...@@ -209,6 +209,7 @@ const ( ...@@ -209,6 +209,7 @@ const (
func (m *ManagedMode) Reset(ctx context.Context, unsafe, safe, finalized eth.BlockID) error { func (m *ManagedMode) Reset(ctx context.Context, unsafe, safe, finalized eth.BlockID) error {
logger := m.log.New("unsafe", unsafe, "safe", safe, "finalized", finalized) logger := m.log.New("unsafe", unsafe, "safe", safe, "finalized", finalized)
logger.Debug("Received reset request", "unsafe", unsafe, "safe", safe, "finalized", finalized)
verify := func(ref eth.BlockID, name string) (eth.L2BlockRef, error) { verify := func(ref eth.BlockID, name string) (eth.L2BlockRef, error) {
result, err := m.l2.L2BlockRefByNumber(ctx, ref.Number) result, err := m.l2.L2BlockRefByNumber(ctx, ref.Number)
......
...@@ -567,3 +567,8 @@ func (su *SupervisorBackend) PullFinalizedL1() error { ...@@ -567,3 +567,8 @@ func (su *SupervisorBackend) PullFinalizedL1() error {
func (su *SupervisorBackend) SetConfDepthL1(depth uint64) { func (su *SupervisorBackend) SetConfDepthL1(depth uint64) {
su.l1Accessor.SetConfDepth(depth) su.l1Accessor.SetConfDepth(depth)
} }
// Rewind rolls back the state of the supervisor for the given chain.
func (su *SupervisorBackend) Rewind(chain eth.ChainID, block eth.BlockID) error {
return su.chainDBs.Rewind(chain, block)
}
...@@ -61,6 +61,7 @@ type LocalDerivedFromStorage interface { ...@@ -61,6 +61,7 @@ type LocalDerivedFromStorage interface {
NextDerived(derived eth.BlockID) (next types.DerivedBlockSealPair, err error) NextDerived(derived eth.BlockID) (next types.DerivedBlockSealPair, err error)
PreviousDerivedFrom(derivedFrom eth.BlockID) (prevDerivedFrom types.BlockSeal, err error) PreviousDerivedFrom(derivedFrom eth.BlockID) (prevDerivedFrom types.BlockSeal, err error)
PreviousDerived(derived eth.BlockID) (prevDerived types.BlockSeal, err error) PreviousDerived(derived eth.BlockID) (prevDerived types.BlockSeal, err error)
RewindToL2(derived uint64) error
} }
var _ LocalDerivedFromStorage = (*fromda.DB)(nil) var _ LocalDerivedFromStorage = (*fromda.DB)(nil)
......
...@@ -51,22 +51,6 @@ func NewFromEntryStore(logger log.Logger, m Metrics, store EntryStore) (*DB, err ...@@ -51,22 +51,6 @@ func NewFromEntryStore(logger log.Logger, m Metrics, store EntryStore) (*DB, err
return db, nil return db, nil
} }
// Rewind to the last entry that was derived from a L1 block with the given block number.
func (db *DB) Rewind(derivedFrom uint64) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
index, _, err := db.lastDerivedAt(derivedFrom)
if err != nil {
return fmt.Errorf("failed to find point to rewind to: %w", err)
}
err = db.store.Truncate(index)
if err != nil {
return err
}
db.m.RecordDBDerivedEntryCount(int64(index) + 1)
return nil
}
// First returns the first known values, alike to Latest. // First returns the first known values, alike to Latest.
func (db *DB) First() (pair types.DerivedBlockSealPair, err error) { func (db *DB) First() (pair types.DerivedBlockSealPair, err error) {
db.rwLock.RLock() db.rwLock.RLock()
......
...@@ -679,17 +679,17 @@ func TestRewind(t *testing.T) { ...@@ -679,17 +679,17 @@ func TestRewind(t *testing.T) {
require.Equal(t, l2Block2, pair.Derived) require.Equal(t, l2Block2, pair.Derived)
// Rewind to the future // Rewind to the future
require.ErrorIs(t, db.Rewind(6), types.ErrFuture) require.ErrorIs(t, db.RewindToL1(6), types.ErrFuture)
// Rewind to the exact block we're at // Rewind to the exact block we're at
require.NoError(t, db.Rewind(l1Block5.Number)) require.NoError(t, db.RewindToL1(l1Block5.Number))
pair, err = db.Latest() pair, err = db.Latest()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, l1Block5, pair.DerivedFrom) require.Equal(t, l1Block5, pair.DerivedFrom)
require.Equal(t, l2Block2, pair.Derived) require.Equal(t, l2Block2, pair.Derived)
// Now rewind to L1 block 3 (inclusive). // Now rewind to L1 block 3 (inclusive).
require.NoError(t, db.Rewind(l1Block3.Number)) require.NoError(t, db.RewindToL1(l1Block3.Number))
// See if we find consistent data // See if we find consistent data
pair, err = db.Latest() pair, err = db.Latest()
...@@ -698,14 +698,14 @@ func TestRewind(t *testing.T) { ...@@ -698,14 +698,14 @@ func TestRewind(t *testing.T) {
require.Equal(t, l2Block1, pair.Derived) require.Equal(t, l2Block1, pair.Derived)
// Rewind further to L1 block 1 (inclusive). // Rewind further to L1 block 1 (inclusive).
require.NoError(t, db.Rewind(l1Block1.Number)) require.NoError(t, db.RewindToL1(l1Block1.Number))
pair, err = db.Latest() pair, err = db.Latest()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, l1Block1, pair.DerivedFrom) require.Equal(t, l1Block1, pair.DerivedFrom)
require.Equal(t, l2Block1, pair.Derived) require.Equal(t, l2Block1, pair.Derived)
// Rewind further to L1 block 0 (inclusive). // Rewind further to L1 block 0 (inclusive).
require.NoError(t, db.Rewind(l1Block0.Number)) require.NoError(t, db.RewindToL1(l1Block0.Number))
pair, err = db.Latest() pair, err = db.Latest()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, l1Block0, pair.DerivedFrom) require.Equal(t, l1Block0, pair.DerivedFrom)
......
...@@ -66,25 +66,83 @@ func (db *DB) ReplaceInvalidatedBlock(replacementDerived eth.BlockRef, invalidat ...@@ -66,25 +66,83 @@ func (db *DB) ReplaceInvalidatedBlock(replacementDerived eth.BlockRef, invalidat
func (db *DB) RewindAndInvalidate(invalidated types.DerivedBlockRefPair) error { func (db *DB) RewindAndInvalidate(invalidated types.DerivedBlockRefPair) error {
db.rwLock.Lock() db.rwLock.Lock()
defer db.rwLock.Unlock() defer db.rwLock.Unlock()
i, link, err := db.lookup(invalidated.DerivedFrom.Number, invalidated.Derived.Number)
invalidatedSeals := types.DerivedBlockSealPair{
DerivedFrom: types.BlockSealFromRef(invalidated.DerivedFrom),
Derived: types.BlockSealFromRef(invalidated.Derived),
}
if err := db.rewindLocked(invalidatedSeals, true); err != nil {
return err
}
if err := db.addLink(invalidated.DerivedFrom, invalidated.Derived, invalidated.Derived.Hash); err != nil {
return fmt.Errorf("failed to add invalidation entry %s: %w", invalidated, err)
}
return nil
}
// Rewind rolls back the database to the target, including the target if the including flag is set.
// it locks the DB and calls rewindLocked.
func (db *DB) Rewind(target types.DerivedBlockSealPair, including bool) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
return db.rewindLocked(target, including)
}
// RewindToL2 rewinds to the first entry where the L2 block with the given number was derived.
func (db *DB) RewindToL2(derived uint64) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
_, link, err := db.firstDerivedFrom(derived)
if err != nil {
return fmt.Errorf("failed to find last derived-from %d: %w", derived, err)
}
return db.rewindLocked(types.DerivedBlockSealPair{
DerivedFrom: link.derivedFrom,
Derived: link.derived,
}, false)
}
// RewindToL1 rewinds to the last entry that was derived from a L1 block with the given block number.
func (db *DB) RewindToL1(derivedFrom uint64) error {
db.rwLock.Lock()
defer db.rwLock.Unlock()
_, link, err := db.lastDerivedAt(derivedFrom)
if err != nil {
return fmt.Errorf("failed to find last derived-from %d: %w", derivedFrom, err)
}
return db.rewindLocked(types.DerivedBlockSealPair{
DerivedFrom: link.derivedFrom,
Derived: link.derived,
}, false)
}
// rewindLocked performs the truncate operation to a specified block seal pair.
// data beyond the specified block seal pair is truncated from the database.
// if including is true, the block seal pair itself is removed as well.
// Note: This function must be called with the rwLock held.
// Callers are responsible for locking and unlocking the Database.
func (db *DB) rewindLocked(t types.DerivedBlockSealPair, including bool) error {
i, link, err := db.lookup(t.DerivedFrom.Number, t.Derived.Number)
if err != nil { if err != nil {
return err return err
} }
if link.derivedFrom.Hash != invalidated.DerivedFrom.Hash { if link.derivedFrom.Hash != t.DerivedFrom.Hash {
return fmt.Errorf("found derived-from %s, but expected %s: %w", return fmt.Errorf("found derived-from %s, but expected %s: %w",
link.derivedFrom, invalidated.DerivedFrom, types.ErrConflict) link.derivedFrom, t.DerivedFrom, types.ErrConflict)
} }
if link.derived.Hash != invalidated.Derived.Hash { if link.derived.Hash != t.Derived.Hash {
return fmt.Errorf("found derived %s, but expected %s: %w", return fmt.Errorf("found derived %s, but expected %s: %w",
link.derived, invalidated.Derived, types.ErrConflict) link.derived, t.Derived, types.ErrConflict)
} }
if err := db.store.Truncate(i - 1); err != nil { // adjust the target index to include the block seal pair itself if requested
return fmt.Errorf("failed to rewind upon block invalidation of %s: %w", invalidated, err) target := i
if including {
target = i - 1
} }
db.m.RecordDBDerivedEntryCount(int64(i)) if err := db.store.Truncate(target); err != nil {
if err := db.addLink(invalidated.DerivedFrom, invalidated.Derived, invalidated.Derived.Hash); err != nil { return fmt.Errorf("failed to rewind upon block invalidation of %s: %w", t, err)
return fmt.Errorf("failed to add invalidation entry %s: %w", invalidated, err)
} }
db.m.RecordDBDerivedEntryCount(int64(target) + 1)
return nil return nil
} }
......
...@@ -41,11 +41,33 @@ func (db *ChainsDB) SealBlock(chain eth.ChainID, block eth.BlockRef) error { ...@@ -41,11 +41,33 @@ func (db *ChainsDB) SealBlock(chain eth.ChainID, block eth.BlockRef) error {
} }
func (db *ChainsDB) Rewind(chain eth.ChainID, headBlock eth.BlockID) error { func (db *ChainsDB) Rewind(chain eth.ChainID, headBlock eth.BlockID) error {
// Rewind the logDB
logDB, ok := db.logDBs.Get(chain) logDB, ok := db.logDBs.Get(chain)
if !ok { if !ok {
return fmt.Errorf("cannot Rewind: %w: %s", types.ErrUnknownChain, chain) return fmt.Errorf("cannot Rewind: %w: %s", types.ErrUnknownChain, chain)
} }
return logDB.Rewind(headBlock) if err := logDB.Rewind(headBlock); err != nil {
return fmt.Errorf("failed to rewind to block %v: %w", headBlock, err)
}
// Rewind the localDB
localDB, ok := db.localDBs.Get(chain)
if !ok {
return fmt.Errorf("cannot Rewind (localDB not found): %w: %s", types.ErrUnknownChain, chain)
}
if err := localDB.RewindToL2(headBlock.Number); err != nil {
return fmt.Errorf("failed to rewind localDB to block %v: %w", headBlock, err)
}
// Rewind the crossDB
crossDB, ok := db.crossDBs.Get(chain)
if !ok {
return fmt.Errorf("cannot Rewind (crossDB not found): %w: %s", types.ErrUnknownChain, chain)
}
if err := crossDB.RewindToL2(headBlock.Number); err != nil {
return fmt.Errorf("failed to rewind crossDB to block %v: %w", headBlock, err)
}
return nil
} }
func (db *ChainsDB) UpdateLocalSafe(chain eth.ChainID, derivedFrom eth.BlockRef, lastDerived eth.BlockRef) { func (db *ChainsDB) UpdateLocalSafe(chain eth.ChainID, derivedFrom eth.BlockRef, lastDerived eth.BlockRef) {
...@@ -57,7 +79,7 @@ func (db *ChainsDB) UpdateLocalSafe(chain eth.ChainID, derivedFrom eth.BlockRef, ...@@ -57,7 +79,7 @@ func (db *ChainsDB) UpdateLocalSafe(chain eth.ChainID, derivedFrom eth.BlockRef,
} }
logger.Debug("Updating local safe DB") logger.Debug("Updating local safe DB")
if err := localDB.AddDerived(derivedFrom, lastDerived); err != nil { if err := localDB.AddDerived(derivedFrom, lastDerived); err != nil {
db.logger.Warn("Failed to update local safe") db.logger.Warn("Failed to update local safe", "err", err)
db.emitter.Emit(superevents.LocalSafeOutOfSyncEvent{ db.emitter.Emit(superevents.LocalSafeOutOfSyncEvent{
ChainID: chain, ChainID: chain,
L1Ref: derivedFrom, L1Ref: derivedFrom,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment