Commit 32638bbb authored by protolambda's avatar protolambda Committed by GitHub

op-service: fix RPC websocket support (#13465)

parent 38ac5e49
...@@ -47,6 +47,7 @@ func (cfg *Config) Setup(ctx context.Context, logger log.Logger) (SubSystem, err ...@@ -47,6 +47,7 @@ func (cfg *Config) Setup(ctx context.Context, logger log.Logger) (SubSystem, err
} }
out := &ManagedMode{} out := &ManagedMode{}
out.srv = rpc.NewServer(cfg.RPCAddr, cfg.RPCPort, "v0.0.0", out.srv = rpc.NewServer(cfg.RPCAddr, cfg.RPCPort, "v0.0.0",
rpc.WithLogger(logger),
rpc.WithWebsocketEnabled(), rpc.WithJWTSecret(jwtSecret[:])) rpc.WithWebsocketEnabled(), rpc.WithJWTSecret(jwtSecret[:]))
return out, nil return out, nil
} else { } else {
......
package httputil package httputil
import "net/http" import (
"bufio"
"fmt"
"net"
"net/http"
)
type WrappedResponseWriter struct { type WrappedResponseWriter struct {
StatusCode int StatusCode int
...@@ -8,8 +13,12 @@ type WrappedResponseWriter struct { ...@@ -8,8 +13,12 @@ type WrappedResponseWriter struct {
w http.ResponseWriter w http.ResponseWriter
wroteHeader bool wroteHeader bool
UpgradeAttempt bool
} }
var _ http.Hijacker = (*WrappedResponseWriter)(nil)
func NewWrappedResponseWriter(w http.ResponseWriter) *WrappedResponseWriter { func NewWrappedResponseWriter(w http.ResponseWriter) *WrappedResponseWriter {
return &WrappedResponseWriter{ return &WrappedResponseWriter{
StatusCode: 200, StatusCode: 200,
...@@ -36,3 +45,14 @@ func (w *WrappedResponseWriter) WriteHeader(statusCode int) { ...@@ -36,3 +45,14 @@ func (w *WrappedResponseWriter) WriteHeader(statusCode int) {
w.StatusCode = statusCode w.StatusCode = statusCode
w.w.WriteHeader(statusCode) w.w.WriteHeader(statusCode)
} }
// Hijack implements http.Hijacker, so the WrappedResponseWriter is
// compatible as middleware for websocket-upgrades that take over the connection.
func (w *WrappedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
w.UpgradeAttempt = true
h, ok := w.w.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("response-writer is not a http.Hijacker, cannot turn it into raw connection")
}
return h.Hijack()
}
...@@ -20,6 +20,7 @@ func NewLoggingMiddleware(lgr log.Logger, next http.Handler) http.Handler { ...@@ -20,6 +20,7 @@ func NewLoggingMiddleware(lgr log.Logger, next http.Handler) http.Handler {
"path", r.URL.EscapedPath(), "path", r.URL.EscapedPath(),
"duration", time.Since(start), "duration", time.Since(start),
"remote_addr", r.RemoteAddr, "remote_addr", r.RemoteAddr,
"upgrade_attempt", ww.UpgradeAttempt,
) )
}) })
} }
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
...@@ -38,6 +39,7 @@ type Server struct { ...@@ -38,6 +39,7 @@ type Server struct {
log log.Logger log log.Logger
tls *ServerTLSConfig tls *ServerTLSConfig
middlewares []Middleware middlewares []Middleware
rpcServer *rpc.Server
} }
type ServerTLSConfig struct { type ServerTLSConfig struct {
...@@ -73,12 +75,16 @@ func WithVHosts(hosts []string) ServerOption { ...@@ -73,12 +75,16 @@ func WithVHosts(hosts []string) ServerOption {
} }
} }
// WithWebsocketEnabled allows `ws://host:port/`, `ws://host:port/ws` and `ws://host:port/ws/`
// to be upgraded to a websocket JSON RPC connection.
func WithWebsocketEnabled() ServerOption { func WithWebsocketEnabled() ServerOption {
return func(b *Server) { return func(b *Server) {
b.wsEnabled = true b.wsEnabled = true
} }
} }
// WithJWTSecret adds authentication to the RPCs (HTTP, and WS pre-upgrade if enabled).
// The health endpoint is still available without authentication.
func WithJWTSecret(secret []byte) ServerOption { func WithJWTSecret(secret []byte) ServerOption {
return func(b *Server) { return func(b *Server) {
b.jwtSecret = secret b.jwtSecret = secret
...@@ -140,6 +146,7 @@ func NewServer(host string, port int, appVersion string, opts ...ServerOption) * ...@@ -140,6 +146,7 @@ func NewServer(host string, port int, appVersion string, opts ...ServerOption) *
Addr: endpoint, Addr: endpoint,
}, },
log: log.Root(), log: log.Root(),
rpcServer: rpc.NewServer(),
} }
for _, opt := range opts { for _, opt := range opts {
opt(bs) opt(bs)
...@@ -156,6 +163,7 @@ func NewServer(host string, port int, appVersion string, opts ...ServerOption) * ...@@ -156,6 +163,7 @@ func NewServer(host string, port int, appVersion string, opts ...ServerOption) *
return bs return bs
} }
// Endpoint returns the HTTP endpoint without http / ws protocol prefix.
func (b *Server) Endpoint() string { func (b *Server) Endpoint() string {
return b.listener.Addr().String() return b.listener.Addr().String()
} }
...@@ -165,36 +173,41 @@ func (b *Server) AddAPI(api rpc.API) { ...@@ -165,36 +173,41 @@ func (b *Server) AddAPI(api rpc.API) {
} }
func (b *Server) Start() error { func (b *Server) Start() error {
srv := rpc.NewServer() // Register all APIs to the RPC server.
for _, api := range b.apis { for _, api := range b.apis {
if err := srv.RegisterName(api.Namespace, api.Service); err != nil { if err := b.rpcServer.RegisterName(api.Namespace, api.Service); err != nil {
return fmt.Errorf("failed to register API %s: %w", api.Namespace, err) return fmt.Errorf("failed to register API %s: %w", api.Namespace, err)
} }
b.log.Info("registered API", "namespace", api.Namespace) b.log.Info("registered API", "namespace", api.Namespace)
} }
// rpc middleware // http handler stack.
var nodeHdlr http.Handler = srv var handler http.Handler
for _, middleware := range b.middlewares {
nodeHdlr = middleware(nodeHdlr) // default to 404 not-found
} handler = http.HandlerFunc(http.NotFound)
nodeHdlr = node.NewHTTPHandlerStack(nodeHdlr, b.corsHosts, b.vHosts, b.jwtSecret)
// Health endpoint is lowest priority.
handler = b.newHealthMiddleware(handler)
// serve RPC on configured RPC path (but not on arbitrary paths)
handler = b.newHttpRPCMiddleware(handler)
mux := http.NewServeMux() // Conditionally enable Websocket support.
mux.Handle(b.rpcPath, nodeHdlr) if b.wsEnabled { // prioritize WS RPC, if it's an upgrade request
mux.Handle(b.healthzPath, b.healthzHandler) handler = b.newWsMiddleWare(handler)
}
if b.wsEnabled { // Apply user middlewares
wsHandler := node.NewWSHandlerStack(srv.WebsocketHandler(b.corsHosts), b.jwtSecret) for _, middleware := range b.middlewares {
mux.Handle("/ws", wsHandler) handler = middleware(handler)
} }
// http middleware // Outer-most middlewares: logging, metrics, TLS
var handler http.Handler = mux
handler = optls.NewPeerTLSMiddleware(handler) handler = optls.NewPeerTLSMiddleware(handler)
handler = opmetrics.NewHTTPRecordingMiddleware(b.httpRecorder, handler) handler = opmetrics.NewHTTPRecordingMiddleware(b.httpRecorder, handler)
handler = oplog.NewLoggingMiddleware(b.log, handler) handler = oplog.NewLoggingMiddleware(b.log, handler)
b.httpServer.Handler = handler b.httpServer.Handler = handler
listener, err := net.Listen("tcp", b.endpoint) listener, err := net.Listen("tcp", b.endpoint)
...@@ -230,10 +243,45 @@ func (b *Server) Start() error { ...@@ -230,10 +243,45 @@ func (b *Server) Start() error {
} }
} }
func (b *Server) newHealthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == b.healthzPath {
b.healthzHandler.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (b *Server) newHttpRPCMiddleware(next http.Handler) http.Handler {
// Only allow RPC handlers behind the appropriate CORS / vhost / JWT (optional) setup.
// Note that websockets have their own handler-stack, also configured with CORS and JWT, separately.
httpHandler := node.NewHTTPHandlerStack(b.rpcServer, b.corsHosts, b.vHosts, b.jwtSecret)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == b.rpcPath {
httpHandler.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (b *Server) newWsMiddleWare(next http.Handler) http.Handler {
wsHandler := node.NewWSHandlerStack(b.rpcServer.WebsocketHandler(b.corsHosts), b.jwtSecret)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isWebsocket(r) && (r.URL.Path == "/" || r.URL.Path == "/ws" || r.URL.Path == "/ws/") {
wsHandler.ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
func (b *Server) Stop() error { func (b *Server) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
_ = b.httpServer.Shutdown(ctx) _ = b.httpServer.Shutdown(ctx)
b.rpcServer.Stop()
return nil return nil
} }
...@@ -256,3 +304,8 @@ type healthzAPI struct { ...@@ -256,3 +304,8 @@ type healthzAPI struct {
func (h *healthzAPI) Status() string { func (h *healthzAPI) Status() string {
return h.appVersion return h.appVersion
} }
func isWebsocket(r *http.Request) bool {
return strings.EqualFold(r.Header.Get("Upgrade"), "websocket") &&
strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade")
}
package rpc package rpc
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/rpc"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rpc"
"github.com/ethereum-optimism/optimism/op-service/testlog"
) )
type testAPI struct{} type testAPI struct{}
...@@ -20,24 +26,26 @@ func (t *testAPI) Frobnicate(n int) int { ...@@ -20,24 +26,26 @@ func (t *testAPI) Frobnicate(n int) int {
func TestBaseServer(t *testing.T) { func TestBaseServer(t *testing.T) {
appVersion := "test" appVersion := "test"
logger := testlog.Logger(t, log.LevelTrace)
log.SetDefault(log.NewLogger(logger.Handler()))
server := NewServer( server := NewServer(
"127.0.0.1", "127.0.0.1",
0, 0,
appVersion, appVersion,
WithLogger(logger),
WithAPIs([]rpc.API{ WithAPIs([]rpc.API{
{ {
Namespace: "test", Namespace: "test",
Service: new(testAPI), Service: new(testAPI),
}, },
}), }),
WithWebsocketEnabled(),
) )
require.NoError(t, server.Start()) require.NoError(t, server.Start(), "must start")
defer func() {
_ = server.Stop()
}()
rpcClient, err := rpc.Dial(fmt.Sprintf("http://%s", server.endpoint)) rpcClient, err := rpc.Dial(fmt.Sprintf("http://%s", server.endpoint))
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(rpcClient.Close)
t.Run("supports GET /healthz", func(t *testing.T) { t.Run("supports GET /healthz", func(t *testing.T) {
res, err := http.Get(fmt.Sprintf("http://%s/healthz", server.endpoint)) res, err := http.Get(fmt.Sprintf("http://%s/healthz", server.endpoint))
...@@ -68,4 +76,19 @@ func TestBaseServer(t *testing.T) { ...@@ -68,4 +76,19 @@ func TestBaseServer(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Greater(t, port, 0) require.Greater(t, port, 0)
}) })
t.Run("supports websocket", func(t *testing.T) {
endpoint := "ws://" + server.Endpoint()
t.Log("connecting to", endpoint)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
wsCl, err := rpc.DialContext(ctx, endpoint)
require.NoError(t, err)
defer wsCl.Close()
var res int
require.NoError(t, wsCl.Call(&res, "test_frobnicate", 42))
require.Equal(t, 42*2, res)
})
require.NoError(t, server.Stop(), "must stop")
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment