Commit a77ef6dc authored by Jason Yellick's avatar Jason Yellick

op-node: RPC Limit client does not respect context

This change addresses two bugs.

1. The existing limit.go implements its own semaphore which ignores the
   passed in context.  This means that the semaphore will block
   indefinitely, even when the context of the request has already
   expired.

2. The existing implementation does not guard against clients which
   invoke RPC methods after the channel underlying the custom semaphore
   has been closed.  This results in panics where the closed channel is
   written to during shutdown, and results in test flakiness.  This
   flakiness is most evident in the op-e2e-http-tests suite.

Along with these fixes comes a test which attempts to demonstrate the
previous bad behavior.  Because these bugs are inherently tied to the
interaction of multiple go routines, the test ends up being a bit
complex, but is well commented and hopefully remains readable.
parent 7c49bb3f
......@@ -2,17 +2,34 @@ package sources
import (
"context"
"net"
"sync"
"github.com/ethereum-optimism/optimism/op-node/client"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/rpc"
"golang.org/x/sync/semaphore"
)
type limitClient struct {
c client.RPC
sema chan struct{}
wg sync.WaitGroup
mutex sync.Mutex
closed bool
c client.RPC
sema *semaphore.Weighted
wg sync.WaitGroup
}
// joinWaitGroup will add the caller to the waitgroup if the client has not
// already been told to shutdown. If the client has shut down, false is
// returned, otherwise true.
func (lc *limitClient) joinWaitGroup() bool {
lc.mutex.Lock()
defer lc.mutex.Unlock()
if lc.closed {
return false
}
lc.wg.Add(1)
return true
}
// LimitRPC limits concurrent RPC requests (excluding subscriptions) to a given number by wrapping the client with a semaphore.
......@@ -20,33 +37,47 @@ func LimitRPC(c client.RPC, concurrentRequests int) client.RPC {
return &limitClient{
c: c,
// the capacity of the channel determines how many go-routines can concurrently execute requests with the wrapped client.
sema: make(chan struct{}, concurrentRequests),
sema: semaphore.NewWeighted(int64(concurrentRequests)),
}
}
func (lc *limitClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
lc.wg.Add(1)
if !lc.joinWaitGroup() {
return net.ErrClosed
}
defer lc.wg.Done()
lc.sema <- struct{}{}
defer func() { <-lc.sema }()
if err := lc.sema.Acquire(ctx, 1); err != nil {
return err
}
defer lc.sema.Release(1)
return lc.c.BatchCallContext(ctx, b)
}
func (lc *limitClient) CallContext(ctx context.Context, result any, method string, args ...any) error {
lc.wg.Add(1)
if !lc.joinWaitGroup() {
return net.ErrClosed
}
defer lc.wg.Done()
lc.sema <- struct{}{}
defer func() { <-lc.sema }()
if err := lc.sema.Acquire(ctx, 1); err != nil {
return err
}
defer lc.sema.Release(1)
return lc.c.CallContext(ctx, result, method, args...)
}
func (lc *limitClient) EthSubscribe(ctx context.Context, channel any, args ...any) (ethereum.Subscription, error) {
if !lc.joinWaitGroup() {
return nil, net.ErrClosed
}
defer lc.wg.Done()
// subscription doesn't count towards request limit
return lc.c.EthSubscribe(ctx, channel, args...)
}
func (lc *limitClient) Close() {
lc.wg.Wait()
close(lc.sema)
lc.mutex.Lock()
lc.closed = true // No new waitgroup members after this is set
lc.mutex.Unlock()
lc.wg.Wait() // All waitgroup members exited, means no more dereferences of the client
lc.c.Close()
}
package sources
import (
"context"
"fmt"
"net"
"sync/atomic"
"testing"
"time"
"github.com/ethereum-optimism/optimism/op-node/client"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/rpc"
"github.com/stretchr/testify/require"
)
type MockRPC struct {
t *testing.T
blockedCallers atomic.Int32
errC chan error
}
func (m *MockRPC) Close() {}
func (m *MockRPC) CallContext(ctx context.Context, result any, method string, args ...any) error {
m.blockedCallers.Add(1)
defer m.blockedCallers.Add(-1)
return <-m.errC
}
func (m *MockRPC) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error {
m.blockedCallers.Add(1)
defer m.blockedCallers.Add(-1)
return <-m.errC
}
func (m *MockRPC) EthSubscribe(ctx context.Context, channel any, args ...any) (ethereum.Subscription, error) {
m.t.Fatal("EthSubscribe should not be called")
return nil, nil
}
func asyncCallContext(ctx context.Context, lc client.RPC) chan error {
errC := make(chan error)
go func() {
errC <- lc.CallContext(ctx, 0, "fake_method")
}()
return errC
}
func TestLimitClient(t *testing.T) {
// The MockRPC will block all calls until errC is written to
m := &MockRPC{
t: t,
errC: make(chan error),
}
lc := LimitRPC(m, 2).(*limitClient)
errC1 := asyncCallContext(context.Background(), lc)
errC2 := asyncCallContext(context.Background(), lc)
require.Eventually(t, func() bool { return m.blockedCallers.Load() == 2 }, time.Second, 10*time.Millisecond)
// Once the limit of 2 clients has been reached, we enqueue two more,
// one with a context that will expire
tCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
errC3 := asyncCallContext(tCtx, lc)
errC4 := asyncCallContext(context.Background(), lc)
select {
case err := <-errC3:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(time.Second):
t.Fatalf("context should have expired and the call returned")
}
// No further clients should be allowed after this block, but existing
// clients should persist until their contexts close
go lc.Close()
require.Eventually(t, func() bool {
lc.mutex.Lock()
defer lc.mutex.Unlock()
return lc.closed
}, time.Second, 10*time.Millisecond)
err := lc.CallContext(context.Background(), 0, "fake_method")
require.ErrorIs(t, err, net.ErrClosed, "Calls after close should return immediately with error")
// Existing clients should all remain blocked
select {
case err := <-errC1:
t.Fatalf("client should not have returned: %s", err)
case err := <-errC2:
t.Fatalf("client should not have returned: %s", err)
case err := <-errC4:
t.Fatalf("client should not have returned: %s", err)
case <-time.After(50 * time.Millisecond):
// None of the clients should return yet
}
m.errC <- fmt.Errorf("fake-error")
m.errC <- fmt.Errorf("fake-error")
require.Eventually(t, func() bool { return m.blockedCallers.Load() == 1 }, time.Second, 10*time.Millisecond)
m.errC <- fmt.Errorf("fake-error")
require.ErrorContains(t, <-errC1, "fake-error")
require.ErrorContains(t, <-errC2, "fake-error")
require.ErrorContains(t, <-errC4, "fake-error")
require.Eventually(t, func() bool { return m.blockedCallers.Load() == 0 }, time.Second, 10*time.Millisecond)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment