Commit 8f39b07e authored by Mark Tyneway's avatar Mark Tyneway Committed by GitHub

Merge pull request #1679 from mslipper/feat/proxyd-ws

go/proxyd: rate limiting and WS support
parents c0574243 6c50098b
---
'@eth-optimism/proxyd': major
---
Update metrics, support WS
This diff is collapsed.
......@@ -6,6 +6,10 @@ type ServerConfig struct {
MaxBodySizeBytes int64 `toml:"max_body_size_bytes"`
}
type RedisConfig struct {
URL string `toml:"url"`
}
type MetricsConfig struct {
Enabled bool `toml:"enabled"`
Host string `toml:"host"`
......@@ -13,16 +17,19 @@ type MetricsConfig struct {
}
type BackendOptions struct {
ResponseTimeoutSeconds int `toml:"response_timeout_seconds"`
MaxResponseSizeBytes int64 `toml:"max_response_size_bytes"`
MaxRetries int `toml:"backend_retries"`
UnhealthyBackendRetryIntervalSeconds int64 `toml:"unhealthy_backend_retry_interval_seconds"`
ResponseTimeoutSeconds int `toml:"response_timeout_seconds"`
MaxResponseSizeBytes int64 `toml:"max_response_size_bytes"`
MaxRetries int `toml:"backend_retries"`
OutOfServiceSeconds int `toml:"out_of_service_seconds"`
}
type BackendConfig struct {
Username string `toml:"username"`
Password string `toml:"password"`
BaseURL string `toml:"base_url"`
Username string `toml:"username"`
Password string `toml:"password"`
RPCURL string `toml:"rpc_url"`
WSURL string `toml:"ws_url"`
MaxRPS int `toml:"max_rps"`
MaxWSConns int `toml:"max_ws_conns"`
}
type BackendsConfig map[string]*BackendConfig
......@@ -36,10 +43,11 @@ type BackendGroupsConfig map[string]*BackendGroupConfig
type MethodMappingsConfig map[string]string
type Config struct {
Server *ServerConfig `toml:"server"`
Metrics *MetricsConfig `toml:"metrics"`
BackendOptions *BackendOptions `toml:"backend"`
Backends BackendsConfig `toml:"backends"`
BackendGroups BackendGroupsConfig `toml:"backend_groups"`
MethodMappings MethodMappingsConfig `toml:"method_mappings"`
AllowedRPCMethods []string `toml:"allowed_rpc_methods"`
AllowedWSMethods []string `toml:"allowed_ws_methods"`
Server *ServerConfig `toml:"server"`
Redis *RedisConfig `toml:"redis"`
Metrics *MetricsConfig `toml:"metrics"`
BackendOptions *BackendOptions `toml:"backend_options"`
Backends BackendsConfig `toml:"backends"`
}
......@@ -5,7 +5,9 @@ go 1.16
require (
github.com/BurntSushi/toml v0.4.1
github.com/ethereum/go-ethereum v1.10.11
github.com/go-redis/redis/v8 v8.11.4
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2
github.com/prometheus/client_golang v1.11.0
github.com/rs/cors v1.8.0
)
This diff is collapsed.
package proxyd
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"strconv"
)
const (
MetricsNamespace = "proxyd"
RPCRequestSourceHTTP = "http"
RPCRequestSourceWS = "ws"
SourceClient = "client"
SourceBackend = "backend"
SourceProxyd = "proxyd"
)
var (
rpcRequestsTotal = promauto.NewCounter(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "rpc_requests_total",
Help: "Count of total client RPC requests.",
})
rpcForwardsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "rpc_backend_requests_total",
Help: "Count of total RPC requests forwarded to each backend.",
}, []string{
"backend_name",
"method_name",
"source",
})
rpcErrorsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "rpc_errors_total",
Help: "Count of total RPC errors.",
}, []string{
"source",
"error_code",
})
rpcBackendRequestDurationSumm = promauto.NewSummaryVec(prometheus.SummaryOpts{
Namespace: MetricsNamespace,
Name: "rpc_backend_request_duration_seconds",
Help: "Summary of backend response times broken down by backend and method name.",
Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.95: 0.005, 0.99: 0.001},
}, []string{
"backend_name",
"method_name",
})
activeClientWsConnsGauge = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Name: "active_client_ws_conns",
Help: "Gauge of active client WS connections.",
})
activeBackendWsConnsGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Name: "active_backend_ws_conns",
Help: "Gauge of active backend WS connections.",
}, []string{
"backend_name",
})
unserviceableRequestsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "unserviceable_requests_total",
Help: "Count of total requests that were rejected due to no backends being available.",
}, []string{
"source",
})
httpRequestsTotal = promauto.NewCounter(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "http_requests_total",
Help: "Count of total HTTP requests.",
})
httpRequestDurationSumm = promauto.NewSummary(prometheus.SummaryOpts{
Namespace: MetricsNamespace,
Name: "http_request_duration_seconds",
Help: "Summary of HTTP request durations, in seconds.",
Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.95: 0.005, 0.99: 0.001},
})
wsMessagesTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "ws_messages_total",
Help: "Count of total websocket messages including protocol control.",
}, []string{
"backend_name",
"source",
})
redisErrorsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "redis_errors_total",
Help: "Count of total Redis errors.",
}, []string{
"source",
})
)
func RecordRPCError(source string, err error) {
rpcErr, ok := err.(*RPCErr)
var code int
if ok {
code = rpcErr.Code
} else {
code = -1
}
rpcErrorsTotal.WithLabelValues(source, strconv.Itoa(code)).Inc()
}
func RecordRedisError(source string) {
redisErrorsTotal.WithLabelValues(source).Inc()
}
func RecordWSMessage(backendName, source string) {
wsMessagesTotal.WithLabelValues(backendName, source).Inc()
}
func RecordUnserviceableRequest(source string) {
unserviceableRequestsTotal.WithLabelValues(source).Inc()
}
......@@ -13,28 +13,35 @@ import (
)
func Start(config *Config) error {
backendsByName := make(map[string]*Backend)
groupsByName := make(map[string]*BackendGroup)
if len(config.Backends) == 0 {
return errors.New("must define at least one backend")
}
if len(config.BackendGroups) == 0 {
return errors.New("must define at least one backend group")
if len(config.AllowedRPCMethods) == 0 {
return errors.New("must define at least one allowed RPC method")
}
if len(config.MethodMappings) == 0 {
return errors.New("must define at least one method mapping")
allowedRPCs := NewStringSetFromStrings(config.AllowedRPCMethods)
allowedWSRPCs := allowedRPCs.Extend(config.AllowedWSMethods)
redis, err := NewRedis(config.Redis.URL)
if err != nil {
return err
}
backends := make([]*Backend, 0)
backendNames := make([]string, 0)
for name, cfg := range config.Backends {
opts := make([]BackendOpt, 0)
if cfg.BaseURL == "" {
return fmt.Errorf("must define a base URL for backend %s", name)
if cfg.RPCURL == "" {
return fmt.Errorf("must define an RPC URL for backend %s", name)
}
if cfg.WSURL == "" {
return fmt.Errorf("must define a WS URL for backend %s", name)
}
if config.BackendOptions.ResponseTimeoutSeconds != 0 {
timeout := time.Duration(config.BackendOptions.ResponseTimeoutSeconds) * time.Second
timeout := secondsToDuration(config.BackendOptions.ResponseTimeoutSeconds)
opts = append(opts, WithTimeout(timeout))
}
if config.BackendOptions.MaxRetries != 0 {
......@@ -43,42 +50,29 @@ func Start(config *Config) error {
if config.BackendOptions.MaxResponseSizeBytes != 0 {
opts = append(opts, WithMaxResponseSize(config.BackendOptions.MaxResponseSizeBytes))
}
if config.BackendOptions.UnhealthyBackendRetryIntervalSeconds != 0 {
opts = append(opts, WithUnhealthyRetryInterval(config.BackendOptions.UnhealthyBackendRetryIntervalSeconds))
if config.BackendOptions.OutOfServiceSeconds != 0 {
opts = append(opts, WithOutOfServiceDuration(secondsToDuration(config.BackendOptions.OutOfServiceSeconds)))
}
if cfg.Password != "" {
opts = append(opts, WithBasicAuth(cfg.Username, cfg.Password))
if cfg.MaxRPS != 0 {
opts = append(opts, WithMaxRPS(cfg.MaxRPS))
}
backendsByName[name] = NewBackend(name, cfg.BaseURL, opts...)
log.Info("configured backend", "name", name, "base_url", cfg.BaseURL)
}
for groupName, cfg := range config.BackendGroups {
backs := make([]*Backend, 0)
for _, backName := range cfg.Backends {
if backendsByName[backName] == nil {
return fmt.Errorf("undefined backend %s", backName)
}
backs = append(backs, backendsByName[backName])
log.Info("configured backend group", "name", groupName)
if cfg.MaxWSConns != 0 {
opts = append(opts, WithMaxWSConns(cfg.MaxWSConns))
}
groupsByName[groupName] = &BackendGroup{
Name: groupName,
backends: backs,
if cfg.Password != "" {
opts = append(opts, WithBasicAuth(cfg.Username, cfg.Password))
}
back := NewBackend(name, cfg.RPCURL, cfg.WSURL, allowedRPCs, allowedWSRPCs, redis, opts...)
backends = append(backends, back)
backendNames = append(backendNames, name)
log.Info("configured backend", "name", name, "rpc_url", cfg.RPCURL, "ws_url", cfg.WSURL)
}
mappings := make(map[string]*BackendGroup)
for method, groupName := range config.MethodMappings {
if groupsByName[groupName] == nil {
return fmt.Errorf("undefined backend group %s", groupName)
}
mappings[method] = groupsByName[groupName]
backendGroup := &BackendGroup{
Name: "main",
Backends: backends,
}
methodMappings := NewMethodMapping(mappings)
srv := NewServer(methodMappings, config.Server.MaxBodySizeBytes)
srv := NewServer(backendGroup, config.Server.MaxBodySizeBytes)
if config.Metrics.Enabled {
addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port)
......@@ -88,6 +82,10 @@ func Start(config *Config) error {
go func() {
if err := srv.ListenAndServe(config.Server.Host, config.Server.Port); err != nil {
if errors.Is(err, http.ErrServerClosed) {
log.Info("server shut down")
return
}
log.Crit("error starting server", "err", err)
}
}()
......@@ -96,5 +94,13 @@ func Start(config *Config) error {
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
recvSig := <-sig
log.Info("caught signal, shutting down", "signal", recvSig)
srv.Shutdown()
if err := redis.FlushBackendWSConns(backendNames); err != nil {
log.Error("error flushing backend ws conns", "err", err)
}
return nil
}
func secondsToDuration(seconds int) time.Duration {
return time.Duration(seconds) * time.Second
}
package proxyd
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"github.com/ethereum/go-ethereum/log"
"github.com/go-redis/redis/v8"
"sync"
"time"
)
const MaxRPSScript = `
local current
current = redis.call("incr", KEYS[1])
if current == 1 then
redis.call("expire", KEYS[1], 1)
end
return current
`
const MaxConcurrentWSConnsScript = `
redis.call("sadd", KEYS[1], KEYS[2])
local total = 0
local scanres = redis.call("sscan", KEYS[1], 0)
for _, k in ipairs(scanres[2]) do
local value = redis.call("get", k)
if value then
total = total + value
end
end
if total < tonumber(ARGV[1]) then
redis.call("incr", KEYS[2])
redis.call("expire", KEYS[2], 300)
return true
end
return false
`
type Redis interface {
IsBackendOnline(name string) (bool, error)
SetBackendOffline(name string, duration time.Duration) error
IncBackendRPS(name string) (int, error)
IncBackendWSConns(name string, max int) (bool, error)
DecBackendWSConns(name string) error
FlushBackendWSConns(names []string) error
}
type RedisImpl struct {
rdb *redis.Client
randID string
touchKeys map[string]time.Duration
tkMtx sync.Mutex
}
func NewRedis(url string) (Redis, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
rdb := redis.NewClient(opts)
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
out := &RedisImpl{
rdb: rdb,
randID: randStr(20),
touchKeys: make(map[string]time.Duration),
}
go out.touch()
return out, nil
}
func (r *RedisImpl) IsBackendOnline(name string) (bool, error) {
exists, err := r.rdb.Exists(context.Background(), fmt.Sprintf("backend:%s:offline", name)).Result()
if err != nil {
RecordRedisError("IsBackendOnline")
return false, wrapErr(err, "error getting backend availability")
}
return exists == 0, nil
}
func (r *RedisImpl) SetBackendOffline(name string, duration time.Duration) error {
err := r.rdb.SetEX(
context.Background(),
fmt.Sprintf("backend:%s:offline", name),
1,
duration,
).Err()
if err != nil {
RecordRedisError("SetBackendOffline")
return wrapErr(err, "error setting backend unavailable")
}
return nil
}
func (r *RedisImpl) IncBackendRPS(name string) (int, error) {
cmd := r.rdb.Eval(
context.Background(),
MaxRPSScript,
[]string{fmt.Sprintf("backend:%s:ratelimit", name)},
)
rps, err := cmd.Int()
if err != nil {
RecordRedisError("IncBackendRPS")
return -1, wrapErr(err, "error upserting backend rate limit")
}
return rps, nil
}
func (r *RedisImpl) IncBackendWSConns(name string, max int) (bool, error) {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
r.tkMtx.Lock()
r.touchKeys[connsKey] = 5 * time.Minute
r.tkMtx.Unlock()
cmd := r.rdb.Eval(
context.Background(),
MaxConcurrentWSConnsScript,
[]string{
fmt.Sprintf("backend:%s:proxies", name),
connsKey,
},
max,
)
incremented, err := cmd.Bool()
// false gets coerced to redis.nil, see https://redis.io/commands/eval#conversion-between-lua-and-redis-data-types
if err == redis.Nil {
return false, nil
}
if err != nil {
RecordRedisError("IncBackendWSConns")
return false, wrapErr(err, "error incrementing backend ws conns")
}
return incremented, nil
}
func (r *RedisImpl) DecBackendWSConns(name string) error {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
err := r.rdb.Decr(context.Background(), connsKey).Err()
if err != nil {
RecordRedisError("DecBackendWSConns")
return wrapErr(err, "error decrementing backend ws conns")
}
return nil
}
func (r *RedisImpl) FlushBackendWSConns(names []string) error {
ctx := context.Background()
for _, name := range names {
connsKey := fmt.Sprintf("proxy:%s:wsconns:%s", r.randID, name)
err := r.rdb.SRem(
ctx,
fmt.Sprintf("backend:%s:proxies", name),
connsKey,
).Err()
if err != nil {
return wrapErr(err, "error flushing backend ws conns")
}
err = r.rdb.Del(ctx, connsKey).Err()
if err != nil {
return wrapErr(err, "error flushing backend ws conns")
}
}
return nil
}
func (r *RedisImpl) touch() {
for {
r.tkMtx.Lock()
for key, dur := range r.touchKeys {
if err := r.rdb.Expire(context.Background(), key, dur).Err(); err != nil {
RecordRedisError("touch")
log.Error("error touching redis key", "key", key, "err", err)
}
}
r.tkMtx.Unlock()
time.Sleep(5 * time.Second)
}
}
func randStr(l int) string {
b := make([]byte, l)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return hex.EncodeToString(b)
}
package proxyd
import (
"encoding/json"
"io"
"io/ioutil"
)
type RPCReq struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
ID *int `json:"id"`
}
type RPCRes struct {
JSONRPC string `json:"jsonrpc"`
Result interface{} `json:"result,omitempty"`
Error *RPCErr `json:"error,omitempty"`
ID *int `json:"id"`
}
func (r *RPCRes) IsError() bool {
return r.Error != nil
}
type RPCErr struct {
Code int `json:"code"`
Message string `json:"message"`
}
func (r *RPCErr) Error() string {
return r.Message
}
func ParseRPCReq(r io.Reader) (*RPCReq, error) {
body, err := ioutil.ReadAll(r)
if err != nil {
return nil, wrapErr(err, "error reading request body")
}
req := new(RPCReq)
if err := json.Unmarshal(body, req); err != nil {
return nil, ErrParseErr
}
if req.JSONRPC != JSONRPCVersion {
return nil, ErrInvalidRequest
}
if req.Method == "" {
return nil, ErrInvalidRequest
}
return req, nil
}
func ParseRPCRes(r io.Reader) (*RPCRes, error) {
body, err := ioutil.ReadAll(r)
if err != nil {
return nil, wrapErr(err, "error reading RPC response")
}
res := new(RPCRes)
if err := json.Unmarshal(body, res); err != nil {
return nil, wrapErr(err, "error unmarshaling RPC response")
}
return res, nil
}
func NewRPCErrorRes(id *int, err error) *RPCRes {
var rpcErr *RPCErr
if rr, ok := err.(*RPCErr); ok {
rpcErr = rr
} else {
rpcErr = &RPCErr{
Code: JSONRPCErrorInternal,
Message: err.Error(),
}
}
return &RPCRes{
JSONRPC: JSONRPCVersion,
Error: rpcErr,
ID: id,
}
}
package proxyd
import (
"encoding/json"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/log"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/rs/cors"
"io"
"io/ioutil"
"net/http"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/log"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
"github.com/rs/cors"
"io"
"net/http"
"time"
)
var (
httpRequestsCtr = promauto.NewCounter(prometheus.CounterOpts{
Namespace: "proxyd",
Name: "http_requests_total",
Help: "Count of total HTTP requests.",
})
httpRequestDurationSummary = promauto.NewSummary(prometheus.SummaryOpts{
Namespace: "proxyd",
Name: "http_request_duration_seconds",
Help: "Summary of HTTP request durations, in seconds.",
Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.95: 0.005, 0.99: 0.001},
})
rpcRequestsCtr = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "proxyd",
Name: "rpc_requests_total",
Help: "Count of RPC requests.",
}, []string{
"method_name",
})
blockedRPCsCtr = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "proxyd",
Name: "blocked_rpc_requests_total",
Help: "Count of blocked RPC requests.",
}, []string{
"method_name",
})
rpcErrorsCtr = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "proxyd",
Name: "rpc_errors_total",
Help: "Count of RPC errors.",
}, []string{
"error_code",
})
)
type RPCReq struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
ID *int `json:"id"`
}
type RPCRes struct {
JSONRPC string `json:"jsonrpc"`
Result interface{} `json:"result,omitempty"`
Error *RPCErr `json:"error,omitempty"`
ID *int `json:"id"`
}
type RPCErr struct {
Code int `json:"code"`
Message string `json:"message"`
}
type Server struct {
mappings *MethodMapping
backends *BackendGroup
maxBodySize int64
upgrader *websocket.Upgrader
server *http.Server
}
func NewServer(mappings *MethodMapping, maxBodySize int64) *Server {
func NewServer(
backends *BackendGroup,
maxBodySize int64,
) *Server {
return &Server{
mappings: mappings,
backends: backends,
maxBodySize: maxBodySize,
upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
},
}
}
......@@ -88,16 +39,21 @@ func (s *Server) ListenAndServe(host string, port int) error {
hdlr := mux.NewRouter()
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/ws", s.HandleWS)
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
})
addr := fmt.Sprintf("%s:%d", host, port)
server := &http.Server{
s.server = &http.Server{
Handler: instrumentedHdlr(c.Handler(hdlr)),
Addr: addr,
}
log.Info("starting HTTP server", "addr", addr)
return server.ListenAndServe()
return s.server.ListenAndServe()
}
func (s *Server) Shutdown() {
s.server.Shutdown(context.Background())
}
func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
......@@ -105,77 +61,93 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(io.LimitReader(r.Body, s.maxBodySize))
req, err := ParseRPCReq(io.LimitReader(r.Body, s.maxBodySize))
if err != nil {
log.Error("error reading request body", "err", err)
rpcErrorsCtr.WithLabelValues("-32700").Inc()
writeRPCError(w, nil, -32700, "could not read request body")
log.Info("rejected request with bad rpc request", "source", "rpc", "err", err)
RecordRPCError(SourceClient, err)
writeRPCError(w, nil, err)
return
}
req := new(RPCReq)
if err := json.Unmarshal(body, req); err != nil {
rpcErrorsCtr.WithLabelValues("-32700").Inc()
writeRPCError(w, nil, -32700, "invalid JSON")
backendRes, err := s.backends.Forward(req)
if err != nil {
if errors.Is(err, ErrNoBackends) {
RecordUnserviceableRequest(RPCRequestSourceHTTP)
RecordRPCError(SourceProxyd, err)
} else if errors.Is(err, ErrMethodNotWhitelisted) {
RecordRPCError(SourceClient, err)
} else {
RecordRPCError(SourceBackend, err)
}
log.Error("error forwarding RPC request", "method", req.Method, "err", err)
writeRPCError(w, req.ID, err)
return
}
if backendRes.IsError() {
RecordRPCError(SourceBackend, backendRes.Error)
}
if req.JSONRPC != JSONRPCVersion {
rpcErrorsCtr.WithLabelValues("-32600").Inc()
writeRPCError(w, nil, -32600, "invalid json-rpc version")
enc := json.NewEncoder(w)
if err := enc.Encode(backendRes); err != nil {
log.Error("error encoding response", "err", err)
RecordRPCError(SourceProxyd, err)
writeRPCError(w, req.ID, err)
return
}
group, err := s.mappings.BackendGroupFor(req.Method)
log.Debug("forwarded RPC method", "method", req.Method)
}
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
clientConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
rpcErrorsCtr.WithLabelValues("-32601").Inc()
blockedRPCsCtr.WithLabelValues(req.Method).Inc()
log.Info("blocked request for non-whitelisted method", "method", req.Method)
writeRPCError(w, req.ID, -32601, "method not found")
log.Error("error upgrading client conn", "err", err)
return
}
backendRes, err := group.Forward(req)
proxier, err := s.backends.ProxyWS(clientConn)
if err != nil {
log.Error("error forwarding RPC request", "group", group.Name, "method", req.Method, "err", err)
rpcErrorsCtr.WithLabelValues("-32603").Inc()
msg := "error fetching data from upstream"
if errors.Is(err, ErrBackendsInconsistent) {
msg = ErrBackendsInconsistent.Error()
if errors.Is(err, ErrNoBackends) {
RecordUnserviceableRequest(RPCRequestSourceWS)
}
writeRPCError(w, req.ID, -32603, msg)
log.Error("error dialing ws backend", "err", err)
clientConn.Close()
return
}
enc := json.NewEncoder(w)
if err := enc.Encode(backendRes); err != nil {
log.Error("error encoding response", "err", err)
return
}
rpcRequestsCtr.WithLabelValues(req.Method).Inc()
log.Debug("forwarded RPC method", "method", req.Method, "group", group.Name)
activeClientWsConnsGauge.Inc()
go func() {
// Below call blocks so run it in a goroutine.
if err := proxier.Proxy(); err != nil {
log.Error("error proxying websocket", "err", err)
}
activeClientWsConnsGauge.Dec()
}()
}
func writeRPCError(w http.ResponseWriter, id *int, code int, msg string) {
func writeRPCError(w http.ResponseWriter, id *int, err error) {
enc := json.NewEncoder(w)
w.WriteHeader(200)
body := &RPCRes{
ID: id,
Error: &RPCErr{
Code: code,
Message: msg,
},
var body *RPCRes
if r, ok := err.(*RPCErr); ok {
body = NewRPCErrorRes(id, r)
} else {
body = NewRPCErrorRes(id, &RPCErr{
Code: JSONRPCErrorInternal,
Message: "internal error",
})
}
if err := enc.Encode(body); err != nil {
log.Error("error writing RPC error", "err", err)
log.Error("error writing rpc error", "err", err)
}
}
func instrumentedHdlr(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
httpRequestsCtr.Inc()
timer := prometheus.NewTimer(httpRequestDurationSummary)
defer timer.ObserveDuration()
httpRequestsTotal.Inc()
respTimer := prometheus.NewTimer(httpRequestDurationSumm)
h.ServeHTTP(w, r)
respTimer.ObserveDuration()
}
}
package proxyd
import "sync"
type StringSet struct {
underlying map[string]bool
mtx sync.RWMutex
}
func NewStringSet() *StringSet {
return &StringSet{
underlying: make(map[string]bool),
}
}
func NewStringSetFromStrings(in []string) *StringSet {
underlying := make(map[string]bool)
for _, str := range in {
underlying[str] = true
}
return &StringSet{
underlying: underlying,
}
}
func (s *StringSet) Has(test string) bool {
s.mtx.RLock()
defer s.mtx.RUnlock()
return s.underlying[test]
}
func (s *StringSet) Add(str string) {
s.mtx.Lock()
defer s.mtx.Unlock()
s.underlying[str] = true
}
func (s *StringSet) Entries() []string {
s.mtx.RLock()
defer s.mtx.RUnlock()
out := make([]string, len(s.underlying))
var i int
for entry := range s.underlying {
out[i] = entry
i++
}
return out
}
func (s *StringSet) Extend(in []string) *StringSet {
out := NewStringSetFromStrings(in)
for k := range s.underlying {
out.Add(k)
}
return out
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment