Commit 4096cd39 authored by Mark Tyneway's avatar Mark Tyneway Committed by GitHub

Merge pull request #1730 from mslipper/feat/alt-ws-server

go/proxyd: Make endpoints match Geth, better logging
parents e4ada1dd abe231bf
---
'@eth-optimism/proxyd': major
---
Make endpoints match Geth, better logging
...@@ -163,7 +163,12 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) { ...@@ -163,7 +163,12 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) {
res, err := b.doForward(req) res, err := b.doForward(req)
if err != nil { if err != nil {
lastError = err lastError = err
log.Warn("backend request failed, trying again", "err", err, "name", b.Name) log.Warn(
"backend request failed, trying again",
"name", b.Name,
"req_id", GetReqID(ctx),
"err", err,
)
respTimer.ObserveDuration() respTimer.ObserveDuration()
RecordRPCError(ctx, b.Name, req.Method, err) RecordRPCError(ctx, b.Name, req.Method, err)
time.Sleep(calcBackoff(i)) time.Sleep(calcBackoff(i))
...@@ -172,6 +177,20 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) { ...@@ -172,6 +177,20 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) {
respTimer.ObserveDuration() respTimer.ObserveDuration()
if res.IsError() { if res.IsError() {
RecordRPCError(ctx, b.Name, req.Method, res.Error) RecordRPCError(ctx, b.Name, req.Method, res.Error)
log.Info(
"backend responded with RPC error",
"code", res.Error.Code,
"msg", res.Error.Message,
"req_id", GetReqID(ctx),
"source", "rpc",
"auth", GetAuthCtx(ctx),
)
} else {
log.Info("forwarded RPC request",
"method", req.Method,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
} }
return res, nil return res, nil
} }
...@@ -313,15 +332,31 @@ func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, er ...@@ -313,15 +332,31 @@ func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, er
return nil, err return nil, err
} }
if errors.Is(err, ErrBackendOffline) { if errors.Is(err, ErrBackendOffline) {
log.Debug("skipping offline backend", "name", back.Name) log.Warn(
"skipping offline backend",
"name", back.Name,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
continue continue
} }
if errors.Is(err, ErrBackendOverCapacity) { if errors.Is(err, ErrBackendOverCapacity) {
log.Debug("skipping over-capacity backend", "name", back.Name) log.Warn(
"skipping over-capacity backend",
"name", back.Name,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
continue continue
} }
if err != nil { if err != nil {
log.Error("error forwarding request to backend", "err", err, "name", b.Name) log.Error(
"error forwarding request to backend",
"name", b.Name,
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
"err", err,
)
continue continue
} }
return res, nil return res, nil
...@@ -331,19 +366,35 @@ func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, er ...@@ -331,19 +366,35 @@ func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, er
return nil, ErrNoBackends return nil, ErrNoBackends
} }
func (b *BackendGroup) ProxyWS(clientConn *websocket.Conn, methodWhitelist *StringSet) (*WSProxier, error) { func (b *BackendGroup) ProxyWS(ctx context.Context, clientConn *websocket.Conn, methodWhitelist *StringSet) (*WSProxier, error) {
for _, back := range b.Backends { for _, back := range b.Backends {
proxier, err := back.ProxyWS(clientConn, methodWhitelist) proxier, err := back.ProxyWS(clientConn, methodWhitelist)
if errors.Is(err, ErrBackendOffline) { if errors.Is(err, ErrBackendOffline) {
log.Debug("skipping offline backend", "name", back.Name) log.Warn(
"skipping offline backend",
"name", back.Name,
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
)
continue continue
} }
if errors.Is(err, ErrBackendOverCapacity) { if errors.Is(err, ErrBackendOverCapacity) {
log.Debug("skipping over-capacity backend", "name", back.Name) log.Warn(
"skipping over-capacity backend",
"name", back.Name,
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
)
continue continue
} }
if err != nil { if err != nil {
log.Warn("error dialing ws backend", "name", back.Name, "err", err) log.Warn(
"error dialing ws backend",
"name", back.Name,
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
"err", err,
)
continue continue
} }
return proxier, nil return proxier, nil
...@@ -411,7 +462,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { ...@@ -411,7 +462,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
// Don't bother sending invalid requests to the backend, // Don't bother sending invalid requests to the backend,
// just handle them here. // just handle them here.
req, err := w.parseClientMsg(msg) req, err := w.prepareClientMsg(msg)
if err != nil { if err != nil {
var id *int var id *int
method := MethodUnknown method := MethodUnknown
...@@ -419,11 +470,23 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) { ...@@ -419,11 +470,23 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
id = req.ID id = req.ID
method = req.Method method = req.Method
} }
log.Info(
"error preparing client message",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
"err", err,
)
outConn = w.clientConn outConn = w.clientConn
msg = mustMarshalJSON(NewRPCErrorRes(id, err)) msg = mustMarshalJSON(NewRPCErrorRes(id, err))
RecordRPCError(ctx, BackendProxyd, method, err) RecordRPCError(ctx, BackendProxyd, method, err)
} else { } else {
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS) RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
log.Info(
"forwarded WS message to backend",
"method", req.Method,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
} }
err = outConn.WriteMessage(msgType, msg) err = outConn.WriteMessage(msgType, msg)
...@@ -465,7 +528,21 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) { ...@@ -465,7 +528,21 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
msg = mustMarshalJSON(NewRPCErrorRes(id, err)) msg = mustMarshalJSON(NewRPCErrorRes(id, err))
} }
if res.IsError() { if res.IsError() {
log.Info(
"backend responded with RPC error",
"code", res.Error.Code,
"msg", res.Error.Message,
"source", "ws",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error) RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error)
} else {
log.Info(
"forwarded WS message to client",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
} }
err = w.clientConn.WriteMessage(msgType, msg) err = w.clientConn.WriteMessage(msgType, msg)
...@@ -485,15 +562,13 @@ func (w *WSProxier) close() { ...@@ -485,15 +562,13 @@ func (w *WSProxier) close() {
activeBackendWsConnsGauge.WithLabelValues(w.backend.Name).Dec() activeBackendWsConnsGauge.WithLabelValues(w.backend.Name).Dec()
} }
func (w *WSProxier) parseClientMsg(msg []byte) (*RPCReq, error) { func (w *WSProxier) prepareClientMsg(msg []byte) (*RPCReq, error) {
req, err := ParseRPCReq(bytes.NewReader(msg)) req, err := ParseRPCReq(bytes.NewReader(msg))
if err != nil { if err != nil {
log.Warn("error parsing RPC request", "source", "ws", "err", err)
return nil, err return nil, err
} }
if !w.methodWhitelist.Has(req.Method) { if !w.methodWhitelist.Has(req.Method) {
log.Info("blocked request for non-whitelisted method", "source", "ws", "method", req.Method)
return req, ErrMethodNotWhitelisted return req, ErrMethodNotWhitelisted
} }
......
package proxyd package proxyd
type ServerConfig struct { type ServerConfig struct {
Host string `toml:"host"` RPCHost string `toml:"rpc_host"`
Port int `toml:"port"` RPCPort int `toml:"rpc_port"`
WSHost string `toml:"ws_host"`
WSPort int `toml:"ws_port"`
MaxBodySizeBytes int64 `toml:"max_body_size_bytes"` MaxBodySizeBytes int64 `toml:"max_body_size_bytes"`
} }
...@@ -36,7 +38,6 @@ type BackendsConfig map[string]*BackendConfig ...@@ -36,7 +38,6 @@ type BackendsConfig map[string]*BackendConfig
type BackendGroupConfig struct { type BackendGroupConfig struct {
Backends []string `toml:"backends"` Backends []string `toml:"backends"`
WSEnabled bool `toml:"ws_enabled"`
} }
type BackendGroupsConfig map[string]*BackendGroupConfig type BackendGroupsConfig map[string]*BackendGroupConfig
...@@ -44,6 +45,7 @@ type BackendGroupsConfig map[string]*BackendGroupConfig ...@@ -44,6 +45,7 @@ type BackendGroupsConfig map[string]*BackendGroupConfig
type MethodMappingsConfig map[string]string type MethodMappingsConfig map[string]string
type Config struct { type Config struct {
WSBackendGroup string `toml:"ws_backend_group"`
Server *ServerConfig `toml:"server"` Server *ServerConfig `toml:"server"`
Redis *RedisConfig `toml:"redis"` Redis *RedisConfig `toml:"redis"`
Metrics *MetricsConfig `toml:"metrics"` Metrics *MetricsConfig `toml:"metrics"`
......
...@@ -4,12 +4,18 @@ ws_method_whitelist = [ ...@@ -4,12 +4,18 @@ ws_method_whitelist = [
"eth_call", "eth_call",
"eth_chainId" "eth_chainId"
] ]
# Enable WS on this backend group. There can only be one WS-enabled backend group.
ws_backend_group = "main"
[server] [server]
# Host for the proxyd server to listen on. # Host for the proxyd RPC server to listen on.
host = "0.0.0.0" rpc_host = "0.0.0.0"
# Port for the above. # Port for the above.
port = 8080 rpc_port = 8080
# Host for the proxyd WS server to listen on.
ws_host = "0.0.0.0"
# Port for the above
ws_port = 8085
# Maximum client body size, in bytes, that the server will accept. # Maximum client body size, in bytes, that the server will accept.
max_body_size_bytes = 10485760 max_body_size_bytes = 10485760
...@@ -59,8 +65,6 @@ max_ws_conns = 1 ...@@ -59,8 +65,6 @@ max_ws_conns = 1
[backend_groups] [backend_groups]
[backend_groups.main] [backend_groups.main]
backends = ["infura"] backends = ["infura"]
# Enable WS on this backend group. There can only be one WS-enabled backend group.
ws_enabled = true
[backend_groups.alchemy] [backend_groups.alchemy]
backends = ["alchemy"] backends = ["alchemy"]
......
...@@ -75,7 +75,6 @@ func Start(config *Config) error { ...@@ -75,7 +75,6 @@ func Start(config *Config) error {
} }
backendGroups := make(map[string]*BackendGroup) backendGroups := make(map[string]*BackendGroup)
var wsBackendGroup *BackendGroup
for bgName, bg := range config.BackendGroups { for bgName, bg := range config.BackendGroups {
backends := make([]*Backend, 0) backends := make([]*Backend, 0)
for _, bName := range bg.Backends { for _, bName := range bg.Backends {
...@@ -89,14 +88,20 @@ func Start(config *Config) error { ...@@ -89,14 +88,20 @@ func Start(config *Config) error {
Backends: backends, Backends: backends,
} }
backendGroups[bgName] = group backendGroups[bgName] = group
if bg.WSEnabled {
if wsBackendGroup != nil {
return fmt.Errorf("cannot define more than one WS-enabled backend group")
} }
wsBackendGroup = group
var wsBackendGroup *BackendGroup
if config.WSBackendGroup != "" {
wsBackendGroup = backendGroups[config.WSBackendGroup]
if wsBackendGroup == nil {
return fmt.Errorf("ws backend group %s does not exist", config.WSBackendGroup)
} }
} }
if wsBackendGroup == nil && config.Server.WSPort != 0 {
return fmt.Errorf("a ws port was defined, but no ws group was defined")
}
for _, bg := range config.RPCMethodMappings { for _, bg := range config.RPCMethodMappings {
if backendGroups[bg] == nil { if backendGroups[bg] == nil {
return fmt.Errorf("undefined backend group %s", bg) return fmt.Errorf("undefined backend group %s", bg)
...@@ -118,15 +123,29 @@ func Start(config *Config) error { ...@@ -118,15 +123,29 @@ func Start(config *Config) error {
go http.ListenAndServe(addr, promhttp.Handler()) go http.ListenAndServe(addr, promhttp.Handler())
} }
if config.Server.RPCPort != 0 {
go func() {
if err := srv.RPCListenAndServe(config.Server.RPCHost, config.Server.RPCPort); err != nil {
if errors.Is(err, http.ErrServerClosed) {
log.Info("RPC server shut down")
return
}
log.Crit("error starting RPC server", "err", err)
}
}()
}
if config.Server.WSPort != 0 {
go func() { go func() {
if err := srv.ListenAndServe(config.Server.Host, config.Server.Port); err != nil { if err := srv.WSListenAndServe(config.Server.WSHost, config.Server.WSPort); err != nil {
if errors.Is(err, http.ErrServerClosed) { if errors.Is(err, http.ErrServerClosed) {
log.Info("server shut down") log.Info("WS server shut down")
return return
} }
log.Crit("error starting server", "err", err) log.Crit("error starting WS server", "err", err)
} }
}() }()
}
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
const ( const (
ContextKeyAuth = "authorization" ContextKeyAuth = "authorization"
ContextKeyReqID = "req_id"
) )
type Server struct { type Server struct {
...@@ -27,7 +28,8 @@ type Server struct { ...@@ -27,7 +28,8 @@ type Server struct {
maxBodySize int64 maxBodySize int64
authenticatedPaths map[string]string authenticatedPaths map[string]string
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
server *http.Server rpcServer *http.Server
wsServer *http.Server
} }
func NewServer( func NewServer(
...@@ -51,27 +53,46 @@ func NewServer( ...@@ -51,27 +53,46 @@ func NewServer(
} }
} }
func (s *Server) ListenAndServe(host string, port int) error { func (s *Server) RPCListenAndServe(host string, port int) error {
hdlr := mux.NewRouter() hdlr := mux.NewRouter()
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET") hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
hdlr.HandleFunc("/api/v1/rpc", s.HandleRPC).Methods("POST") hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/api/v1/{authorization}/rpc", s.HandleRPC).Methods("POST") hdlr.HandleFunc("/{authorization}", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/api/v1/ws", s.HandleWS)
hdlr.HandleFunc("/api/v1/{authorization}/ws", s.HandleWS)
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowedOrigins: []string{"*"}, AllowedOrigins: []string{"*"},
}) })
addr := fmt.Sprintf("%s:%d", host, port) addr := fmt.Sprintf("%s:%d", host, port)
s.server = &http.Server{ s.rpcServer = &http.Server{
Handler: instrumentedHdlr(c.Handler(hdlr)), Handler: instrumentedHdlr(c.Handler(hdlr)),
Addr: addr, Addr: addr,
} }
log.Info("starting HTTP server", "addr", addr) log.Info("starting HTTP server", "addr", addr)
return s.server.ListenAndServe() return s.rpcServer.ListenAndServe()
}
func (s *Server) WSListenAndServe(host string, port int) error {
hdlr := mux.NewRouter()
hdlr.HandleFunc("/", s.HandleWS)
hdlr.HandleFunc("/{authorization}", s.HandleWS)
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
})
addr := fmt.Sprintf("%s:%d", host, port)
s.wsServer = &http.Server{
Handler: instrumentedHdlr(c.Handler(hdlr)),
Addr: addr,
}
log.Info("starting WS server", "addr", addr)
return s.wsServer.ListenAndServe()
} }
func (s *Server) Shutdown() { func (s *Server) Shutdown() {
s.server.Shutdown(context.Background()) if s.rpcServer != nil {
s.rpcServer.Shutdown(context.Background())
}
if s.wsServer != nil {
s.wsServer.Shutdown(context.Background())
}
} }
func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
...@@ -79,11 +100,13 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) { ...@@ -79,11 +100,13 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
ctx := s.authenticate(w, r) ctx := s.populateContext(w, r)
if ctx == nil { if ctx == nil {
return return
} }
log.Info("received RPC request", "req_id", GetReqID(ctx), "auth", GetAuthCtx(ctx))
req, err := ParseRPCReq(io.LimitReader(r.Body, s.maxBodySize)) req, err := ParseRPCReq(io.LimitReader(r.Body, s.maxBodySize))
if err != nil { if err != nil {
log.Info("rejected request with bad rpc request", "source", "rpc", "err", err) log.Info("rejected request with bad rpc request", "source", "rpc", "err", err)
...@@ -96,7 +119,12 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ...@@ -96,7 +119,12 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
if group == "" { if group == "" {
// use unknown below to prevent DOS vector that fills up memory // use unknown below to prevent DOS vector that fills up memory
// with arbitrary method names. // with arbitrary method names.
log.Info("blocked request for non-whitelisted method", "source", "ws", "method", req.Method) log.Info(
"blocked request for non-whitelisted method",
"source", "rpc",
"req_id", GetReqID(ctx),
"method", req.Method,
)
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted) RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
writeRPCError(w, req.ID, ErrMethodNotWhitelisted) writeRPCError(w, req.ID, ErrMethodNotWhitelisted)
return return
...@@ -104,39 +132,49 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ...@@ -104,39 +132,49 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
backendRes, err := s.backendGroups[group].Forward(ctx, req) backendRes, err := s.backendGroups[group].Forward(ctx, req)
if err != nil { if err != nil {
log.Error("error forwarding RPC request", "method", req.Method, "err", err) log.Error(
"error forwarding RPC request",
"method", req.Method,
"req_id", GetReqID(ctx),
"err", err,
)
writeRPCError(w, req.ID, err) writeRPCError(w, req.ID, err)
return return
} }
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
if err := enc.Encode(backendRes); err != nil { if err := enc.Encode(backendRes); err != nil {
log.Error("error encoding response", "err", err) log.Error(
"error encoding response",
"req_id", GetReqID(ctx),
"err", err,
)
RecordRPCError(ctx, BackendProxyd, req.Method, err) RecordRPCError(ctx, BackendProxyd, req.Method, err)
writeRPCError(w, req.ID, err) writeRPCError(w, req.ID, err)
return return
} }
log.Debug("forwarded RPC method", "method", req.Method)
} }
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
ctx := s.authenticate(w, r) ctx := s.populateContext(w, r)
if ctx == nil { if ctx == nil {
return return
} }
log.Info("received WS connection", "req_id", GetReqID(ctx))
clientConn, err := s.upgrader.Upgrade(w, r, nil) clientConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Error("error upgrading client conn", "err", err) log.Error("error upgrading client conn", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
return return
} }
proxier, err := s.wsBackendGroup.ProxyWS(clientConn, s.wsMethodWhitelist) proxier, err := s.wsBackendGroup.ProxyWS(ctx, clientConn, s.wsMethodWhitelist)
if err != nil { if err != nil {
if errors.Is(err, ErrNoBackends) { if errors.Is(err, ErrNoBackends) {
RecordUnserviceableRequest(ctx, RPCRequestSourceWS) RecordUnserviceableRequest(ctx, RPCRequestSourceWS)
} }
log.Error("error dialing ws backend", "err", err) log.Error("error dialing ws backend", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
clientConn.Close() clientConn.Close()
return return
} }
...@@ -145,13 +183,15 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { ...@@ -145,13 +183,15 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
go func() { go func() {
// Below call blocks so run it in a goroutine. // Below call blocks so run it in a goroutine.
if err := proxier.Proxy(ctx); err != nil { if err := proxier.Proxy(ctx); err != nil {
log.Error("error proxying websocket", "err", err) log.Error("error proxying websocket", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
} }
activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Dec() activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Dec()
}() }()
log.Info("accepted WS connection", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx))
} }
func (s *Server) authenticate(w http.ResponseWriter, r *http.Request) context.Context { func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context {
vars := mux.Vars(r) vars := mux.Vars(r)
authorization := vars["authorization"] authorization := vars["authorization"]
...@@ -159,18 +199,29 @@ func (s *Server) authenticate(w http.ResponseWriter, r *http.Request) context.Co ...@@ -159,18 +199,29 @@ func (s *Server) authenticate(w http.ResponseWriter, r *http.Request) context.Co
// handle the edge case where auth is disabled // handle the edge case where auth is disabled
// but someone sends in an auth key anyway // but someone sends in an auth key anyway
if authorization != "" { if authorization != "" {
log.Info("blocked authenticated request against unauthenticated proxy")
w.WriteHeader(404) w.WriteHeader(404)
return nil return nil
} }
return r.Context() return context.WithValue(
r.Context(),
ContextKeyReqID,
randStr(10),
)
} }
if authorization == "" || s.authenticatedPaths[authorization] == "" { if authorization == "" || s.authenticatedPaths[authorization] == "" {
log.Info("blocked unauthorized request", "authorization", authorization)
w.WriteHeader(401) w.WriteHeader(401)
return nil return nil
} }
return context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization]) ctx := context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization])
return context.WithValue(
ctx,
ContextKeyReqID,
randStr(10),
)
} }
func writeRPCError(w http.ResponseWriter, id *int, err error) { func writeRPCError(w http.ResponseWriter, id *int, err error) {
...@@ -208,3 +259,11 @@ func GetAuthCtx(ctx context.Context) string { ...@@ -208,3 +259,11 @@ func GetAuthCtx(ctx context.Context) string {
return authUser return authUser
} }
func GetReqID(ctx context.Context) string {
reqId, ok := ctx.Value(ContextKeyReqID).(string)
if !ok {
return ""
}
return reqId
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment