Commit 9bfb4ddb authored by Mark Tyneway's avatar Mark Tyneway Committed by GitHub

Merge pull request #7525 from bobanetwork/jyellick/op-node-fix-limits-bug

op-node: RPC Limit client does not respect context
parents 1f1c908f a77ef6dc
......@@ -2,17 +2,34 @@ package sources
import (
"context"
"net"
"sync"
"github.com/ethereum-optimism/optimism/op-node/client"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/rpc"
"golang.org/x/sync/semaphore"
)
type limitClient struct {
c client.RPC
sema chan struct{}
wg sync.WaitGroup
mutex sync.Mutex
closed bool
c client.RPC
sema *semaphore.Weighted
wg sync.WaitGroup
}
// joinWaitGroup will add the caller to the waitgroup if the client has not
// already been told to shutdown. If the client has shut down, false is
// returned, otherwise true.
func (lc *limitClient) joinWaitGroup() bool {
lc.mutex.Lock()
defer lc.mutex.Unlock()
if lc.closed {
return false
}
lc.wg.Add(1)
return true
}
// LimitRPC limits concurrent RPC requests (excluding subscriptions) to a given number by wrapping the client with a semaphore.
......@@ -20,33 +37,47 @@ func LimitRPC(c client.RPC, concurrentRequests int) client.RPC {
return &limitClient{
c: c,
// the capacity of the channel determines how many go-routines can concurrently execute requests with the wrapped client.
sema: make(chan struct{}, concurrentRequests),
sema: semaphore.NewWeighted(int64(concurrentRequests)),
}
}
func (lc *limitClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
lc.wg.Add(1)
if !lc.joinWaitGroup() {
return net.ErrClosed
}
defer lc.wg.Done()
lc.sema <- struct{}{}
defer func() { <-lc.sema }()
if err := lc.sema.Acquire(ctx, 1); err != nil {
return err
}
defer lc.sema.Release(1)
return lc.c.BatchCallContext(ctx, b)
}
func (lc *limitClient) CallContext(ctx context.Context, result any, method string, args ...any) error {
lc.wg.Add(1)
if !lc.joinWaitGroup() {
return net.ErrClosed
}
defer lc.wg.Done()
lc.sema <- struct{}{}
defer func() { <-lc.sema }()
if err := lc.sema.Acquire(ctx, 1); err != nil {
return err
}
defer lc.sema.Release(1)
return lc.c.CallContext(ctx, result, method, args...)
}
func (lc *limitClient) EthSubscribe(ctx context.Context, channel any, args ...any) (ethereum.Subscription, error) {
if !lc.joinWaitGroup() {
return nil, net.ErrClosed
}
defer lc.wg.Done()
// subscription doesn't count towards request limit
return lc.c.EthSubscribe(ctx, channel, args...)
}
func (lc *limitClient) Close() {
lc.wg.Wait()
close(lc.sema)
lc.mutex.Lock()
lc.closed = true // No new waitgroup members after this is set
lc.mutex.Unlock()
lc.wg.Wait() // All waitgroup members exited, means no more dereferences of the client
lc.c.Close()
}
package sources
import (
"context"
"fmt"
"net"
"sync/atomic"
"testing"
"time"
"github.com/ethereum-optimism/optimism/op-node/client"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/rpc"
"github.com/stretchr/testify/require"
)
type MockRPC struct {
t *testing.T
blockedCallers atomic.Int32
errC chan error
}
func (m *MockRPC) Close() {}
func (m *MockRPC) CallContext(ctx context.Context, result any, method string, args ...any) error {
m.blockedCallers.Add(1)
defer m.blockedCallers.Add(-1)
return <-m.errC
}
func (m *MockRPC) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
m.blockedCallers.Add(1)
defer m.blockedCallers.Add(-1)
return <-m.errC
}
func (m *MockRPC) EthSubscribe(ctx context.Context, channel any, args ...any) (ethereum.Subscription, error) {
m.t.Fatal("EthSubscribe should not be called")
return nil, nil
}
func asyncCallContext(ctx context.Context, lc client.RPC) chan error {
errC := make(chan error)
go func() {
errC <- lc.CallContext(ctx, 0, "fake_method")
}()
return errC
}
func TestLimitClient(t *testing.T) {
// The MockRPC will block all calls until errC is written to
m := &MockRPC{
t: t,
errC: make(chan error),
}
lc := LimitRPC(m, 2).(*limitClient)
errC1 := asyncCallContext(context.Background(), lc)
errC2 := asyncCallContext(context.Background(), lc)
require.Eventually(t, func() bool { return m.blockedCallers.Load() == 2 }, time.Second, 10*time.Millisecond)
// Once the limit of 2 clients has been reached, we enqueue two more,
// one with a context that will expire
tCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
errC3 := asyncCallContext(tCtx, lc)
errC4 := asyncCallContext(context.Background(), lc)
select {
case err := <-errC3:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(time.Second):
t.Fatalf("context should have expired and the call returned")
}
// No further clients should be allowed after this block, but existing
// clients should persist until their contexts close
go lc.Close()
require.Eventually(t, func() bool {
lc.mutex.Lock()
defer lc.mutex.Unlock()
return lc.closed
}, time.Second, 10*time.Millisecond)
err := lc.CallContext(context.Background(), 0, "fake_method")
require.ErrorIs(t, err, net.ErrClosed, "Calls after close should return immediately with error")
// Existing clients should all remain blocked
select {
case err := <-errC1:
t.Fatalf("client should not have returned: %s", err)
case err := <-errC2:
t.Fatalf("client should not have returned: %s", err)
case err := <-errC4:
t.Fatalf("client should not have returned: %s", err)
case <-time.After(50 * time.Millisecond):
// None of the clients should return yet
}
m.errC <- fmt.Errorf("fake-error")
m.errC <- fmt.Errorf("fake-error")
require.Eventually(t, func() bool { return m.blockedCallers.Load() == 1 }, time.Second, 10*time.Millisecond)
m.errC <- fmt.Errorf("fake-error")
require.ErrorContains(t, <-errC1, "fake-error")
require.ErrorContains(t, <-errC2, "fake-error")
require.ErrorContains(t, <-errC4, "fake-error")
require.Eventually(t, func() bool { return m.blockedCallers.Load() == 0 }, time.Second, 10*time.Millisecond)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment