Commit bce67b2e authored by Adrian Sutton's avatar Adrian Sutton

challenger: Fix thread safety in op-challenger tests.

parent a298160f
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"math/big" "math/big"
"sync"
"testing" "testing"
"time" "time"
...@@ -60,15 +61,16 @@ func TestMonitorGames(t *testing.T) { ...@@ -60,15 +61,16 @@ func TestMonitorGames(t *testing.T) {
go func() { go func() {
headerNotSent := true headerNotSent := true
for { for {
if len(sched.scheduled) >= 1 { if len(sched.Scheduled()) >= 1 {
break break
} }
if mockHeadSource.sub == nil { sub := mockHeadSource.Sub()
if sub == nil {
continue continue
} }
if headerNotSent { if headerNotSent {
select { select {
case mockHeadSource.sub.headers <- &ethtypes.Header{ case sub.headers <- &ethtypes.Header{
Number: big.NewInt(1), Number: big.NewInt(1),
}: }:
headerNotSent = false headerNotSent = false
...@@ -80,15 +82,15 @@ func TestMonitorGames(t *testing.T) { ...@@ -80,15 +82,15 @@ func TestMonitorGames(t *testing.T) {
// Just to avoid a tight loop // Just to avoid a tight loop
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
mockHeadSource.err = fmt.Errorf("eth subscribe test error") mockHeadSource.SetErr(fmt.Errorf("eth subscribe test error"))
cancel() cancel()
}() }()
monitor.StartMonitoring() monitor.StartMonitoring()
<-ctx.Done() <-ctx.Done()
monitor.StopMonitoring() monitor.StopMonitoring()
require.Len(t, sched.scheduled, 1) require.Len(t, sched.Scheduled(), 1)
require.Equal(t, []common.Address{addr1, addr2}, sched.scheduled[0]) require.Equal(t, []common.Address{addr1, addr2}, sched.Scheduled()[0])
}) })
t.Run("Resubscribes on error", func(t *testing.T) { t.Run("Resubscribes on error", func(t *testing.T) {
...@@ -103,19 +105,20 @@ func TestMonitorGames(t *testing.T) { ...@@ -103,19 +105,20 @@ func TestMonitorGames(t *testing.T) {
go func() { go func() {
// Wait for the subscription to be created // Wait for the subscription to be created
waitErr := wait.For(context.Background(), 5*time.Second, func() (bool, error) { waitErr := wait.For(context.Background(), 5*time.Second, func() (bool, error) {
return mockHeadSource.sub != nil, nil return mockHeadSource.Sub() != nil, nil
}) })
require.NoError(t, waitErr) require.NoError(t, waitErr)
mockHeadSource.sub.errChan <- fmt.Errorf("test error") mockHeadSource.Sub().errChan <- fmt.Errorf("test error")
for { for {
if len(sched.scheduled) >= 1 { if len(sched.Scheduled()) >= 1 {
break break
} }
if mockHeadSource.sub == nil { sub := mockHeadSource.Sub()
if sub == nil {
continue continue
} }
select { select {
case mockHeadSource.sub.headers <- &ethtypes.Header{ case sub.headers <- &ethtypes.Header{
Number: big.NewInt(1), Number: big.NewInt(1),
}: }:
case <-ctx.Done(): case <-ctx.Done():
...@@ -126,15 +129,15 @@ func TestMonitorGames(t *testing.T) { ...@@ -126,15 +129,15 @@ func TestMonitorGames(t *testing.T) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
require.NoError(t, waitErr) require.NoError(t, waitErr)
mockHeadSource.err = fmt.Errorf("eth subscribe test error") mockHeadSource.SetErr(fmt.Errorf("eth subscribe test error"))
cancel() cancel()
}() }()
monitor.StartMonitoring() monitor.StartMonitoring()
<-ctx.Done() <-ctx.Done()
monitor.StopMonitoring() monitor.StopMonitoring()
require.NotEmpty(t, sched.scheduled) // We might get more than one update scheduled. require.NotEmpty(t, sched.Scheduled()) // We might get more than one update scheduled.
require.Equal(t, []common.Address{addr1, addr2}, sched.scheduled[0]) require.Equal(t, []common.Address{addr1, addr2}, sched.Scheduled()[0])
}) })
} }
...@@ -147,8 +150,8 @@ func TestMonitorCreateAndProgressGameAgents(t *testing.T) { ...@@ -147,8 +150,8 @@ func TestMonitorCreateAndProgressGameAgents(t *testing.T) {
require.NoError(t, monitor.progressGames(context.Background(), common.Hash{0x01})) require.NoError(t, monitor.progressGames(context.Background(), common.Hash{0x01}))
require.Len(t, sched.scheduled, 1) require.Len(t, sched.Scheduled(), 1)
require.Equal(t, []common.Address{addr1, addr2}, sched.scheduled[0]) require.Equal(t, []common.Address{addr1, addr2}, sched.Scheduled()[0])
} }
func TestMonitorOnlyScheduleSpecifiedGame(t *testing.T) { func TestMonitorOnlyScheduleSpecifiedGame(t *testing.T) {
...@@ -159,8 +162,8 @@ func TestMonitorOnlyScheduleSpecifiedGame(t *testing.T) { ...@@ -159,8 +162,8 @@ func TestMonitorOnlyScheduleSpecifiedGame(t *testing.T) {
require.NoError(t, monitor.progressGames(context.Background(), common.Hash{0x01})) require.NoError(t, monitor.progressGames(context.Background(), common.Hash{0x01}))
require.Len(t, sched.scheduled, 1) require.Len(t, sched.Scheduled(), 1)
require.Equal(t, []common.Address{addr2}, sched.scheduled[0]) require.Equal(t, []common.Address{addr2}, sched.Scheduled()[0])
} }
func newFDG(proxy common.Address, timestamp uint64) types.GameMetadata { func newFDG(proxy common.Address, timestamp uint64) types.GameMetadata {
...@@ -197,15 +200,36 @@ func setupMonitorTest( ...@@ -197,15 +200,36 @@ func setupMonitorTest(
} }
type mockNewHeadSource struct { type mockNewHeadSource struct {
sync.Mutex
sub *mockSubscription sub *mockSubscription
err error err error
} }
func (m *mockNewHeadSource) Sub() *mockSubscription {
m.Lock()
defer m.Unlock()
return m.sub
}
func (m *mockNewHeadSource) SetSub(sub *mockSubscription) {
m.Lock()
defer m.Unlock()
m.sub = sub
}
func (m *mockNewHeadSource) SetErr(err error) {
m.Lock()
defer m.Unlock()
m.err = err
}
func (m *mockNewHeadSource) EthSubscribe( func (m *mockNewHeadSource) EthSubscribe(
ctx context.Context, _ context.Context,
ch any, ch any,
args ...any, _ ...any,
) (ethereum.Subscription, error) { ) (ethereum.Subscription, error) {
m.Lock()
defer m.Unlock()
errChan := make(chan error) errChan := make(chan error)
m.sub = &mockSubscription{errChan, (ch).(chan<- *ethtypes.Header)} m.sub = &mockSubscription{errChan, (ch).(chan<- *ethtypes.Header)}
if m.err != nil { if m.err != nil {
...@@ -230,18 +254,26 @@ type stubGameSource struct { ...@@ -230,18 +254,26 @@ type stubGameSource struct {
} }
func (s *stubGameSource) FetchAllGamesAtBlock( func (s *stubGameSource) FetchAllGamesAtBlock(
ctx context.Context, _ context.Context,
earliest uint64, _ uint64,
blockHash common.Hash, _ common.Hash,
) ([]types.GameMetadata, error) { ) ([]types.GameMetadata, error) {
return s.games, nil return s.games, nil
} }
type stubScheduler struct { type stubScheduler struct {
sync.Mutex
scheduled [][]common.Address scheduled [][]common.Address
} }
func (s *stubScheduler) Scheduled() [][]common.Address {
s.Lock()
defer s.Unlock()
return s.scheduled
}
func (s *stubScheduler) Schedule(games []types.GameMetadata) error { func (s *stubScheduler) Schedule(games []types.GameMetadata) error {
s.Lock()
defer s.Unlock()
var addrs []common.Address var addrs []common.Address
for _, game := range games { for _, game := range games {
addrs = append(addrs, game.Proxy) addrs = append(addrs, game.Proxy)
......
...@@ -3,6 +3,7 @@ package scheduler ...@@ -3,6 +3,7 @@ package scheduler
import ( import (
"context" "context"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
...@@ -28,21 +29,21 @@ func TestWorkerShouldProcessJobsUntilContextDone(t *testing.T) { ...@@ -28,21 +29,21 @@ func TestWorkerShouldProcessJobsUntilContextDone(t *testing.T) {
player: &test.StubGamePlayer{StatusValue: types.GameStatusInProgress}, player: &test.StubGamePlayer{StatusValue: types.GameStatusInProgress},
} }
waitErr := wait.For(context.Background(), 100*time.Millisecond, func() (bool, error) { waitErr := wait.For(context.Background(), 100*time.Millisecond, func() (bool, error) {
return ms.activeCalls >= 1, nil return ms.activeCalls.Load() >= 1, nil
}) })
require.NoError(t, waitErr) require.NoError(t, waitErr)
require.Equal(t, ms.activeCalls, 1) require.EqualValues(t, ms.activeCalls.Load(), 1)
require.Equal(t, ms.idleCalls, 1) require.EqualValues(t, ms.idleCalls.Load(), 1)
in <- job{ in <- job{
player: &test.StubGamePlayer{StatusValue: types.GameStatusDefenderWon}, player: &test.StubGamePlayer{StatusValue: types.GameStatusDefenderWon},
} }
waitErr = wait.For(context.Background(), 100*time.Millisecond, func() (bool, error) { waitErr = wait.For(context.Background(), 100*time.Millisecond, func() (bool, error) {
return ms.activeCalls >= 2, nil return ms.activeCalls.Load() >= 2, nil
}) })
require.NoError(t, waitErr) require.NoError(t, waitErr)
require.Equal(t, ms.activeCalls, 2) require.EqualValues(t, ms.activeCalls.Load(), 2)
require.Equal(t, ms.idleCalls, 2) require.EqualValues(t, ms.idleCalls.Load(), 2)
result1 := readWithTimeout(t, out) result1 := readWithTimeout(t, out)
result2 := readWithTimeout(t, out) result2 := readWithTimeout(t, out)
...@@ -56,16 +57,16 @@ func TestWorkerShouldProcessJobsUntilContextDone(t *testing.T) { ...@@ -56,16 +57,16 @@ func TestWorkerShouldProcessJobsUntilContextDone(t *testing.T) {
} }
type metricSink struct { type metricSink struct {
activeCalls int activeCalls atomic.Int32
idleCalls int idleCalls atomic.Int32
} }
func (m *metricSink) ThreadActive() { func (m *metricSink) ThreadActive() {
m.activeCalls++ m.activeCalls.Add(1)
} }
func (m *metricSink) ThreadIdle() { func (m *metricSink) ThreadIdle() {
m.idleCalls++ m.idleCalls.Add(1)
} }
func readWithTimeout[T any](t *testing.T, ch <-chan T) T { func readWithTimeout[T any](t *testing.T, ch <-chan T) T {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment