Commit dbb3ae87 authored by Matthew Slipper's avatar Matthew Slipper Committed by GitHub

Merge pull request #1701 from mslipper/feat/proxyd-auth

go/proxyd: ENG-1596 support authenticated client requests
parents 8f39b07e 9ba4c5e0
---
'@eth-optimism/proxyd': minor
---
Update metrics, support authenticated endpoints
...@@ -11,45 +11,13 @@ This tool implements `proxyd`, an RPC request router and proxy. It does the foll ...@@ -11,45 +11,13 @@ This tool implements `proxyd`, an RPC request router and proxy. It does the foll
Run `make proxyd` to build the binary. No additional dependencies are necessary. Run `make proxyd` to build the binary. No additional dependencies are necessary.
To configure `proxyd` for use, you'll need to create a configuration file to define your proxy backends and routing rules. An example config that routes `eth_chainId` between Infura and Alchemy is below: To configure `proxyd` for use, you'll need to create a configuration file to define your proxy backends and routing rules. Check out [example.config.toml](./example.config.toml) for how to do this alongside a full list of all options with commentary.
```toml
[backends]
[backends.infura]
base_url = "url-here"
[backends.alchemy]
base_url = "url-here"
[backend_groups]
[backend_groups.main]
backends = ["infura", "alchemy"]
[method_mappings]
eth_chainId = "main"
```
Check out [example.config.toml](./example.config.toml) for a full list of all options with commentary.
Once you have a config file, start the daemon via `proxyd <path-to-config>.toml`. Once you have a config file, start the daemon via `proxyd <path-to-config>.toml`.
## Metrics ## Metrics
The following Prometheus metrics are exported: See `metrics.go` for a list of all available metrics.
| Name | Description | Flags |
|------------------------------------------------|-------------------------------------------------------------------------------------------------|----------------------------------------|
| `proxyd_backend_requests_total` | Count of all successful requests to a backend. | backend_name: The name of the backend. |
| `proxyd_backend_errors_total` | Count of all backend errors. | backend_name: The name of the backend |
| `proxyd_http_requests_total` | Count of all HTTP requests, successful or not. | |
| `proxyd_http_request_duration_histogram_seconds` | Histogram of HTTP request durations. | |
| `proxyd_rpc_requests_total` | Count of all RPC requests. | method_name: The RPC method requested. |
| `proxyd_blocked_rpc_requests_total` | Count of all RPC requests with a blacklisted method. | method_name: The RPC method requested. |
| `proxyd_rpc_errors_total` | Count of all RPC errors. **NOTE:** Does not include errors sent from the backend to the client. |
The metrics port is configurable via the `metrics.port` and `metrics.host` keys in the config. The metrics port is configurable via the `metrics.port` and `metrics.host` keys in the config.
## Errata
- RPC errors originating from the backend (e.g., any backend response containing an `error` key) are passed on to the client directly. This simplifies the code and avoids having to marshal/unmarshal the backend's response JSON.
- Requests are distributed round-robin between backends in a group.
\ No newline at end of file
...@@ -2,6 +2,7 @@ package proxyd ...@@ -2,6 +2,7 @@ package proxyd
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -149,7 +150,7 @@ func NewBackend( ...@@ -149,7 +150,7 @@ func NewBackend(
return backend return backend
} }
func (b *Backend) Forward(req *RPCReq) (*RPCRes, error) { func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) {
if !b.allowedRPCMethods.Has(req.Method) { if !b.allowedRPCMethods.Has(req.Method) {
return nil, ErrMethodNotWhitelisted return nil, ErrMethodNotWhitelisted
} }
...@@ -164,7 +165,7 @@ func (b *Backend) Forward(req *RPCReq) (*RPCRes, error) { ...@@ -164,7 +165,7 @@ func (b *Backend) Forward(req *RPCReq) (*RPCRes, error) {
// <= to account for the first attempt not technically being // <= to account for the first attempt not technically being
// a retry // a retry
for i := 0; i <= b.maxRetries; i++ { for i := 0; i <= b.maxRetries; i++ {
rpcForwardsTotal.WithLabelValues(b.Name, req.Method, RPCRequestSourceHTTP).Inc() RecordRPCForward(ctx, b.Name, req.Method, RPCRequestSourceHTTP)
respTimer := prometheus.NewTimer(rpcBackendRequestDurationSumm.WithLabelValues(b.Name, req.Method)) respTimer := prometheus.NewTimer(rpcBackendRequestDurationSumm.WithLabelValues(b.Name, req.Method))
resB, err := b.doForward(req) resB, err := b.doForward(req)
if err != nil { if err != nil {
...@@ -305,11 +306,11 @@ type BackendGroup struct { ...@@ -305,11 +306,11 @@ type BackendGroup struct {
Backends []*Backend Backends []*Backend
} }
func (b *BackendGroup) Forward(rpcReq *RPCReq) (*RPCRes, error) { func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, error) {
rpcRequestsTotal.Inc() rpcRequestsTotal.Inc()
for _, back := range b.Backends { for _, back := range b.Backends {
res, err := back.Forward(rpcReq) res, err := back.Forward(ctx, rpcReq)
if errors.Is(err, ErrMethodNotWhitelisted) { if errors.Is(err, ErrMethodNotWhitelisted) {
return nil, err return nil, err
} }
...@@ -364,7 +365,7 @@ type WSProxier struct { ...@@ -364,7 +365,7 @@ type WSProxier struct {
backendConn *websocket.Conn backendConn *websocket.Conn
} }
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, ) *WSProxier { func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn) *WSProxier {
return &WSProxier{ return &WSProxier{
backend: backend, backend: backend,
clientConn: clientConn, clientConn: clientConn,
...@@ -372,16 +373,16 @@ func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, ) * ...@@ -372,16 +373,16 @@ func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, ) *
} }
} }
func (w *WSProxier) Proxy() error { func (w *WSProxier) Proxy(ctx context.Context) error {
errC := make(chan error, 2) errC := make(chan error, 2)
go w.clientPump(errC) go w.clientPump(ctx, errC)
go w.backendPump(errC) go w.backendPump(ctx, errC)
err := <-errC err := <-errC
w.close() w.close()
return err return err
} }
func (w *WSProxier) clientPump(errC chan error) { func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
for { for {
outConn := w.backendConn outConn := w.backendConn
// Block until we get a message. // Block until we get a message.
...@@ -392,7 +393,7 @@ func (w *WSProxier) clientPump(errC chan error) { ...@@ -392,7 +393,7 @@ func (w *WSProxier) clientPump(errC chan error) {
return return
} }
RecordWSMessage(w.backend.Name, SourceClient) RecordWSMessage(ctx, w.backend.Name, SourceClient)
// Route control messages to the backend. These don't // Route control messages to the backend. These don't
// count towards the total RPC requests count. // count towards the total RPC requests count.
...@@ -417,9 +418,9 @@ func (w *WSProxier) clientPump(errC chan error) { ...@@ -417,9 +418,9 @@ func (w *WSProxier) clientPump(errC chan error) {
} }
outConn = w.clientConn outConn = w.clientConn
msg = mustMarshalJSON(NewRPCErrorRes(id, err)) msg = mustMarshalJSON(NewRPCErrorRes(id, err))
RecordRPCError(SourceClient, err) RecordRPCError(ctx, SourceClient, err)
} else { } else {
rpcForwardsTotal.WithLabelValues(w.backend.Name, req.Method, RPCRequestSourceWS).Inc() RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
} }
err = outConn.WriteMessage(msgType, msg) err = outConn.WriteMessage(msgType, msg)
...@@ -430,7 +431,7 @@ func (w *WSProxier) clientPump(errC chan error) { ...@@ -430,7 +431,7 @@ func (w *WSProxier) clientPump(errC chan error) {
} }
} }
func (w *WSProxier) backendPump(errC chan error) { func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
for { for {
// Block until we get a message. // Block until we get a message.
msgType, msg, err := w.backendConn.ReadMessage() msgType, msg, err := w.backendConn.ReadMessage()
...@@ -440,7 +441,7 @@ func (w *WSProxier) backendPump(errC chan error) { ...@@ -440,7 +441,7 @@ func (w *WSProxier) backendPump(errC chan error) {
return return
} }
RecordWSMessage(w.backend.Name, SourceBackend) RecordWSMessage(ctx, w.backend.Name, SourceBackend)
// Route control messages directly to the client. // Route control messages directly to the client.
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
...@@ -461,7 +462,7 @@ func (w *WSProxier) backendPump(errC chan error) { ...@@ -461,7 +462,7 @@ func (w *WSProxier) backendPump(errC chan error) {
msg = mustMarshalJSON(NewRPCErrorRes(id, err)) msg = mustMarshalJSON(NewRPCErrorRes(id, err))
} }
if res.IsError() { if res.IsError() {
RecordRPCError(SourceBackend, res.Error) RecordRPCError(ctx, SourceBackend, res.Error)
} }
err = w.clientConn.WriteMessage(msgType, msg) err = w.clientConn.WriteMessage(msgType, msg)
......
...@@ -50,4 +50,5 @@ type Config struct { ...@@ -50,4 +50,5 @@ type Config struct {
Metrics *MetricsConfig `toml:"metrics"` Metrics *MetricsConfig `toml:"metrics"`
BackendOptions *BackendOptions `toml:"backend_options"` BackendOptions *BackendOptions `toml:"backend_options"`
Backends BackendsConfig `toml:"backends"` Backends BackendsConfig `toml:"backends"`
Authentication map[string]string `toml:"authentication"`
} }
# List of allowed RPC methods.
allowed_rpc_methods = [
"eth_call",
"eth_blockNumber",
"eth_gasPrice",
"eth_chainId"
]
# list of allowed WS methods. Will be combined with allowed_rpc_methods.
allowed_ws_methods = [
"eth_subscribe"
]
[server] [server]
# Host for the proxyd server to listen on. # Host for the proxyd server to listen on.
host = "0.0.0.0" host = "0.0.0.0"
...@@ -6,6 +19,10 @@ port = 8080 ...@@ -6,6 +19,10 @@ port = 8080
# Maximum client body size, in bytes, that the server will accept. # Maximum client body size, in bytes, that the server will accept.
max_body_size_bytes = 10485760 max_body_size_bytes = 10485760
[redis]
# URL to a Redis instance.
url = "redis://localhost:6379"
[metrics] [metrics]
# Whether or not to enable Prometheus metrics. # Whether or not to enable Prometheus metrics.
enabled = true enabled = true
...@@ -20,9 +37,9 @@ response_timeout_seconds = 5 ...@@ -20,9 +37,9 @@ response_timeout_seconds = 5
# Maximum response size, in bytes, that proxyd will accept from a backend. # Maximum response size, in bytes, that proxyd will accept from a backend.
max_response_size_bytes = 5242880 max_response_size_bytes = 5242880
# Maximum number of times proxyd will try a backend before giving up. # Maximum number of times proxyd will try a backend before giving up.
max_retries = 0 max_retries = 3
# Number of seconds to wait before trying an unhealthy backend again. # Number of seconds to wait before trying an unhealthy backend again.
unhealthy_backend_retry_interval_seconds = 600 out_of_service_seconds = 600
[backends] [backends]
# A map of backends by name. # A map of backends by name.
...@@ -33,15 +50,16 @@ base_url = "url-here" ...@@ -33,15 +50,16 @@ base_url = "url-here"
username = "" username = ""
# HTTP basic auth password to use with the backend. # HTTP basic auth password to use with the backend.
password = "" password = ""
# Maximum RPC requests per second before rate limiting.
# This number is global across multiple proxyd instances.
max_rps = 3
# Maximum number of concurrent WS connections before dropping them.
# This number is global across multiple proxyd instances.
max_ws_conns = 1
[backend_groups] # If the authentication group below is in the config,
# A map of backend groups by name. # proxyd will only accept authenticated requests.
[backend_groups.main] [authentication]
# A list of backend names to place in the group. # Mapping of auth key to alias. The alias is used to provide a human-
backends = ["infura", "alchemy"] # readable name for the auth key in monitoring.
secret = "uniswap"
[method_mappings] \ No newline at end of file
# A mapping between RPC methods and the backend groups that should serve them.
eth_call = "main"
eth_chainId = "main"
# other mappings go here
package proxyd package proxyd
import ( import (
"context"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
"strconv" "strconv"
...@@ -26,9 +27,10 @@ var ( ...@@ -26,9 +27,10 @@ var (
rpcForwardsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ rpcForwardsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace, Namespace: MetricsNamespace,
Name: "rpc_backend_requests_total", Name: "rpc_forwards_total",
Help: "Count of total RPC requests forwarded to each backend.", Help: "Count of total RPC requests forwarded to each backend.",
}, []string{ }, []string{
"auth",
"backend_name", "backend_name",
"method_name", "method_name",
"source", "source",
...@@ -39,6 +41,7 @@ var ( ...@@ -39,6 +41,7 @@ var (
Name: "rpc_errors_total", Name: "rpc_errors_total",
Help: "Count of total RPC errors.", Help: "Count of total RPC errors.",
}, []string{ }, []string{
"auth",
"source", "source",
"error_code", "error_code",
}) })
...@@ -53,10 +56,12 @@ var ( ...@@ -53,10 +56,12 @@ var (
"method_name", "method_name",
}) })
activeClientWsConnsGauge = promauto.NewGauge(prometheus.GaugeOpts{ activeClientWsConnsGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: MetricsNamespace, Namespace: MetricsNamespace,
Name: "active_client_ws_conns", Name: "active_client_ws_conns",
Help: "Gauge of active client WS connections.", Help: "Gauge of active client WS connections.",
}, []string{
"auth",
}) })
activeBackendWsConnsGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{ activeBackendWsConnsGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{
...@@ -72,6 +77,7 @@ var ( ...@@ -72,6 +77,7 @@ var (
Name: "unserviceable_requests_total", Name: "unserviceable_requests_total",
Help: "Count of total requests that were rejected due to no backends being available.", Help: "Count of total requests that were rejected due to no backends being available.",
}, []string{ }, []string{
"auth",
"source", "source",
}) })
...@@ -93,6 +99,7 @@ var ( ...@@ -93,6 +99,7 @@ var (
Name: "ws_messages_total", Name: "ws_messages_total",
Help: "Count of total websocket messages including protocol control.", Help: "Count of total websocket messages including protocol control.",
}, []string{ }, []string{
"auth",
"backend_name", "backend_name",
"source", "source",
}) })
...@@ -106,7 +113,11 @@ var ( ...@@ -106,7 +113,11 @@ var (
}) })
) )
func RecordRPCError(source string, err error) { func RecordRedisError(source string) {
redisErrorsTotal.WithLabelValues(source).Inc()
}
func RecordRPCError(ctx context.Context, source string, err error) {
rpcErr, ok := err.(*RPCErr) rpcErr, ok := err.(*RPCErr)
var code int var code int
if ok { if ok {
...@@ -115,17 +126,17 @@ func RecordRPCError(source string, err error) { ...@@ -115,17 +126,17 @@ func RecordRPCError(source string, err error) {
code = -1 code = -1
} }
rpcErrorsTotal.WithLabelValues(source, strconv.Itoa(code)).Inc() rpcErrorsTotal.WithLabelValues(GetAuthCtx(ctx), source, strconv.Itoa(code)).Inc()
} }
func RecordRedisError(source string) { func RecordWSMessage(ctx context.Context, backendName, source string) {
redisErrorsTotal.WithLabelValues(source).Inc() wsMessagesTotal.WithLabelValues(GetAuthCtx(ctx), backendName, source).Inc()
} }
func RecordWSMessage(backendName, source string) { func RecordUnserviceableRequest(ctx context.Context, source string) {
wsMessagesTotal.WithLabelValues(backendName, source).Inc() unserviceableRequestsTotal.WithLabelValues(GetAuthCtx(ctx), source).Inc()
} }
func RecordUnserviceableRequest(source string) { func RecordRPCForward(ctx context.Context, backendName, method, source string) {
unserviceableRequestsTotal.WithLabelValues(source).Inc() rpcForwardsTotal.WithLabelValues(GetAuthCtx(ctx), backendName, method, source).Inc()
} }
...@@ -20,6 +20,12 @@ func Start(config *Config) error { ...@@ -20,6 +20,12 @@ func Start(config *Config) error {
return errors.New("must define at least one allowed RPC method") return errors.New("must define at least one allowed RPC method")
} }
for authKey := range config.Authentication {
if authKey == "none" {
return errors.New("cannot use none as an auth key")
}
}
allowedRPCs := NewStringSetFromStrings(config.AllowedRPCMethods) allowedRPCs := NewStringSetFromStrings(config.AllowedRPCMethods)
allowedWSRPCs := allowedRPCs.Extend(config.AllowedWSMethods) allowedWSRPCs := allowedRPCs.Extend(config.AllowedWSMethods)
...@@ -72,7 +78,11 @@ func Start(config *Config) error { ...@@ -72,7 +78,11 @@ func Start(config *Config) error {
Name: "main", Name: "main",
Backends: backends, Backends: backends,
} }
srv := NewServer(backendGroup, config.Server.MaxBodySizeBytes) srv := NewServer(
backendGroup,
config.Server.MaxBodySizeBytes,
config.Authentication,
)
if config.Metrics.Enabled { if config.Metrics.Enabled {
addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port) addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port)
......
...@@ -15,9 +15,14 @@ import ( ...@@ -15,9 +15,14 @@ import (
"time" "time"
) )
const (
ContextKeyAuth = "authorization"
)
type Server struct { type Server struct {
backends *BackendGroup backends *BackendGroup
maxBodySize int64 maxBodySize int64
authenticatedPaths map[string]string
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
server *http.Server server *http.Server
} }
...@@ -25,10 +30,12 @@ type Server struct { ...@@ -25,10 +30,12 @@ type Server struct {
func NewServer( func NewServer(
backends *BackendGroup, backends *BackendGroup,
maxBodySize int64, maxBodySize int64,
authenticatedPaths map[string]string,
) *Server { ) *Server {
return &Server{ return &Server{
backends: backends, backends: backends,
maxBodySize: maxBodySize, maxBodySize: maxBodySize,
authenticatedPaths: authenticatedPaths,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
}, },
...@@ -38,8 +45,10 @@ func NewServer( ...@@ -38,8 +45,10 @@ func NewServer(
func (s *Server) ListenAndServe(host string, port int) error { func (s *Server) ListenAndServe(host string, port int) error {
hdlr := mux.NewRouter() hdlr := mux.NewRouter()
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET") hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST") hdlr.HandleFunc("/api/v1/rpc", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/ws", s.HandleWS) hdlr.HandleFunc("/api/v1/{authorization}/rpc", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/api/v1/ws", s.HandleWS)
hdlr.HandleFunc("/api/v1/{authorization}/ws", s.HandleWS)
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
}) })
...@@ -61,36 +70,41 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) { ...@@ -61,36 +70,41 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
ctx := s.authenticate(w, r)
if ctx == nil {
return
}
req, err := ParseRPCReq(io.LimitReader(r.Body, s.maxBodySize)) req, err := ParseRPCReq(io.LimitReader(r.Body, s.maxBodySize))
if err != nil { if err != nil {
log.Info("rejected request with bad rpc request", "source", "rpc", "err", err) log.Info("rejected request with bad rpc request", "source", "rpc", "err", err)
RecordRPCError(SourceClient, err) RecordRPCError(ctx, SourceClient, err)
writeRPCError(w, nil, err) writeRPCError(w, nil, err)
return return
} }
backendRes, err := s.backends.Forward(req) backendRes, err := s.backends.Forward(ctx, req)
if err != nil { if err != nil {
if errors.Is(err, ErrNoBackends) { if errors.Is(err, ErrNoBackends) {
RecordUnserviceableRequest(RPCRequestSourceHTTP) RecordUnserviceableRequest(ctx, RPCRequestSourceHTTP)
RecordRPCError(SourceProxyd, err) RecordRPCError(ctx, SourceProxyd, err)
} else if errors.Is(err, ErrMethodNotWhitelisted) { } else if errors.Is(err, ErrMethodNotWhitelisted) {
RecordRPCError(SourceClient, err) RecordRPCError(ctx, SourceClient, err)
} else { } else {
RecordRPCError(SourceBackend, err) RecordRPCError(ctx, SourceBackend, err)
} }
log.Error("error forwarding RPC request", "method", req.Method, "err", err) log.Error("error forwarding RPC request", "method", req.Method, "err", err)
writeRPCError(w, req.ID, err) writeRPCError(w, req.ID, err)
return return
} }
if backendRes.IsError() { if backendRes.IsError() {
RecordRPCError(SourceBackend, backendRes.Error) RecordRPCError(ctx, SourceBackend, backendRes.Error)
} }
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
if err := enc.Encode(backendRes); err != nil { if err := enc.Encode(backendRes); err != nil {
log.Error("error encoding response", "err", err) log.Error("error encoding response", "err", err)
RecordRPCError(SourceProxyd, err) RecordRPCError(ctx, SourceProxyd, err)
writeRPCError(w, req.ID, err) writeRPCError(w, req.ID, err)
return return
} }
...@@ -99,6 +113,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ...@@ -99,6 +113,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
ctx := s.authenticate(w, r)
if ctx == nil {
return
}
clientConn, err := s.upgrader.Upgrade(w, r, nil) clientConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Error("error upgrading client conn", "err", err) log.Error("error upgrading client conn", "err", err)
...@@ -108,23 +127,45 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { ...@@ -108,23 +127,45 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
proxier, err := s.backends.ProxyWS(clientConn) proxier, err := s.backends.ProxyWS(clientConn)
if err != nil { if err != nil {
if errors.Is(err, ErrNoBackends) { if errors.Is(err, ErrNoBackends) {
RecordUnserviceableRequest(RPCRequestSourceWS) RecordUnserviceableRequest(ctx, RPCRequestSourceWS)
} }
log.Error("error dialing ws backend", "err", err) log.Error("error dialing ws backend", "err", err)
clientConn.Close() clientConn.Close()
return return
} }
activeClientWsConnsGauge.Inc() activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Inc()
go func() { go func() {
// Below call blocks so run it in a goroutine. // Below call blocks so run it in a goroutine.
if err := proxier.Proxy(); err != nil { if err := proxier.Proxy(ctx); err != nil {
log.Error("error proxying websocket", "err", err) log.Error("error proxying websocket", "err", err)
} }
activeClientWsConnsGauge.Dec() activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Dec()
}() }()
} }
func (s *Server) authenticate(w http.ResponseWriter, r *http.Request) context.Context {
vars := mux.Vars(r)
authorization := vars["authorization"]
if s.authenticatedPaths == nil {
// handle the edge case where auth is disabled
// but someone sends in an auth key anyway
if authorization != "" {
w.WriteHeader(404)
return nil
}
return r.Context()
}
if authorization == "" || s.authenticatedPaths[authorization] == "" {
w.WriteHeader(401)
return nil
}
return context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization])
}
func writeRPCError(w http.ResponseWriter, id *int, err error) { func writeRPCError(w http.ResponseWriter, id *int, err error) {
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
w.WriteHeader(200) w.WriteHeader(200)
...@@ -151,3 +192,12 @@ func instrumentedHdlr(h http.Handler) http.HandlerFunc { ...@@ -151,3 +192,12 @@ func instrumentedHdlr(h http.Handler) http.HandlerFunc {
respTimer.ObserveDuration() respTimer.ObserveDuration()
} }
} }
func GetAuthCtx(ctx context.Context) string {
authUser, ok := ctx.Value(ContextKeyAuth).(string)
if !ok {
return "none"
}
return authUser
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment