Commit abe231bf authored by Matthew Slipper's avatar Matthew Slipper

go/proxyd: Make endpoints match Geth, better logging

parent b3363ac8
---
'@eth-optimism/proxyd': major
---
Make endpoints match Geth, better logging
......@@ -163,7 +163,12 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) {
res, err := b.doForward(req)
if err != nil {
lastError = err
log.Warn("backend request failed, trying again", "err", err, "name", b.Name)
log.Warn(
"backend request failed, trying again",
"name", b.Name,
"req_id", GetReqID(ctx),
"err", err,
)
respTimer.ObserveDuration()
RecordRPCError(ctx, b.Name, req.Method, err)
time.Sleep(calcBackoff(i))
......@@ -172,6 +177,20 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) {
respTimer.ObserveDuration()
if res.IsError() {
RecordRPCError(ctx, b.Name, req.Method, res.Error)
log.Info(
"backend responded with RPC error",
"code", res.Error.Code,
"msg", res.Error.Message,
"req_id", GetReqID(ctx),
"source", "rpc",
"auth", GetAuthCtx(ctx),
)
} else {
log.Info("forwarded RPC request",
"method", req.Method,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
}
return res, nil
}
......@@ -313,15 +332,31 @@ func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, er
return nil, err
}
if errors.Is(err, ErrBackendOffline) {
log.Debug("skipping offline backend", "name", back.Name)
log.Warn(
"skipping offline backend",
"name", back.Name,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
continue
}
if errors.Is(err, ErrBackendOverCapacity) {
log.Debug("skipping over-capacity backend", "name", back.Name)
log.Warn(
"skipping over-capacity backend",
"name", back.Name,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
continue
}
if err != nil {
log.Error("error forwarding request to backend", "err", err, "name", b.Name)
log.Error(
"error forwarding request to backend",
"name", b.Name,
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
"err", err,
)
continue
}
return res, nil
......@@ -331,19 +366,35 @@ func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, er
return nil, ErrNoBackends
}
func (b *BackendGroup) ProxyWS(clientConn *websocket.Conn, methodWhitelist *StringSet) (*WSProxier, error) {
func (b *BackendGroup) ProxyWS(ctx context.Context, clientConn *websocket.Conn, methodWhitelist *StringSet) (*WSProxier, error) {
for _, back := range b.Backends {
proxier, err := back.ProxyWS(clientConn, methodWhitelist)
if errors.Is(err, ErrBackendOffline) {
log.Debug("skipping offline backend", "name", back.Name)
log.Warn(
"skipping offline backend",
"name", back.Name,
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
)
continue
}
if errors.Is(err, ErrBackendOverCapacity) {
log.Debug("skipping over-capacity backend", "name", back.Name)
log.Warn(
"skipping over-capacity backend",
"name", back.Name,
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
)
continue
}
if err != nil {
log.Warn("error dialing ws backend", "name", back.Name, "err", err)
log.Warn(
"error dialing ws backend",
"name", back.Name,
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
"err", err,
)
continue
}
return proxier, nil
......@@ -411,7 +462,7 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
// Don't bother sending invalid requests to the backend,
// just handle them here.
req, err := w.parseClientMsg(msg)
req, err := w.prepareClientMsg(msg)
if err != nil {
var id *int
method := MethodUnknown
......@@ -419,11 +470,23 @@ func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
id = req.ID
method = req.Method
}
log.Info(
"error preparing client message",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
"err", err,
)
outConn = w.clientConn
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
RecordRPCError(ctx, BackendProxyd, method, err)
} else {
RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
log.Info(
"forwarded WS message to backend",
"method", req.Method,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
}
err = outConn.WriteMessage(msgType, msg)
......@@ -465,7 +528,21 @@ func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
msg = mustMarshalJSON(NewRPCErrorRes(id, err))
}
if res.IsError() {
log.Info(
"backend responded with RPC error",
"code", res.Error.Code,
"msg", res.Error.Message,
"source", "ws",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error)
} else {
log.Info(
"forwarded WS message to client",
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
}
err = w.clientConn.WriteMessage(msgType, msg)
......@@ -485,15 +562,13 @@ func (w *WSProxier) close() {
activeBackendWsConnsGauge.WithLabelValues(w.backend.Name).Dec()
}
func (w *WSProxier) parseClientMsg(msg []byte) (*RPCReq, error) {
func (w *WSProxier) prepareClientMsg(msg []byte) (*RPCReq, error) {
req, err := ParseRPCReq(bytes.NewReader(msg))
if err != nil {
log.Warn("error parsing RPC request", "source", "ws", "err", err)
return nil, err
}
if !w.methodWhitelist.Has(req.Method) {
log.Info("blocked request for non-whitelisted method", "source", "ws", "method", req.Method)
return req, ErrMethodNotWhitelisted
}
......
package proxyd
type ServerConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
RPCHost string `toml:"rpc_host"`
RPCPort int `toml:"rpc_port"`
WSHost string `toml:"ws_host"`
WSPort int `toml:"ws_port"`
MaxBodySizeBytes int64 `toml:"max_body_size_bytes"`
}
......@@ -36,7 +38,6 @@ type BackendsConfig map[string]*BackendConfig
type BackendGroupConfig struct {
Backends []string `toml:"backends"`
WSEnabled bool `toml:"ws_enabled"`
}
type BackendGroupsConfig map[string]*BackendGroupConfig
......@@ -44,6 +45,7 @@ type BackendGroupsConfig map[string]*BackendGroupConfig
type MethodMappingsConfig map[string]string
type Config struct {
WSBackendGroup string `toml:"ws_backend_group"`
Server *ServerConfig `toml:"server"`
Redis *RedisConfig `toml:"redis"`
Metrics *MetricsConfig `toml:"metrics"`
......
......@@ -4,12 +4,18 @@ ws_method_whitelist = [
"eth_call",
"eth_chainId"
]
# Enable WS on this backend group. There can only be one WS-enabled backend group.
ws_backend_group = "main"
[server]
# Host for the proxyd server to listen on.
host = "0.0.0.0"
# Host for the proxyd RPC server to listen on.
rpc_host = "0.0.0.0"
# Port for the above.
port = 8080
rpc_port = 8080
# Host for the proxyd WS server to listen on.
ws_host = "0.0.0.0"
# Port for the above
ws_port = 8085
# Maximum client body size, in bytes, that the server will accept.
max_body_size_bytes = 10485760
......@@ -59,8 +65,6 @@ max_ws_conns = 1
[backend_groups]
[backend_groups.main]
backends = ["infura"]
# Enable WS on this backend group. There can only be one WS-enabled backend group.
ws_enabled = true
[backend_groups.alchemy]
backends = ["alchemy"]
......
......@@ -16,12 +16,12 @@ func Start(config *Config) error {
if len(config.Backends) == 0 {
return errors.New("must define at least one backend")
}
if len(config.BackendGroups) == 0 {
return errors.New("must define at least one backend group")
}
if len(config.RPCMethodMappings) == 0 {
return errors.New("must define at least one RPC method mapping")
}
if len(config.BackendGroups) == 0 {
return errors.New("must define at least one backend group")
}
if len(config.RPCMethodMappings) == 0 {
return errors.New("must define at least one RPC method mapping")
}
for authKey := range config.Authentication {
if authKey == "none" {
......@@ -35,7 +35,7 @@ func Start(config *Config) error {
}
backendNames := make([]string, 0)
backendsByName := make(map[string]*Backend)
backendsByName := make(map[string]*Backend)
for name, cfg := range config.Backends {
opts := make([]BackendOpt, 0)
......@@ -70,44 +70,49 @@ func Start(config *Config) error {
}
back := NewBackend(name, cfg.RPCURL, cfg.WSURL, redis, opts...)
backendNames = append(backendNames, name)
backendsByName[name] = back
backendsByName[name] = back
log.Info("configured backend", "name", name, "rpc_url", cfg.RPCURL, "ws_url", cfg.WSURL)
}
backendGroups := make(map[string]*BackendGroup)
backendGroups := make(map[string]*BackendGroup)
for bgName, bg := range config.BackendGroups {
backends := make([]*Backend, 0)
for _, bName := range bg.Backends {
if backendsByName[bName] == nil {
return fmt.Errorf("backend %s is not defined", bName)
}
backends = append(backends, backendsByName[bName])
}
group := &BackendGroup{
Name: bgName,
Backends: backends,
}
backendGroups[bgName] = group
}
var wsBackendGroup *BackendGroup
for bgName, bg := range config.BackendGroups {
backends := make([]*Backend, 0)
for _, bName := range bg.Backends {
if backendsByName[bName] == nil {
return fmt.Errorf("backend %s is not defined", bName)
}
backends = append(backends, backendsByName[bName])
}
group := &BackendGroup{
Name: bgName,
Backends: backends,
}
backendGroups[bgName] = group
if bg.WSEnabled {
if wsBackendGroup != nil {
return fmt.Errorf("cannot define more than one WS-enabled backend group")
}
wsBackendGroup = group
if config.WSBackendGroup != "" {
wsBackendGroup = backendGroups[config.WSBackendGroup]
if wsBackendGroup == nil {
return fmt.Errorf("ws backend group %s does not exist", config.WSBackendGroup)
}
}
for _, bg := range config.RPCMethodMappings {
if backendGroups[bg] == nil {
return fmt.Errorf("undefined backend group %s", bg)
}
if wsBackendGroup == nil && config.Server.WSPort != 0 {
return fmt.Errorf("a ws port was defined, but no ws group was defined")
}
for _, bg := range config.RPCMethodMappings {
if backendGroups[bg] == nil {
return fmt.Errorf("undefined backend group %s", bg)
}
}
srv := NewServer(
backendGroups,
wsBackendGroup,
NewStringSetFromStrings(config.WSMethodWhitelist),
config.RPCMethodMappings,
backendGroups,
wsBackendGroup,
NewStringSetFromStrings(config.WSMethodWhitelist),
config.RPCMethodMappings,
config.Server.MaxBodySizeBytes,
config.Authentication,
)
......@@ -118,15 +123,29 @@ func Start(config *Config) error {
go http.ListenAndServe(addr, promhttp.Handler())
}
go func() {
if err := srv.ListenAndServe(config.Server.Host, config.Server.Port); err != nil {
if errors.Is(err, http.ErrServerClosed) {
log.Info("server shut down")
return
if config.Server.RPCPort != 0 {
go func() {
if err := srv.RPCListenAndServe(config.Server.RPCHost, config.Server.RPCPort); err != nil {
if errors.Is(err, http.ErrServerClosed) {
log.Info("RPC server shut down")
return
}
log.Crit("error starting RPC server", "err", err)
}
log.Crit("error starting server", "err", err)
}
}()
}()
}
if config.Server.WSPort != 0 {
go func() {
if err := srv.WSListenAndServe(config.Server.WSHost, config.Server.WSPort); err != nil {
if errors.Is(err, http.ErrServerClosed) {
log.Info("WS server shut down")
return
}
log.Crit("error starting WS server", "err", err)
}
}()
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
......
......@@ -16,7 +16,8 @@ import (
)
const (
ContextKeyAuth = "authorization"
ContextKeyAuth = "authorization"
ContextKeyReqID = "req_id"
)
type Server struct {
......@@ -27,7 +28,8 @@ type Server struct {
maxBodySize int64
authenticatedPaths map[string]string
upgrader *websocket.Upgrader
server *http.Server
rpcServer *http.Server
wsServer *http.Server
}
func NewServer(
......@@ -51,27 +53,46 @@ func NewServer(
}
}
func (s *Server) ListenAndServe(host string, port int) error {
func (s *Server) RPCListenAndServe(host string, port int) error {
hdlr := mux.NewRouter()
hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
hdlr.HandleFunc("/api/v1/rpc", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/api/v1/{authorization}/rpc", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/api/v1/ws", s.HandleWS)
hdlr.HandleFunc("/api/v1/{authorization}/ws", s.HandleWS)
hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
hdlr.HandleFunc("/{authorization}", s.HandleRPC).Methods("POST")
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
})
addr := fmt.Sprintf("%s:%d", host, port)
s.server = &http.Server{
s.rpcServer = &http.Server{
Handler: instrumentedHdlr(c.Handler(hdlr)),
Addr: addr,
}
log.Info("starting HTTP server", "addr", addr)
return s.server.ListenAndServe()
return s.rpcServer.ListenAndServe()
}
func (s *Server) WSListenAndServe(host string, port int) error {
hdlr := mux.NewRouter()
hdlr.HandleFunc("/", s.HandleWS)
hdlr.HandleFunc("/{authorization}", s.HandleWS)
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
})
addr := fmt.Sprintf("%s:%d", host, port)
s.wsServer = &http.Server{
Handler: instrumentedHdlr(c.Handler(hdlr)),
Addr: addr,
}
log.Info("starting WS server", "addr", addr)
return s.wsServer.ListenAndServe()
}
func (s *Server) Shutdown() {
s.server.Shutdown(context.Background())
if s.rpcServer != nil {
s.rpcServer.Shutdown(context.Background())
}
if s.wsServer != nil {
s.wsServer.Shutdown(context.Background())
}
}
func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
......@@ -79,11 +100,13 @@ func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
ctx := s.authenticate(w, r)
ctx := s.populateContext(w, r)
if ctx == nil {
return
}
log.Info("received RPC request", "req_id", GetReqID(ctx), "auth", GetAuthCtx(ctx))
req, err := ParseRPCReq(io.LimitReader(r.Body, s.maxBodySize))
if err != nil {
log.Info("rejected request with bad rpc request", "source", "rpc", "err", err)
......@@ -92,51 +115,66 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return
}
group := s.rpcMethodMappings[req.Method]
if group == "" {
// use unknown below to prevent DOS vector that fills up memory
// with arbitrary method names.
log.Info("blocked request for non-whitelisted method", "source", "ws", "method", req.Method)
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
writeRPCError(w, req.ID, ErrMethodNotWhitelisted)
return
}
group := s.rpcMethodMappings[req.Method]
if group == "" {
// use unknown below to prevent DOS vector that fills up memory
// with arbitrary method names.
log.Info(
"blocked request for non-whitelisted method",
"source", "rpc",
"req_id", GetReqID(ctx),
"method", req.Method,
)
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
writeRPCError(w, req.ID, ErrMethodNotWhitelisted)
return
}
backendRes, err := s.backendGroups[group].Forward(ctx, req)
if err != nil {
log.Error("error forwarding RPC request", "method", req.Method, "err", err)
log.Error(
"error forwarding RPC request",
"method", req.Method,
"req_id", GetReqID(ctx),
"err", err,
)
writeRPCError(w, req.ID, err)
return
}
enc := json.NewEncoder(w)
if err := enc.Encode(backendRes); err != nil {
log.Error("error encoding response", "err", err)
log.Error(
"error encoding response",
"req_id", GetReqID(ctx),
"err", err,
)
RecordRPCError(ctx, BackendProxyd, req.Method, err)
writeRPCError(w, req.ID, err)
return
}
log.Debug("forwarded RPC method", "method", req.Method)
}
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
ctx := s.authenticate(w, r)
ctx := s.populateContext(w, r)
if ctx == nil {
return
}
log.Info("received WS connection", "req_id", GetReqID(ctx))
clientConn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Error("error upgrading client conn", "err", err)
log.Error("error upgrading client conn", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
return
}
proxier, err := s.wsBackendGroup.ProxyWS(clientConn, s.wsMethodWhitelist)
proxier, err := s.wsBackendGroup.ProxyWS(ctx, clientConn, s.wsMethodWhitelist)
if err != nil {
if errors.Is(err, ErrNoBackends) {
RecordUnserviceableRequest(ctx, RPCRequestSourceWS)
}
log.Error("error dialing ws backend", "err", err)
log.Error("error dialing ws backend", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
clientConn.Close()
return
}
......@@ -145,13 +183,15 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
go func() {
// Below call blocks so run it in a goroutine.
if err := proxier.Proxy(ctx); err != nil {
log.Error("error proxying websocket", "err", err)
log.Error("error proxying websocket", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
}
activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Dec()
}()
log.Info("accepted WS connection", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx))
}
func (s *Server) authenticate(w http.ResponseWriter, r *http.Request) context.Context {
func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context {
vars := mux.Vars(r)
authorization := vars["authorization"]
......@@ -159,18 +199,29 @@ func (s *Server) authenticate(w http.ResponseWriter, r *http.Request) context.Co
// handle the edge case where auth is disabled
// but someone sends in an auth key anyway
if authorization != "" {
log.Info("blocked authenticated request against unauthenticated proxy")
w.WriteHeader(404)
return nil
}
return r.Context()
return context.WithValue(
r.Context(),
ContextKeyReqID,
randStr(10),
)
}
if authorization == "" || s.authenticatedPaths[authorization] == "" {
log.Info("blocked unauthorized request", "authorization", authorization)
w.WriteHeader(401)
return nil
}
return context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization])
ctx := context.WithValue(r.Context(), ContextKeyAuth, s.authenticatedPaths[authorization])
return context.WithValue(
ctx,
ContextKeyReqID,
randStr(10),
)
}
func writeRPCError(w http.ResponseWriter, id *int, err error) {
......@@ -208,3 +259,11 @@ func GetAuthCtx(ctx context.Context) string {
return authUser
}
func GetReqID(ctx context.Context) string {
reqId, ok := ctx.Value(ContextKeyReqID).(string)
if !ok {
return ""
}
return reqId
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment