Commit 13e82ea5 authored by mergify[bot]'s avatar mergify[bot] Committed by GitHub

Merge pull request #5211 from ethereum-optimism/bugfix/migration-error-race

op-chain-ops: Fix data race in error handling
parents 7d0c54cd f19edc9f
...@@ -109,7 +109,7 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -109,7 +109,7 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
outCh := make(chan accountData) outCh := make(chan accountData)
// Channel to receive errors from each iteration job. // Channel to receive errors from each iteration job.
errCh := make(chan error, checkJobs) errCh := make(chan error, checkJobs)
// Channel to cancel all iteration jobs as well as the collector. // Channel to cancel all iteration jobs.
cancelCh := make(chan struct{}) cancelCh := make(chan struct{})
// Define a worker function to iterate over each partition. // Define a worker function to iterate over each partition.
...@@ -244,16 +244,17 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -244,16 +244,17 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
go worker(start, end) go worker(start, end)
} }
// Make a channel to make sure that the collector process completes. // Make a channel to track when collector process completes.
collectorCloseCh := make(chan struct{}) collectorClosedCh := make(chan struct{})
// Make a channel to cancel the collector process.
collectorCancelCh := make(chan struct{})
// Keep track of the last error seen. // Keep track of the last error seen.
var lastErr error var lastErr error
// There are multiple ways that the cancel channel can be closed: // The cancel channel can be closed if any of the workers returns an error.
// - if we receive an error from the errCh // We wrap the close in a sync.Once to ensure that it's only closed once.
// - if the collector process completes
// To prevent panics, we wrap the close in a sync.Once.
var cancelOnce sync.Once var cancelOnce sync.Once
// Create a map of accounts we've seen so that we can filter out duplicates. // Create a map of accounts we've seen so that we can filter out duplicates.
...@@ -268,7 +269,7 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -268,7 +269,7 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
progress := util.ProgressLogger(1000, "Migrated OVM_ETH storage slot") progress := util.ProgressLogger(1000, "Migrated OVM_ETH storage slot")
go func() { go func() {
defer func() { defer func() {
collectorCloseCh <- struct{}{} collectorClosedCh <- struct{}{}
}() }()
for { for {
select { select {
...@@ -291,10 +292,20 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -291,10 +292,20 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
seenAccounts[account.address] = true seenAccounts[account.address] = true
case err := <-errCh: case err := <-errCh:
cancelOnce.Do(func() { cancelOnce.Do(func() {
lastErr = err
close(cancelCh) close(cancelCh)
lastErr = err
}) })
case <-cancelCh: case <-collectorCancelCh:
// Explicitly drain the error channel. Since the error channel is buffered, it's possible
// for the wg.Wait() call below to unblock and cancel this goroutine before the error gets
// processed by the case statement above.
for len(errCh) > 0 {
err := <-errCh
if lastErr == nil {
lastErr = err
}
}
return return
} }
} }
...@@ -302,13 +313,10 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -302,13 +313,10 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
// Wait for the workers to finish. // Wait for the workers to finish.
wg.Wait() wg.Wait()
// Close the cancel channel to signal the collector process to stop.
cancelOnce.Do(func() {
close(cancelCh)
})
// Wait for the collector process to finish. // Close the collector, and wait for it to finish.
<-collectorCloseCh close(collectorCancelCh)
<-collectorClosedCh
// If we saw an error, return it. // If we saw an error, return it.
if lastErr != nil { if lastErr != nil {
......
...@@ -228,44 +228,58 @@ func makeLegacyETH(t *testing.T, totalSupply *big.Int, balances map[common.Addre ...@@ -228,44 +228,58 @@ func makeLegacyETH(t *testing.T, totalSupply *big.Int, balances map[common.Addre
} }
} }
// TestMigrateBalancesRandom tests that the pre-check balances function works // TestMigrateBalancesRandomOK tests that the pre-check balances function works
// with random addresses. This test makes sure that the partition logic doesn't // with random addresses. This test makes sure that the partition logic doesn't
// miss anything. // miss anything, and helps detect concurrency errors.
func TestMigrateBalancesRandom(t *testing.T) { func TestMigrateBalancesRandomOK(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
addresses := make([]common.Address, 0) addresses, stateBalances, allowances, stateAllowances, totalSupply := setupRandTest(t)
stateBalances := make(map[common.Address]*big.Int)
allowances := make([]*crossdomain.Allowance, 0)
stateAllowances := make(map[common.Address]common.Address)
totalSupply := big.NewInt(0) db, factory := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances)
err := doMigration(db, factory, addresses, allowances, big.NewInt(0), false)
require.NoError(t, err)
for j := 0; j < rand.Intn(10000); j++ { for addr, expBal := range stateBalances {
addr := randAddr(t) actBal := db.GetBalance(addr)
addresses = append(addresses, addr) require.EqualValues(t, expBal, actBal)
stateBalances[addr] = big.NewInt(int64(rand.Intn(1_000_000))) }
totalSupply = new(big.Int).Add(totalSupply, stateBalances[addr])
} }
}
for j := 0; j < rand.Intn(1000); j++ { // TestMigrateBalancesRandomMissing tests that the pre-check balances function works
addr := randAddr(t) // with random addresses when some of them are missing. This helps make sure that the
to := randAddr(t) // partition logic doesn't miss anything, and helps detect concurrency errors.
allowances = append(allowances, &crossdomain.Allowance{ func TestMigrateBalancesRandomMissing(t *testing.T) {
From: addr, for i := 0; i < 100; i++ {
To: to, addresses, stateBalances, allowances, stateAllowances, totalSupply := setupRandTest(t)
})
stateAllowances[addr] = to if len(addresses) == 0 {
continue
} }
// Remove a random address from the list of witnesses
idx := rand.Intn(len(addresses))
addresses = append(addresses[:idx], addresses[idx+1:]...)
db, factory := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances) db, factory := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances)
err := doMigration(db, factory, addresses, allowances, big.NewInt(0), false) err := doMigration(db, factory, addresses, allowances, big.NewInt(0), false)
require.NoError(t, err) require.ErrorContains(t, err, "unknown storage slot")
}
for addr, expBal := range stateBalances { for i := 0; i < 100; i++ {
actBal := db.GetBalance(addr) addresses, stateBalances, allowances, stateAllowances, totalSupply := setupRandTest(t)
require.EqualValues(t, expBal, actBal)
if len(allowances) == 0 {
continue
} }
// Remove a random allowance from the list of witnesses
idx := rand.Intn(len(allowances))
allowances = append(allowances[:idx], allowances[idx+1:]...)
db, factory := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances)
err := doMigration(db, factory, addresses, allowances, big.NewInt(0), false)
require.ErrorContains(t, err, "unknown storage slot")
} }
} }
...@@ -354,3 +368,32 @@ func randAddr(t *testing.T) common.Address { ...@@ -354,3 +368,32 @@ func randAddr(t *testing.T) common.Address {
require.NoError(t, err) require.NoError(t, err)
return addr return addr
} }
func setupRandTest(t *testing.T) ([]common.Address, map[common.Address]*big.Int, []*crossdomain.Allowance, map[common.Address]common.Address, *big.Int) {
addresses := make([]common.Address, 0)
stateBalances := make(map[common.Address]*big.Int)
allowances := make([]*crossdomain.Allowance, 0)
stateAllowances := make(map[common.Address]common.Address)
totalSupply := big.NewInt(0)
for j := 0; j < rand.Intn(10000); j++ {
addr := randAddr(t)
addresses = append(addresses, addr)
stateBalances[addr] = big.NewInt(int64(rand.Intn(1_000_000)))
totalSupply = new(big.Int).Add(totalSupply, stateBalances[addr])
}
for j := 0; j < rand.Intn(1000); j++ {
addr := randAddr(t)
to := randAddr(t)
allowances = append(allowances, &crossdomain.Allowance{
From: addr,
To: to,
})
stateAllowances[addr] = to
}
return addresses, stateBalances, allowances, stateAllowances, totalSupply
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment