Commit f19edc9f authored by Matthew Slipper's avatar Matthew Slipper

op-chain-ops: Fix data race in error handling

The error channel in the OVM_ETH migration is buffered. In rare cases, errors were written to this channel in such a way that they were not processed by the collector goroutine. This meant that errors were not being caught by the `lastErr != nil` check, and were instead triggering the total supply check below.

This error reliably reproduces when running the new TestMigrateBalancesRandomMissing test. This test generates a random state, and randomly removes an address/allowance from the witness data. It runs this 100 times per test. We didn't catch this with the other random test because it doesn't exercise the error handling codepath.
parent 7d0c54cd
...@@ -109,7 +109,7 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -109,7 +109,7 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
outCh := make(chan accountData) outCh := make(chan accountData)
// Channel to receive errors from each iteration job. // Channel to receive errors from each iteration job.
errCh := make(chan error, checkJobs) errCh := make(chan error, checkJobs)
// Channel to cancel all iteration jobs as well as the collector. // Channel to cancel all iteration jobs.
cancelCh := make(chan struct{}) cancelCh := make(chan struct{})
// Define a worker function to iterate over each partition. // Define a worker function to iterate over each partition.
...@@ -244,16 +244,17 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -244,16 +244,17 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
go worker(start, end) go worker(start, end)
} }
// Make a channel to make sure that the collector process completes. // Make a channel to track when collector process completes.
collectorCloseCh := make(chan struct{}) collectorClosedCh := make(chan struct{})
// Make a channel to cancel the collector process.
collectorCancelCh := make(chan struct{})
// Keep track of the last error seen. // Keep track of the last error seen.
var lastErr error var lastErr error
// There are multiple ways that the cancel channel can be closed: // The cancel channel can be closed if any of the workers returns an error.
// - if we receive an error from the errCh // We wrap the close in a sync.Once to ensure that it's only closed once.
// - if the collector process completes
// To prevent panics, we wrap the close in a sync.Once.
var cancelOnce sync.Once var cancelOnce sync.Once
// Create a map of accounts we've seen so that we can filter out duplicates. // Create a map of accounts we've seen so that we can filter out duplicates.
...@@ -268,7 +269,7 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -268,7 +269,7 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
progress := util.ProgressLogger(1000, "Migrated OVM_ETH storage slot") progress := util.ProgressLogger(1000, "Migrated OVM_ETH storage slot")
go func() { go func() {
defer func() { defer func() {
collectorCloseCh <- struct{}{} collectorClosedCh <- struct{}{}
}() }()
for { for {
select { select {
...@@ -291,10 +292,20 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -291,10 +292,20 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
seenAccounts[account.address] = true seenAccounts[account.address] = true
case err := <-errCh: case err := <-errCh:
cancelOnce.Do(func() { cancelOnce.Do(func() {
lastErr = err
close(cancelCh) close(cancelCh)
lastErr = err
}) })
case <-cancelCh: case <-collectorCancelCh:
// Explicitly drain the error channel. Since the error channel is buffered, it's possible
// for the wg.Wait() call below to unblock and cancel this goroutine before the error gets
// processed by the case statement above.
for len(errCh) > 0 {
err := <-errCh
if lastErr == nil {
lastErr = err
}
}
return return
} }
} }
...@@ -302,13 +313,10 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm ...@@ -302,13 +313,10 @@ func doMigration(mutableDB *state.StateDB, dbFactory DBFactory, addresses []comm
// Wait for the workers to finish. // Wait for the workers to finish.
wg.Wait() wg.Wait()
// Close the cancel channel to signal the collector process to stop.
cancelOnce.Do(func() {
close(cancelCh)
})
// Wait for the collector process to finish. // Close the collector, and wait for it to finish.
<-collectorCloseCh close(collectorCancelCh)
<-collectorClosedCh
// If we saw an error, return it. // If we saw an error, return it.
if lastErr != nil { if lastErr != nil {
......
...@@ -228,44 +228,58 @@ func makeLegacyETH(t *testing.T, totalSupply *big.Int, balances map[common.Addre ...@@ -228,44 +228,58 @@ func makeLegacyETH(t *testing.T, totalSupply *big.Int, balances map[common.Addre
} }
} }
// TestMigrateBalancesRandom tests that the pre-check balances function works // TestMigrateBalancesRandomOK tests that the pre-check balances function works
// with random addresses. This test makes sure that the partition logic doesn't // with random addresses. This test makes sure that the partition logic doesn't
// miss anything. // miss anything, and helps detect concurrency errors.
func TestMigrateBalancesRandom(t *testing.T) { func TestMigrateBalancesRandomOK(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
addresses := make([]common.Address, 0) addresses, stateBalances, allowances, stateAllowances, totalSupply := setupRandTest(t)
stateBalances := make(map[common.Address]*big.Int)
allowances := make([]*crossdomain.Allowance, 0) db, factory := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances)
stateAllowances := make(map[common.Address]common.Address) err := doMigration(db, factory, addresses, allowances, big.NewInt(0), false)
require.NoError(t, err)
totalSupply := big.NewInt(0)
for j := 0; j < rand.Intn(10000); j++ { for addr, expBal := range stateBalances {
addr := randAddr(t) actBal := db.GetBalance(addr)
addresses = append(addresses, addr) require.EqualValues(t, expBal, actBal)
stateBalances[addr] = big.NewInt(int64(rand.Intn(1_000_000)))
totalSupply = new(big.Int).Add(totalSupply, stateBalances[addr])
} }
}
}
for j := 0; j < rand.Intn(1000); j++ { // TestMigrateBalancesRandomMissing tests that the pre-check balances function works
addr := randAddr(t) // with random addresses when some of them are missing. This helps make sure that the
to := randAddr(t) // partition logic doesn't miss anything, and helps detect concurrency errors.
allowances = append(allowances, &crossdomain.Allowance{ func TestMigrateBalancesRandomMissing(t *testing.T) {
From: addr, for i := 0; i < 100; i++ {
To: to, addresses, stateBalances, allowances, stateAllowances, totalSupply := setupRandTest(t)
})
stateAllowances[addr] = to if len(addresses) == 0 {
continue
} }
// Remove a random address from the list of witnesses
idx := rand.Intn(len(addresses))
addresses = append(addresses[:idx], addresses[idx+1:]...)
db, factory := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances) db, factory := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances)
err := doMigration(db, factory, addresses, allowances, big.NewInt(0), false) err := doMigration(db, factory, addresses, allowances, big.NewInt(0), false)
require.NoError(t, err) require.ErrorContains(t, err, "unknown storage slot")
}
for addr, expBal := range stateBalances { for i := 0; i < 100; i++ {
actBal := db.GetBalance(addr) addresses, stateBalances, allowances, stateAllowances, totalSupply := setupRandTest(t)
require.EqualValues(t, expBal, actBal)
if len(allowances) == 0 {
continue
} }
// Remove a random allowance from the list of witnesses
idx := rand.Intn(len(allowances))
allowances = append(allowances[:idx], allowances[idx+1:]...)
db, factory := makeLegacyETH(t, totalSupply, stateBalances, stateAllowances)
err := doMigration(db, factory, addresses, allowances, big.NewInt(0), false)
require.ErrorContains(t, err, "unknown storage slot")
} }
} }
...@@ -354,3 +368,32 @@ func randAddr(t *testing.T) common.Address { ...@@ -354,3 +368,32 @@ func randAddr(t *testing.T) common.Address {
require.NoError(t, err) require.NoError(t, err)
return addr return addr
} }
func setupRandTest(t *testing.T) ([]common.Address, map[common.Address]*big.Int, []*crossdomain.Allowance, map[common.Address]common.Address, *big.Int) {
addresses := make([]common.Address, 0)
stateBalances := make(map[common.Address]*big.Int)
allowances := make([]*crossdomain.Allowance, 0)
stateAllowances := make(map[common.Address]common.Address)
totalSupply := big.NewInt(0)
for j := 0; j < rand.Intn(10000); j++ {
addr := randAddr(t)
addresses = append(addresses, addr)
stateBalances[addr] = big.NewInt(int64(rand.Intn(1_000_000)))
totalSupply = new(big.Int).Add(totalSupply, stateBalances[addr])
}
for j := 0; j < rand.Intn(1000); j++ {
addr := randAddr(t)
to := randAddr(t)
allowances = append(allowances, &crossdomain.Allowance{
From: addr,
To: to,
})
stateAllowances[addr] = to
}
return addresses, stateBalances, allowances, stateAllowances, totalSupply
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment