Commit 14719116 authored by Matthew Slipper's avatar Matthew Slipper Committed by GitHub

Merge pull request #3681 from ethereum-optimism/10-09-proxyd_Integrate_custom_rate_limiter

proxyd: Integrate custom rate limiter
parents ce033510 01ae6625
---
'@eth-optimism/proxyd': minor
---
Adds new Redis rate limiter
......@@ -57,22 +57,14 @@ type RedisBackendRateLimiter struct {
tkMtx sync.Mutex
}
func NewRedisRateLimiter(url string) (BackendRateLimiter, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
rdb := redis.NewClient(opts)
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
func NewRedisRateLimiter(rdb *redis.Client) BackendRateLimiter {
out := &RedisBackendRateLimiter{
rdb: rdb,
randID: randStr(20),
touchKeys: make(map[string]time.Duration),
}
go out.touch()
return out, nil
return out
}
func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
......
......@@ -46,16 +46,8 @@ type redisCache struct {
rdb *redis.Client
}
func newRedisCache(url string) (*redisCache, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
rdb := redis.NewClient(opts)
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
return &redisCache{rdb}, nil
func newRedisCache(rdb *redis.Client) *redisCache {
return &redisCache{rdb}
}
func (c *redisCache) Get(ctx context.Context, key string) (string, error) {
......
......@@ -42,7 +42,8 @@ type MetricsConfig struct {
type RateLimitConfig struct {
UseRedis bool `toml:"use_redis"`
RatePerSecond int `toml:"rate_per_second"`
BaseRate int `toml:"base_rate"`
BaseInterval TOMLDuration `toml:"base_interval"`
ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"`
......
......@@ -17,7 +17,7 @@ type FrontendRateLimiter interface {
//
// No error will be returned if the limit could not be taken
// as a result of the requestor being over the limit.
Take(ctx context.Context, key string, max int) (bool, error)
Take(ctx context.Context, key string) (bool, error)
}
// limitedKeys is a wrapper around a map that stores a truncated
......@@ -58,16 +58,18 @@ func (l *limitedKeys) Take(key string, max int) bool {
type MemoryFrontendRateLimiter struct {
currGeneration *limitedKeys
dur time.Duration
max int
mtx sync.Mutex
}
func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter {
func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter {
return &MemoryFrontendRateLimiter{
dur: dur,
max: max,
}
}
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) {
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
m.mtx.Lock()
// Create truncated timestamp
truncTS := truncateNow(m.dur)
......@@ -83,35 +85,51 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max in
m.mtx.Unlock()
return limiter.Take(key, max), nil
return limiter.Take(key, m.max), nil
}
// RedisFrontendRateLimiter is a rate limiter that stores data in Redis.
// It uses the basic rate limiter pattern described on the Redis best
// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/.
type RedisFrontendRateLimiter struct {
r *redis.Client
dur time.Duration
r *redis.Client
dur time.Duration
max int
prefix string
}
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter {
return &RedisFrontendRateLimiter{r: r, dur: dur}
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter {
return &RedisFrontendRateLimiter{
r: r,
dur: dur,
max: max,
prefix: prefix,
}
}
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) {
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
var incr *redis.IntCmd
truncTS := truncateNow(r.dur)
fullKey := fmt.Sprintf("%s:%d", key, truncTS)
fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS)
_, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, fullKey)
pipe.Expire(ctx, fullKey, r.dur-time.Second)
pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond)
return nil
})
if err != nil {
frontendRateLimitTakeErrors.Inc()
return false, err
}
return incr.Val()-1 < int64(max), nil
return incr.Val()-1 < int64(r.max), nil
}
type noopFrontendRateLimiter struct{}
var NoopFrontendRateLimiter = &noopFrontendRateLimiter{}
func (n *noopFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
return true, nil
}
// truncateNow truncates the current timestamp
......
......@@ -20,32 +20,32 @@ func TestFrontendRateLimiter(t *testing.T) {
Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()),
})
max := 2
lims := []struct {
name string
frl FrontendRateLimiter
}{
{"memory", NewMemoryFrontendRateLimit(2 * time.Second)},
{"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)},
{"memory", NewMemoryFrontendRateLimit(2*time.Second, max)},
{"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second, max, "")},
}
max := 2
for _, cfg := range lims {
frl := cfg.frl
ctx := context.Background()
t.Run(cfg.name, func(t *testing.T) {
for i := 0; i < 4; i++ {
ok, err := frl.Take(ctx, "foo", max)
ok, err := frl.Take(ctx, "foo")
require.NoError(t, err)
require.Equal(t, i < max, ok)
ok, err = frl.Take(ctx, "bar", max)
ok, err = frl.Take(ctx, "bar")
require.NoError(t, err)
require.Equal(t, i < max, ok)
}
time.Sleep(2 * time.Second)
for i := 0; i < 4; i++ {
ok, _ := frl.Take(ctx, "foo", max)
ok, _ := frl.Take(ctx, "foo")
require.Equal(t, i < max, ok)
ok, _ = frl.Take(ctx, "bar", max)
ok, _ = frl.Take(ctx, "bar")
require.Equal(t, i < max, ok)
}
})
......
......@@ -261,6 +261,8 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
config.BackendOptions.MaxRetries = 2
// Setup redis to detect offline backends
config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port())
redisClient, err := proxyd.NewRedisClient(config.Redis.URL)
require.NoError(t, err)
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse))
defer goodBackend.Close()
......@@ -285,7 +287,7 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
require.Equal(t, 1, len(badBackend.Requests()))
require.Equal(t, 1, len(goodBackend.Requests()))
rr, err := proxyd.NewRedisRateLimiter(config.Redis.URL)
rr := proxyd.NewRedisRateLimiter(redisClient)
require.NoError(t, err)
online, err := rr.IsBackendOnline("bad")
require.NoError(t, err)
......
......@@ -18,7 +18,8 @@ eth_chainId = "main"
eth_foobar = "main"
[rate_limit]
rate_per_second = 2
base_rate = 2
base_interval = "1s"
exempt_origins = ["exempt_origin"]
exempt_user_agents = ["exempt_agent"]
error_message = "over rate limit with special message"
......
......@@ -236,6 +236,12 @@ var (
100,
},
})
frontendRateLimitTakeErrors = promauto.NewCounter(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "rate_limit_take_errors",
Help: "Count of errors taking frontend rate limits",
})
)
func RecordRedisError(source string) {
......
......@@ -13,6 +13,7 @@ import (
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/semaphore"
)
......@@ -34,25 +35,29 @@ func Start(config *Config) (func(), error) {
}
}
var redisURL string
var redisClient *redis.Client
if config.Redis.URL != "" {
rURL, err := ReadFromEnvOrConfig(config.Redis.URL)
if err != nil {
return nil, err
}
redisURL = rURL
redisClient, err = NewRedisClient(rURL)
if err != nil {
return nil, err
}
}
if redisClient == nil && config.RateLimit.UseRedis {
return nil, errors.New("must specify a Redis URL if UseRedis is true in rate limit config")
}
var lim BackendRateLimiter
var err error
if redisURL == "" {
if redisClient == nil {
log.Warn("redis is not configured, using local rate limiter")
lim = NewLocalBackendRateLimiter()
} else {
lim, err = NewRedisRateLimiter(redisURL)
if err != nil {
return nil, err
}
lim = NewRedisRateLimiter(redisClient)
}
// While modifying shared globals is a bad practice, the alternative
......@@ -206,13 +211,11 @@ func Start(config *Config) (func(), error) {
return nil, err
}
if redisURL != "" {
if cache, err = newRedisCache(redisURL); err != nil {
return nil, err
}
} else {
if redisClient == nil {
log.Warn("redis is not configured, using in-memory cache")
cache = newMemoryCache()
} else {
cache = newRedisCache(redisClient)
}
// Ideally, the BlocKSyncRPCURL should be the sequencer or a HA replica that's not far behind
ethClient, err := ethclient.Dial(blockSyncRPCURL)
......@@ -240,6 +243,7 @@ func Start(config *Config) (func(), error) {
config.Server.EnableRequestLog,
config.Server.MaxRequestBodyLogLen,
config.BatchConfig.MaxSize,
redisClient,
)
if err != nil {
return nil, fmt.Errorf("error creating server: %w", err)
......
package proxyd
import (
"context"
"time"
"github.com/go-redis/redis/v8"
)
func NewRedisClient(url string) (*redis.Client, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
client := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
return client, nil
}
......@@ -13,11 +13,8 @@ import (
"sync"
"time"
"github.com/sethvargo/go-limiter"
"github.com/sethvargo/go-limiter/memorystore"
"github.com/sethvargo/go-limiter/noopstore"
"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
......@@ -50,9 +47,8 @@ type Server struct {
maxUpstreamBatchSize int
maxBatchSize int
upgrader *websocket.Upgrader
mainLim limiter.Store
overrideLims map[string]limiter.Store
limConfig RateLimitConfig
mainLim FrontendRateLimiter
overrideLims map[string]FrontendRateLimiter
limExemptOrigins map[string]bool
limExemptUserAgents map[string]bool
rpcServer *http.Server
......@@ -77,6 +73,7 @@ func NewServer(
enableRequestLog bool,
maxRequestBodyLogLen int,
maxBatchSize int,
redisClient *redis.Client,
) (*Server, error) {
if cache == nil {
cache = &NoopRPCCache{}
......@@ -98,19 +95,19 @@ func NewServer(
maxBatchSize = MaxBatchRPCCallsHardLimit
}
var mainLim limiter.Store
limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.RatePerSecond > 0 {
var err error
mainLim, err = memorystore.New(&memorystore.Config{
Tokens: uint64(rateLimitConfig.RatePerSecond),
Interval: time.Second,
})
if err != nil {
return nil, err
limiterFactory := func(dur time.Duration, max int, prefix string) FrontendRateLimiter {
if rateLimitConfig.UseRedis {
return NewRedisFrontendRateLimiter(redisClient, dur, max, prefix)
}
return NewMemoryFrontendRateLimit(dur, max)
}
var mainLim FrontendRateLimiter
limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.BaseRate > 0 {
mainLim = limiterFactory(time.Duration(rateLimitConfig.BaseInterval), rateLimitConfig.BaseRate, "main")
for _, origin := range rateLimitConfig.ExemptOrigins {
limExemptOrigins[strings.ToLower(origin)] = true
}
......@@ -118,16 +115,13 @@ func NewServer(
limExemptUserAgents[strings.ToLower(agent)] = true
}
} else {
mainLim, _ = noopstore.New()
mainLim = NoopFrontendRateLimiter
}
overrideLims := make(map[string]limiter.Store)
overrideLims := make(map[string]FrontendRateLimiter)
for method, override := range rateLimitConfig.MethodOverrides {
var err error
overrideLims[method], err = memorystore.New(&memorystore.Config{
Tokens: uint64(override.Limit),
Interval: time.Duration(override.Interval),
})
overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method)
if err != nil {
return nil, err
}
......@@ -151,7 +145,6 @@ func NewServer(
},
mainLim: mainLim,
overrideLims: overrideLims,
limConfig: rateLimitConfig,
limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents,
}, nil
......@@ -235,7 +228,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return false
}
var lim limiter.Store
var lim FrontendRateLimiter
if method == "" {
lim = s.mainLim
} else {
......@@ -246,7 +239,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return false
}
_, _, _, ok, _ := lim.Take(ctx, xff)
ok, err := lim.Take(ctx, xff)
if err != nil {
log.Warn("error taking rate limit", "err", err)
return true
}
return !ok
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment