Commit dbb3ae87 authored by Matthew Slipper's avatar Matthew Slipper Committed by GitHub

Merge pull request #1701 from mslipper/feat/proxyd-auth

go/proxyd: ENG-1596 support authenticated client requests
parents 8f39b07e 9ba4c5e0
---
'@eth-optimism/proxyd': minor
---
Update metrics, support authenticated endpoints
......@@ -11,45 +11,13 @@ This tool implements `proxyd`, an RPC request router and proxy. It does the foll
Run `make proxyd` to build the binary. No additional dependencies are necessary.
To configure `proxyd` for use, you'll need to create a configuration file to define your proxy backends and routing rules. An example config that routes `eth_chainId` between Infura and Alchemy is below:
```toml
[backends]
[backends.infura]
base_url = "url-here"
[backends.alchemy]
base_url = "url-here"
[backend_groups]
[backend_groups.main]
backends = ["infura", "alchemy"]
[method_mappings]
eth_chainId = "main"
```
Check out [example.config.toml](./example.config.toml) for a full list of all options with commentary.
To configure `proxyd` for use, you'll need to create a configuration file to define your proxy backends and routing rules. Check out [example.config.toml](./example.config.toml) for how to do this alongside a full list of all options with commentary.
Once you have a config file, start the daemon via `proxyd <path-to-config>.toml`.
## Metrics
The following Prometheus metrics are exported:
| Name | Description | Flags |
|------------------------------------------------|-------------------------------------------------------------------------------------------------|----------------------------------------|
| `proxyd_backend_requests_total` | Count of all successful requests to a backend. | backend_name: The name of the backend. |
| `proxyd_backend_errors_total` | Count of all backend errors. | backend_name: The name of the backend |
| `proxyd_http_requests_total` | Count of all HTTP requests, successful or not. | |
| `proxyd_http_request_duration_histogram_seconds` | Histogram of HTTP request durations. | |
| `proxyd_rpc_requests_total` | Count of all RPC requests. | method_name: The RPC method requested. |
| `proxyd_blocked_rpc_requests_total` | Count of all RPC requests with a blacklisted method. | method_name: The RPC method requested. |
| `proxyd_rpc_errors_total` | Count of all RPC errors. **NOTE:** Does not include errors sent from the backend to the client. |
See `metrics.go` for a list of all available metrics.
The metrics port is configurable via the `metrics.port` and `metrics.host` keys in the config.
## Errata
- RPC errors originating from the backend (e.g., any backend response containing an `error` key) are passed on to the client directly. This simplifies the code and avoids having to marshal/unmarshal the backend's response JSON.
- Requests are distributed round-robin between backends in a group.
\ No newline at end of file
......@@ -2,6 +2,7 @@ package proxyd
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
......@@ -149,7 +150,7 @@ func NewBackend(
return backend
}
func (b *Backend) Forward(req *RPCReq) (*RPCRes, error) {
func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) {
if !b.allowedRPCMethods.Has(req.Method) {
return nil, ErrMethodNotWhitelisted
}
......@@ -164,7 +165,7 @@ func (b *Backend) Forward(req *RPCReq) (*RPCRes, error) {
// <= to account for the first attempt not technically being
// a retry
for i := 0; i <= b.maxRetries; i++ {
rpcForwardsTotal.WithLabelValues(b.Name, req.Method, RPCRequestSourceHTTP).Inc()
RecordRPCForward(ctx, b.Name, req.Method, RPCRequestSourceHTTP)
respTimer := prometheus.NewTimer(rpcBackendRequestDurationSumm.WithLabelValues(b.Name, req.Method))
resB, err := b.doForward(req)
if err != nil {
......@@ -305,11 +306,11 @@ type BackendGroup struct {
Backends []*Backend
}
func (b *BackendGroup) Forward(rpcReq *RPCReq) (*RPCRes, error) {
func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, error) {
rpcRequestsTotal.Inc()
for _, back := range b.Backends {
res, err := back.Forward(rpcReq)
res, err := back.Forward(ctx, rpcReq)
if errors.Is(err, ErrMethodNotWhitelisted) {
return nil, err
}
......@@ -364,7 +365,7 @@ type WSProxier struct {
backendConn *websocket.Conn
}
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, ) *WSProxier {
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn) *WSProxier {
return &WSProxier{
backend: backend,
clientConn: clientConn,
......@@ -372,16 +373,16 @@ func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, ) *
}
}
func (w *WSProxier) Proxy() error {
func (w *WSProxier) Proxy(ctx context.Context) error {
errC := make(chan error, 2)
go w.clientPump(errC)
go w.backendPump(errC)
go w.clientPump(ctx, errC)
go w.backendPump(ctx, errC)
err := <-errC
w.close()
return err
}
func (w *WSProxier) clientPump(errC chan error) {
func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
for {
outConn := w.backendConn
// Block until we get a message.
......@@ -392,7 +393,7 @@ func (w *WSProxier) clientPump(errC chan error) {
return
}
RecordWSMessage(w.backend.Name, SourceClient)
RecordWSMessage(ctx, w.backend.Name, SourceClient)
// Route control messages to the backend. These don't
// count towards the total RPC requests count.
......@@ -417,9 +418,9 @@ func (w *WSProxier) clientPump(errC chan error) {
}
outConn = w.clientConn
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
RecordRPCError(SourceClient, err)
RecordRPCError(ctx, SourceClient, err)
} else {
rpcForwardsTotal.WithLabelValues(w.backend.Name, req.Method, RPCRequestSourceWS).Inc()
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
}
err = outConn.WriteMessage(msgType, msg)
......@@ -430,7 +431,7 @@ func (w *WSProxier) clientPump(errC chan error) {
}
}
func (w *WSProxier) backendPump(errC chan error) {
func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
for {
// Block until we get a message.
msgType, msg, err := w.backendConn.ReadMessage()
......@@ -440,7 +441,7 @@ func (w *WSProxier) backendPump(errC chan error) {
return
}
RecordWSMessage(w.backend.Name, SourceBackend)
RecordWSMessage(ctx, w.backend.Name, SourceBackend)
// Route control messages directly to the client.
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
......@@ -461,7 +462,7 @@ func (w *WSProxier) backendPump(errC chan error) {
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
}
if res.IsError() {
RecordRPCError(SourceBackend, res.Error)
RecordRPCError(ctx, SourceBackend, res.Error)
}
err = w.clientConn.WriteMessage(msgType, msg)
......
......@@ -43,11 +43,12 @@ type BackendGroupsConfig map[string]*BackendGroupConfig
type MethodMappingsConfig map[string]string
type Config struct {
AllowedRPCMethods []string `toml:"allowed_rpc_methods"`
AllowedWSMethods []string `toml:"allowed_ws_methods"`
Server *ServerConfig `toml:"server"`
Redis *RedisConfig `toml:"redis"`
Metrics *MetricsConfig `toml:"metrics"`
BackendOptions *BackendOptions `toml:"backend_options"`
Backends BackendsConfig `toml:"backends"`
AllowedRPCMethods []string `toml:"allowed_rpc_methods"`
AllowedWSMethods []string `toml:"allowed_ws_methods"`
Server *ServerConfig `toml:"server"`
Redis *RedisConfig `toml:"redis"`
Metrics *MetricsConfig `toml:"metrics"`
BackendOptions *BackendOptions `toml:"backend_options"`
Backends BackendsConfig `toml:"backends"`
Authentication map[string]string `toml:"authentication"`
}
# List of allowed RPC methods.
allowed_rpc_methods = [
"eth_call",
"eth_blockNumber",
"eth_gasPrice",
"eth_chainId"
]
# list of allowed WS methods. Will be combined with allowed_rpc_methods.
allowed_ws_methods = [
"eth_subscribe"
]
[server]
# Host for the proxyd server to listen on.
host = "0.0.0.0"
......@@ -6,6 +19,10 @@ port = 8080
# Maximum client body size, in bytes, that the server will accept.
max_body_size_bytes = 10485760
[redis]
# URL to a Redis instance.
url = "redis://localhost:6379"
[metrics]
# Whether or not to enable Prometheus metrics.
enabled = true
......@@ -20,9 +37,9 @@ response_timeout_seconds = 5
# Maximum response size, in bytes, that proxyd will accept from a backend.
max_response_size_bytes = 5242880
# Maximum number of times proxyd will try a backend before giving up.
max_retries = 0
max_retries = 3
# Number of seconds to wait before trying an unhealthy backend again.
unhealthy_backend_retry_interval_seconds = 600
out_of_service_seconds = 600
[backends]
# A map of backends by name.
......@@ -33,15 +50,16 @@ base_url = "url-here"
username = ""
# HTTP basic auth password to use with the backend.
password = ""
# Maximum RPC requests per second before rate limiting.
# This number is global across multiple proxyd instances.
max_rps = 3
# Maximum number of concurrent WS connections before dropping them.
# This number is global across multiple proxyd instances.
max_ws_conns = 1
[backend_groups]
# A map of backend groups by name.
[backend_groups.main]
# A list of backend names to place in the group.
backends = ["infura", "alchemy"]
[method_mappings]
# A mapping between RPC methods and the backend groups that should serve them.
eth_call = "main"
eth_chainId = "main"
# other mappings go here
# If the authentication group below is in the config,
# proxyd will only accept authenticated requests.
[authentication]
# Mapping of auth key to alias. The alias is used to provide a human-
# readable name for the auth key in monitoring.
secret = "uniswap"
\ No newline at end of file
package proxyd
import (
"context"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"strconv"
......@@ -26,9 +27,10 @@ var (
rpcForwardsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "rpc_backend_requests_total",
Name: "rpc_forwards_total",
Help: "Count of total RPC requests forwarded to each backend.",
}, []string{
"auth",
"backend_name",
"method_name",
"source",
......@@ -39,6 +41,7 @@ var (
Name: "rpc_errors_total",
Help: "Count of total RPC errors.",
}, []string{
"auth",
"source",
"error_code",
})
......@@ -53,10 +56,12 @@ var (
"method_name",
})
activeClientWsConnsGauge = promauto.NewGauge(prometheus.GaugeOpts{
activeClientWsConnsGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: MetricsNamespace,
Name: "active_client_ws_conns",
Help: "Gauge of active client WS connections.",
}, []string{
"auth",
})
activeBackendWsConnsGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{
......@@ -72,6 +77,7 @@ var (
Name: "unserviceable_requests_total",
Help: "Count of total requests that were rejected due to no backends being available.",
}, []string{
"auth",
"source",
})
......@@ -93,6 +99,7 @@ var (
Name: "ws_messages_total",
Help: "Count of total websocket messages including protocol control.",
}, []string{
"auth",
"backend_name",
"source",
})
......@@ -106,7 +113,11 @@ var (
})
)
func RecordRPCError(source string, err error) {
func RecordRedisError(source string) {
redisErrorsTotal.WithLabelValues(source).Inc()
}
func RecordRPCError(ctx context.Context, source string, err error) {
rpcErr, ok := err.(*RPCErr)
var code int
if ok {
......@@ -115,17 +126,17 @@ func RecordRPCError(source string, err error) {
code = -1
}
rpcErrorsTotal.WithLabelValues(source, strconv.Itoa(code)).Inc()
rpcErrorsTotal.WithLabelValues(GetAuthCtx(ctx), source, strconv.Itoa(code)).Inc()
}
func RecordRedisError(source string) {
redisErrorsTotal.WithLabelValues(source).Inc()
func RecordWSMessage(ctx context.Context, backendName, source string) {
wsMessagesTotal.WithLabelValues(GetAuthCtx(ctx), backendName, source).Inc()
}
func RecordWSMessage(backendName, source string) {
wsMessagesTotal.WithLabelValues(backendName, source).Inc()
func RecordUnserviceableRequest(ctx context.Context, source string) {
unserviceableRequestsTotal.WithLabelValues(GetAuthCtx(ctx), source).Inc()
}
func RecordUnserviceableRequest(source string) {
unserviceableRequestsTotal.WithLabelValues(source).Inc()
func RecordRPCForward(ctx context.Context, backendName, method, source string) {
rpcForwardsTotal.WithLabelValues(GetAuthCtx(ctx), backendName, method, source).Inc()
}
......@@ -20,6 +20,12 @@ func Start(config *Config) error {
return errors.New("must define at least one allowed RPC method")
}
for authKey := range config.Authentication {
if authKey == "none" {
return errors.New("cannot use none as an auth key")
}
}
allowedRPCs := NewStringSetFromStrings(config.AllowedRPCMethods)
allowedWSRPCs := allowedRPCs.Extend(config.AllowedWSMethods)
......@@ -72,7 +78,11 @@ func Start(config *Config) error {
Name: "main",
Backends: backends,
}
srv := NewServer(backendGroup, config.Server.MaxBodySizeBytes)
srv := NewServer(
backendGroup,
config.Server.MaxBodySizeBytes,
config.Authentication,
)
if config.Metrics.Enabled {
addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port)
......
......@@ -15,20 +15,27 @@ import (
"time"
)
const (
ContextKeyAuth = "authorization"
)
type Server struct {
backends *BackendGroup
maxBodySize int64
upgrader *websocket.Upgrader
server *http.Server
backends *BackendGroup
maxBodySize int64
authenticatedPaths map[string]string
upgrader *websocket.Upgrader
server *http.Server
}
func NewServer(
backends *BackendGroup,
maxBodySize int64,
authenticatedPaths map[string]string,
) *Server {
return &Server{
backends: backends,
maxBodySize: maxBodySize,
backends: backends,
maxBodySize: maxBodySize,
authenticatedPaths: authenticatedPaths,
upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
},
......@@ -38,8 +45,10 @@ func NewServer(
func (s *Server) ListenAndServe(host string, port int) error {
hdlr := mux.NewRouter()
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/ws", s.HandleWS)
hdlr.HandleFunc("/api/v1/rpc", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/api/v1/{authorization}/rpc", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/api/v1/ws", s.HandleWS)
hdlr.HandleFunc("/api/v1/{authorization}/ws", s.HandleWS)
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
})
......@@ -61,36 +70,41 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
ctx := s.authenticate(w, r)
if ctx == nil {
return
}
req, err := ParseRPCReq(io.LimitReader(r.Body, s.maxBodySize))
if err != nil {
log.Info("rejected request with bad rpc request", "source", "rpc", "err", err)
RecordRPCError(SourceClient, err)
RecordRPCError(ctx, SourceClient, err)
writeRPCError(w, nil, err)
return
}
backendRes, err := s.backends.Forward(req)
backendRes, err := s.backends.Forward(ctx, req)
if err != nil {
if errors.Is(err, ErrNoBackends) {
RecordUnserviceableRequest(RPCRequestSourceHTTP)
RecordRPCError(SourceProxyd, err)
RecordUnserviceableRequest(ctx, RPCRequestSourceHTTP)
RecordRPCError(ctx, SourceProxyd, err)
} else if errors.Is(err, ErrMethodNotWhitelisted) {
RecordRPCError(SourceClient, err)
RecordRPCError(ctx, SourceClient, err)
} else {
RecordRPCError(SourceBackend, err)
RecordRPCError(ctx, SourceBackend, err)
}
log.Error("error forwarding RPC request", "method", req.Method, "err", err)
writeRPCError(w, req.ID, err)
return
}
if backendRes.IsError() {
RecordRPCError(SourceBackend, backendRes.Error)
RecordRPCError(ctx, SourceBackend, backendRes.Error)
}
enc := json.NewEncoder(w)
if err := enc.Encode(backendRes); err != nil {
log.Error("error encoding response", "err", err)
RecordRPCError(SourceProxyd, err)
RecordRPCError(ctx, SourceProxyd, err)
writeRPCError(w, req.ID, err)
return
}
......@@ -99,6 +113,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
ctx := s.authenticate(w, r)
if ctx == nil {
return
}
clientConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Error("error upgrading client conn", "err", err)
......@@ -108,23 +127,45 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
proxier, err := s.backends.ProxyWS(clientConn)
if err != nil {
if errors.Is(err, ErrNoBackends) {
RecordUnserviceableRequest(RPCRequestSourceWS)
RecordUnserviceableRequest(ctx, RPCRequestSourceWS)
}
log.Error("error dialing ws backend", "err", err)
clientConn.Close()
return
}
activeClientWsConnsGauge.Inc()
activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Inc()
go func() {
// Below call blocks so run it in a goroutine.
if err := proxier.Proxy(); err != nil {
if err := proxier.Proxy(ctx); err != nil {
log.Error("error proxying websocket", "err", err)
}
activeClientWsConnsGauge.Dec()
activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Dec()
}()
}
func (s *Server) authenticate(w http.ResponseWriter, r *http.Request) context.Context {
vars := mux.Vars(r)
authorization := vars["authorization"]
if s.authenticatedPaths == nil {
// handle the edge case where auth is disabled
// but someone sends in an auth key anyway
if authorization != "" {
w.WriteHeader(404)
return nil
}
return r.Context()
}
if authorization == "" || s.authenticatedPaths[authorization] == "" {
w.WriteHeader(401)
return nil
}
return context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization])
}
func writeRPCError(w http.ResponseWriter, id *int, err error) {
enc := json.NewEncoder(w)
w.WriteHeader(200)
......@@ -151,3 +192,12 @@ func instrumentedHdlr(h http.Handler) http.HandlerFunc {
respTimer.ObserveDuration()
}
}
func GetAuthCtx(ctx context.Context) string {
authUser, ok := ctx.Value(ContextKeyAuth).(string)
if !ok {
return "none"
}
return authUser
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment