Commit 01ae6625 authored by Matthew Slipper's avatar Matthew Slipper

proxyd: Integrate custom rate limiter

Integrates the custom rate limiter in the previous PR into the rest of the application. Also takes the opportunity to clean up how we instantiate Redis clients so that we can share them among multiple different services.

There are some config changes in this PR. Specifically, you must specify a `base_rate` and `base_interval` in the rate limit config.
parent c2b8efac
---
'@eth-optimism/proxyd': minor
---
Adds new Redis rate limiter
...@@ -57,22 +57,14 @@ type RedisBackendRateLimiter struct { ...@@ -57,22 +57,14 @@ type RedisBackendRateLimiter struct {
tkMtx sync.Mutex tkMtx sync.Mutex
} }
func NewRedisRateLimiter(url string) (BackendRateLimiter, error) { func NewRedisRateLimiter(rdb *redis.Client) BackendRateLimiter {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
rdb := redis.NewClient(opts)
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
out := &RedisBackendRateLimiter{ out := &RedisBackendRateLimiter{
rdb: rdb, rdb: rdb,
randID: randStr(20), randID: randStr(20),
touchKeys: make(map[string]time.Duration), touchKeys: make(map[string]time.Duration),
} }
go out.touch() go out.touch()
return out, nil return out
} }
func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) { func (r *RedisBackendRateLimiter) IsBackendOnline(name string) (bool, error) {
......
...@@ -46,16 +46,8 @@ type redisCache struct { ...@@ -46,16 +46,8 @@ type redisCache struct {
rdb *redis.Client rdb *redis.Client
} }
func newRedisCache(url string) (*redisCache, error) { func newRedisCache(rdb *redis.Client) *redisCache {
opts, err := redis.ParseURL(url) return &redisCache{rdb}
if err != nil {
return nil, err
}
rdb := redis.NewClient(opts)
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
return &redisCache{rdb}, nil
} }
func (c *redisCache) Get(ctx context.Context, key string) (string, error) { func (c *redisCache) Get(ctx context.Context, key string) (string, error) {
......
...@@ -42,7 +42,8 @@ type MetricsConfig struct { ...@@ -42,7 +42,8 @@ type MetricsConfig struct {
type RateLimitConfig struct { type RateLimitConfig struct {
UseRedis bool `toml:"use_redis"` UseRedis bool `toml:"use_redis"`
RatePerSecond int `toml:"rate_per_second"` BaseRate int `toml:"base_rate"`
BaseInterval TOMLDuration `toml:"base_interval"`
ExemptOrigins []string `toml:"exempt_origins"` ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"` ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"` ErrorMessage string `toml:"error_message"`
......
...@@ -17,7 +17,7 @@ type FrontendRateLimiter interface { ...@@ -17,7 +17,7 @@ type FrontendRateLimiter interface {
// //
// No error will be returned if the limit could not be taken // No error will be returned if the limit could not be taken
// as a result of the requestor being over the limit. // as a result of the requestor being over the limit.
Take(ctx context.Context, key string, max int) (bool, error) Take(ctx context.Context, key string) (bool, error)
} }
// limitedKeys is a wrapper around a map that stores a truncated // limitedKeys is a wrapper around a map that stores a truncated
...@@ -58,16 +58,18 @@ func (l *limitedKeys) Take(key string, max int) bool { ...@@ -58,16 +58,18 @@ func (l *limitedKeys) Take(key string, max int) bool {
type MemoryFrontendRateLimiter struct { type MemoryFrontendRateLimiter struct {
currGeneration *limitedKeys currGeneration *limitedKeys
dur time.Duration dur time.Duration
max int
mtx sync.Mutex mtx sync.Mutex
} }
func NewMemoryFrontendRateLimit(dur time.Duration) FrontendRateLimiter { func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter {
return &MemoryFrontendRateLimiter{ return &MemoryFrontendRateLimiter{
dur: dur, dur: dur,
max: max,
} }
} }
func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
m.mtx.Lock() m.mtx.Lock()
// Create truncated timestamp // Create truncated timestamp
truncTS := truncateNow(m.dur) truncTS := truncateNow(m.dur)
...@@ -83,35 +85,51 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max in ...@@ -83,35 +85,51 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string, max in
m.mtx.Unlock() m.mtx.Unlock()
return limiter.Take(key, max), nil return limiter.Take(key, m.max), nil
} }
// RedisFrontendRateLimiter is a rate limiter that stores data in Redis. // RedisFrontendRateLimiter is a rate limiter that stores data in Redis.
// It uses the basic rate limiter pattern described on the Redis best // It uses the basic rate limiter pattern described on the Redis best
// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/. // practices website: https://redis.com/redis-best-practices/basic-rate-limiting/.
type RedisFrontendRateLimiter struct { type RedisFrontendRateLimiter struct {
r *redis.Client r *redis.Client
dur time.Duration dur time.Duration
max int
prefix string
} }
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration) FrontendRateLimiter { func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter {
return &RedisFrontendRateLimiter{r: r, dur: dur} return &RedisFrontendRateLimiter{
r: r,
dur: dur,
max: max,
prefix: prefix,
}
} }
func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string, max int) (bool, error) { func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
var incr *redis.IntCmd var incr *redis.IntCmd
truncTS := truncateNow(r.dur) truncTS := truncateNow(r.dur)
fullKey := fmt.Sprintf("%s:%d", key, truncTS) fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS)
_, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error { _, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, fullKey) incr = pipe.Incr(ctx, fullKey)
pipe.Expire(ctx, fullKey, r.dur-time.Second) pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond)
return nil return nil
}) })
if err != nil { if err != nil {
frontendRateLimitTakeErrors.Inc()
return false, err return false, err
} }
return incr.Val()-1 < int64(max), nil return incr.Val()-1 < int64(r.max), nil
}
type noopFrontendRateLimiter struct{}
var NoopFrontendRateLimiter = &noopFrontendRateLimiter{}
func (n *noopFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
return true, nil
} }
// truncateNow truncates the current timestamp // truncateNow truncates the current timestamp
......
...@@ -20,32 +20,32 @@ func TestFrontendRateLimiter(t *testing.T) { ...@@ -20,32 +20,32 @@ func TestFrontendRateLimiter(t *testing.T) {
Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()), Addr: fmt.Sprintf("127.0.0.1:%s", redisServer.Port()),
}) })
max := 2
lims := []struct { lims := []struct {
name string name string
frl FrontendRateLimiter frl FrontendRateLimiter
}{ }{
{"memory", NewMemoryFrontendRateLimit(2 * time.Second)}, {"memory", NewMemoryFrontendRateLimit(2*time.Second, max)},
{"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second)}, {"redis", NewRedisFrontendRateLimiter(redisClient, 2*time.Second, max, "")},
} }
max := 2
for _, cfg := range lims { for _, cfg := range lims {
frl := cfg.frl frl := cfg.frl
ctx := context.Background() ctx := context.Background()
t.Run(cfg.name, func(t *testing.T) { t.Run(cfg.name, func(t *testing.T) {
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
ok, err := frl.Take(ctx, "foo", max) ok, err := frl.Take(ctx, "foo")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, i < max, ok) require.Equal(t, i < max, ok)
ok, err = frl.Take(ctx, "bar", max) ok, err = frl.Take(ctx, "bar")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, i < max, ok) require.Equal(t, i < max, ok)
} }
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
ok, _ := frl.Take(ctx, "foo", max) ok, _ := frl.Take(ctx, "foo")
require.Equal(t, i < max, ok) require.Equal(t, i < max, ok)
ok, _ = frl.Take(ctx, "bar", max) ok, _ = frl.Take(ctx, "bar")
require.Equal(t, i < max, ok) require.Equal(t, i < max, ok)
} }
}) })
......
...@@ -261,6 +261,8 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) { ...@@ -261,6 +261,8 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
config.BackendOptions.MaxRetries = 2 config.BackendOptions.MaxRetries = 2
// Setup redis to detect offline backends // Setup redis to detect offline backends
config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port()) config.Redis.URL = fmt.Sprintf("redis://127.0.0.1:%s", redis.Port())
redisClient, err := proxyd.NewRedisClient(config.Redis.URL)
require.NoError(t, err)
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse))
defer goodBackend.Close() defer goodBackend.Close()
...@@ -285,7 +287,7 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) { ...@@ -285,7 +287,7 @@ func TestInfuraFailoverOnUnexpectedResponse(t *testing.T) {
require.Equal(t, 1, len(badBackend.Requests())) require.Equal(t, 1, len(badBackend.Requests()))
require.Equal(t, 1, len(goodBackend.Requests())) require.Equal(t, 1, len(goodBackend.Requests()))
rr, err := proxyd.NewRedisRateLimiter(config.Redis.URL) rr := proxyd.NewRedisRateLimiter(redisClient)
require.NoError(t, err) require.NoError(t, err)
online, err := rr.IsBackendOnline("bad") online, err := rr.IsBackendOnline("bad")
require.NoError(t, err) require.NoError(t, err)
......
...@@ -18,7 +18,8 @@ eth_chainId = "main" ...@@ -18,7 +18,8 @@ eth_chainId = "main"
eth_foobar = "main" eth_foobar = "main"
[rate_limit] [rate_limit]
rate_per_second = 2 base_rate = 2
base_interval = "1s"
exempt_origins = ["exempt_origin"] exempt_origins = ["exempt_origin"]
exempt_user_agents = ["exempt_agent"] exempt_user_agents = ["exempt_agent"]
error_message = "over rate limit with special message" error_message = "over rate limit with special message"
......
...@@ -236,6 +236,12 @@ var ( ...@@ -236,6 +236,12 @@ var (
100, 100,
}, },
}) })
frontendRateLimitTakeErrors = promauto.NewCounter(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "rate_limit_take_errors",
Help: "Count of errors taking frontend rate limits",
})
) )
func RecordRedisError(source string) { func RecordRedisError(source string) {
......
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
) )
...@@ -34,25 +35,29 @@ func Start(config *Config) (func(), error) { ...@@ -34,25 +35,29 @@ func Start(config *Config) (func(), error) {
} }
} }
var redisURL string var redisClient *redis.Client
if config.Redis.URL != "" { if config.Redis.URL != "" {
rURL, err := ReadFromEnvOrConfig(config.Redis.URL) rURL, err := ReadFromEnvOrConfig(config.Redis.URL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
redisURL = rURL redisClient, err = NewRedisClient(rURL)
if err != nil {
return nil, err
}
}
if redisClient == nil && config.RateLimit.UseRedis {
return nil, errors.New("must specify a Redis URL if UseRedis is true in rate limit config")
} }
var lim BackendRateLimiter var lim BackendRateLimiter
var err error var err error
if redisURL == "" { if redisClient == nil {
log.Warn("redis is not configured, using local rate limiter") log.Warn("redis is not configured, using local rate limiter")
lim = NewLocalBackendRateLimiter() lim = NewLocalBackendRateLimiter()
} else { } else {
lim, err = NewRedisRateLimiter(redisURL) lim = NewRedisRateLimiter(redisClient)
if err != nil {
return nil, err
}
} }
// While modifying shared globals is a bad practice, the alternative // While modifying shared globals is a bad practice, the alternative
...@@ -206,13 +211,11 @@ func Start(config *Config) (func(), error) { ...@@ -206,13 +211,11 @@ func Start(config *Config) (func(), error) {
return nil, err return nil, err
} }
if redisURL != "" { if redisClient == nil {
if cache, err = newRedisCache(redisURL); err != nil {
return nil, err
}
} else {
log.Warn("redis is not configured, using in-memory cache") log.Warn("redis is not configured, using in-memory cache")
cache = newMemoryCache() cache = newMemoryCache()
} else {
cache = newRedisCache(redisClient)
} }
// Ideally, the BlocKSyncRPCURL should be the sequencer or a HA replica that's not far behind // Ideally, the BlocKSyncRPCURL should be the sequencer or a HA replica that's not far behind
ethClient, err := ethclient.Dial(blockSyncRPCURL) ethClient, err := ethclient.Dial(blockSyncRPCURL)
...@@ -240,6 +243,7 @@ func Start(config *Config) (func(), error) { ...@@ -240,6 +243,7 @@ func Start(config *Config) (func(), error) {
config.Server.EnableRequestLog, config.Server.EnableRequestLog,
config.Server.MaxRequestBodyLogLen, config.Server.MaxRequestBodyLogLen,
config.BatchConfig.MaxSize, config.BatchConfig.MaxSize,
redisClient,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating server: %w", err) return nil, fmt.Errorf("error creating server: %w", err)
......
package proxyd
import (
"context"
"time"
"github.com/go-redis/redis/v8"
)
func NewRedisClient(url string) (*redis.Client, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
client := redis.NewClient(opts)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
return client, nil
}
...@@ -13,11 +13,8 @@ import ( ...@@ -13,11 +13,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/sethvargo/go-limiter"
"github.com/sethvargo/go-limiter/memorystore"
"github.com/sethvargo/go-limiter/noopstore"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
...@@ -50,9 +47,8 @@ type Server struct { ...@@ -50,9 +47,8 @@ type Server struct {
maxUpstreamBatchSize int maxUpstreamBatchSize int
maxBatchSize int maxBatchSize int
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
mainLim limiter.Store mainLim FrontendRateLimiter
overrideLims map[string]limiter.Store overrideLims map[string]FrontendRateLimiter
limConfig RateLimitConfig
limExemptOrigins map[string]bool limExemptOrigins map[string]bool
limExemptUserAgents map[string]bool limExemptUserAgents map[string]bool
rpcServer *http.Server rpcServer *http.Server
...@@ -77,6 +73,7 @@ func NewServer( ...@@ -77,6 +73,7 @@ func NewServer(
enableRequestLog bool, enableRequestLog bool,
maxRequestBodyLogLen int, maxRequestBodyLogLen int,
maxBatchSize int, maxBatchSize int,
redisClient *redis.Client,
) (*Server, error) { ) (*Server, error) {
if cache == nil { if cache == nil {
cache = &NoopRPCCache{} cache = &NoopRPCCache{}
...@@ -98,19 +95,19 @@ func NewServer( ...@@ -98,19 +95,19 @@ func NewServer(
maxBatchSize = MaxBatchRPCCallsHardLimit maxBatchSize = MaxBatchRPCCallsHardLimit
} }
var mainLim limiter.Store limiterFactory := func(dur time.Duration, max int, prefix string) FrontendRateLimiter {
limExemptOrigins := make(map[string]bool) if rateLimitConfig.UseRedis {
limExemptUserAgents := make(map[string]bool) return NewRedisFrontendRateLimiter(redisClient, dur, max, prefix)
if rateLimitConfig.RatePerSecond > 0 {
var err error
mainLim, err = memorystore.New(&memorystore.Config{
Tokens: uint64(rateLimitConfig.RatePerSecond),
Interval: time.Second,
})
if err != nil {
return nil, err
} }
return NewMemoryFrontendRateLimit(dur, max)
}
var mainLim FrontendRateLimiter
limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.BaseRate > 0 {
mainLim = limiterFactory(time.Duration(rateLimitConfig.BaseInterval), rateLimitConfig.BaseRate, "main")
for _, origin := range rateLimitConfig.ExemptOrigins { for _, origin := range rateLimitConfig.ExemptOrigins {
limExemptOrigins[strings.ToLower(origin)] = true limExemptOrigins[strings.ToLower(origin)] = true
} }
...@@ -118,16 +115,13 @@ func NewServer( ...@@ -118,16 +115,13 @@ func NewServer(
limExemptUserAgents[strings.ToLower(agent)] = true limExemptUserAgents[strings.ToLower(agent)] = true
} }
} else { } else {
mainLim, _ = noopstore.New() mainLim = NoopFrontendRateLimiter
} }
overrideLims := make(map[string]limiter.Store) overrideLims := make(map[string]FrontendRateLimiter)
for method, override := range rateLimitConfig.MethodOverrides { for method, override := range rateLimitConfig.MethodOverrides {
var err error var err error
overrideLims[method], err = memorystore.New(&memorystore.Config{ overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method)
Tokens: uint64(override.Limit),
Interval: time.Duration(override.Interval),
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -151,7 +145,6 @@ func NewServer( ...@@ -151,7 +145,6 @@ func NewServer(
}, },
mainLim: mainLim, mainLim: mainLim,
overrideLims: overrideLims, overrideLims: overrideLims,
limConfig: rateLimitConfig,
limExemptOrigins: limExemptOrigins, limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents, limExemptUserAgents: limExemptUserAgents,
}, nil }, nil
...@@ -235,7 +228,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ...@@ -235,7 +228,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return false return false
} }
var lim limiter.Store var lim FrontendRateLimiter
if method == "" { if method == "" {
lim = s.mainLim lim = s.mainLim
} else { } else {
...@@ -246,7 +239,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ...@@ -246,7 +239,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return false return false
} }
_, _, _, ok, _ := lim.Take(ctx, xff) ok, err := lim.Take(ctx, xff)
if err != nil {
log.Warn("error taking rate limit", "err", err)
return true
}
return !ok return !ok
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment