limit_test.go 3.21 KB
Newer Older
1 2 3 4
package sources

import (
	"context"
5
	"errors"
6 7 8 9 10
	"net"
	"sync/atomic"
	"testing"
	"time"

Sabnock01's avatar
Sabnock01 committed
11
	"github.com/ethereum-optimism/optimism/op-service/client"
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
	"github.com/ethereum/go-ethereum"
	"github.com/ethereum/go-ethereum/rpc"
	"github.com/stretchr/testify/require"
)

type MockRPC struct {
	t              *testing.T
	blockedCallers atomic.Int32
	errC           chan error
}

func (m *MockRPC) Close() {}

func (m *MockRPC) CallContext(ctx context.Context, result any, method string, args ...any) error {
	m.blockedCallers.Add(1)
	defer m.blockedCallers.Add(-1)
	return <-m.errC
}

func (m *MockRPC) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
	m.blockedCallers.Add(1)
	defer m.blockedCallers.Add(-1)
	return <-m.errC
}

func (m *MockRPC) EthSubscribe(ctx context.Context, channel any, args ...any) (ethereum.Subscription, error) {
	m.t.Fatal("EthSubscribe should not be called")
	return nil, nil
}

func asyncCallContext(ctx context.Context, lc client.RPC) chan error {
	errC := make(chan error)
	go func() {
		errC <- lc.CallContext(ctx, 0, "fake_method")
	}()
	return errC
}

func TestLimitClient(t *testing.T) {
	// The MockRPC will block all calls until errC is written to
	m := &MockRPC{
		t:    t,
		errC: make(chan error),
	}
	lc := LimitRPC(m, 2).(*limitClient)

	errC1 := asyncCallContext(context.Background(), lc)
	errC2 := asyncCallContext(context.Background(), lc)
	require.Eventually(t, func() bool { return m.blockedCallers.Load() == 2 }, time.Second, 10*time.Millisecond)

	// Once the limit of 2 clients has been reached, we enqueue two more,
	// one with a context that will expire
	tCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
	defer cancel()
	errC3 := asyncCallContext(tCtx, lc)
	errC4 := asyncCallContext(context.Background(), lc)

	select {
	case err := <-errC3:
		require.ErrorIs(t, err, context.DeadlineExceeded)
	case <-time.After(time.Second):
		t.Fatalf("context should have expired and the call returned")
	}

	// No further clients should be allowed after this block, but existing
	// clients should persist until their contexts close
	go lc.Close()
	require.Eventually(t, func() bool {
		lc.mutex.Lock()
		defer lc.mutex.Unlock()
		return lc.closed
	}, time.Second, 10*time.Millisecond)

	err := lc.CallContext(context.Background(), 0, "fake_method")
	require.ErrorIs(t, err, net.ErrClosed, "Calls after close should return immediately with error")

	// Existing clients should all remain blocked
	select {
	case err := <-errC1:
		t.Fatalf("client should not have returned: %s", err)
	case err := <-errC2:
		t.Fatalf("client should not have returned: %s", err)
	case err := <-errC4:
		t.Fatalf("client should not have returned: %s", err)
	case <-time.After(50 * time.Millisecond):
		// None of the clients should return yet
	}

100 101
	m.errC <- errors.New("fake-error")
	m.errC <- errors.New("fake-error")
102
	require.Eventually(t, func() bool { return m.blockedCallers.Load() == 1 }, time.Second, 10*time.Millisecond)
103
	m.errC <- errors.New("fake-error")
104 105 106 107 108 109 110

	require.ErrorContains(t, <-errC1, "fake-error")
	require.ErrorContains(t, <-errC2, "fake-error")
	require.ErrorContains(t, <-errC4, "fake-error")

	require.Eventually(t, func() bool { return m.blockedCallers.Load() == 0 }, time.Second, 10*time.Millisecond)
}