Commit 7e4821e7 authored by protolambda's avatar protolambda

op-service: Stop(ctx) calls, shutdown testing, op-node rpc server update

parent 246d7dbe
...@@ -52,7 +52,11 @@ func Main(version string) func(cliCtx *cli.Context) error { ...@@ -52,7 +52,11 @@ func Main(version string) func(cliCtx *cli.Context) error {
l.Error("error starting metrics server", err) l.Error("error starting metrics server", err)
return err return err
} }
defer srv.Close() defer func() {
if err := srv.Stop(cliCtx.Context); err != nil {
l.Error("failed to stop metrics server", "err", err)
}
}()
opio.BlockOnInterrupts() opio.BlockOnInterrupts()
return nil return nil
......
...@@ -2,15 +2,18 @@ package api ...@@ -2,15 +2,18 @@ package api
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"runtime/debug" "runtime/debug"
"strconv"
"sync" "sync"
"github.com/ethereum-optimism/optimism/indexer/api/routes" "github.com/ethereum-optimism/optimism/indexer/api/routes"
"github.com/ethereum-optimism/optimism/indexer/config" "github.com/ethereum-optimism/optimism/indexer/config"
"github.com/ethereum-optimism/optimism/indexer/database" "github.com/ethereum-optimism/optimism/indexer/database"
"github.com/ethereum-optimism/optimism/op-service/httputil"
"github.com/ethereum-optimism/optimism/op-service/metrics" "github.com/ethereum-optimism/optimism/op-service/metrics"
"github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/log"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
...@@ -120,41 +123,42 @@ func (a *API) Port() int { ...@@ -120,41 +123,42 @@ func (a *API) Port() int {
// startServer ... Starts the API server // startServer ... Starts the API server
func (a *API) startServer(ctx context.Context) error { func (a *API) startServer(ctx context.Context) error {
a.log.Info("api server listening...", "port", a.serverConfig.Port) a.log.Debug("api server listening...", "port", a.serverConfig.Port)
server := http.Server{Addr: fmt.Sprintf(":%d", a.serverConfig.Port), Handler: a.router} addr := net.JoinHostPort(a.serverConfig.Host, strconv.Itoa(a.serverConfig.Port))
srv, err := httputil.StartHTTPServer(addr, a.router)
if err != nil {
return fmt.Errorf("failed to start API server: %w", err)
}
addr := fmt.Sprintf(":%d", a.serverConfig.Port) host, portStr, err := net.SplitHostPort(srv.Addr().String())
listener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
a.log.Error("Listen:", err) return errors.Join(err, srv.Close())
return err
} }
tcpAddr, ok := listener.Addr().(*net.TCPAddr) port, err := strconv.Atoi(portStr)
if !ok { if err != nil {
return fmt.Errorf("failed to get TCP address from network listener") return errors.Join(err, srv.Close())
} }
// Update the port in the config in case the OS chose a different port // Update the port in the config in case the OS chose a different port
// than the one we requested (e.g. using port 0 to fetch a random open port) // than the one we requested (e.g. using port 0 to fetch a random open port)
a.serverConfig.Port = tcpAddr.Port a.serverConfig.Host = host
a.serverConfig.Port = port
err = http.Serve(listener, server.Handler) <-ctx.Done()
if err != nil { if err := srv.Stop(context.Background()); err != nil {
a.log.Error("api server stopped with error", "err", err) return fmt.Errorf("failed to shutdown api server: %w", err)
} else {
a.log.Info("api server stopped")
} }
return err return nil
} }
// startMetricsServer ... Starts the metrics server // startMetricsServer ... Starts the metrics server
func (a *API) startMetricsServer(ctx context.Context) error { func (a *API) startMetricsServer(ctx context.Context) error {
a.log.Info("starting metrics server...", "port", a.metricsConfig.Port) a.log.Debug("starting metrics server...", "port", a.metricsConfig.Port)
srv, err := metrics.StartServer(a.metricsRegistry, a.metricsConfig.Host, a.metricsConfig.Port) srv, err := metrics.StartServer(a.metricsRegistry, a.metricsConfig.Host, a.metricsConfig.Port)
if err != nil { if err != nil {
return fmt.Errorf("failed to start metrics server: %w", err) return fmt.Errorf("failed to start metrics server: %w", err)
} }
<-ctx.Done() <-ctx.Done()
defer a.log.Info("metrics server stopped") defer a.log.Info("metrics server stopped")
return srv.Close() return srv.Stop(context.Background())
} }
...@@ -119,7 +119,7 @@ func (i *Indexer) startHttpServer(ctx context.Context) error { ...@@ -119,7 +119,7 @@ func (i *Indexer) startHttpServer(ctx context.Context) error {
i.log.Info("http server started", "addr", srv.Addr()) i.log.Info("http server started", "addr", srv.Addr())
<-ctx.Done() <-ctx.Done()
defer i.log.Info("http server stopped") defer i.log.Info("http server stopped")
return srv.Close() return srv.Stop(context.Background())
} }
func (i *Indexer) startMetricsServer(ctx context.Context) error { func (i *Indexer) startMetricsServer(ctx context.Context) error {
...@@ -131,7 +131,7 @@ func (i *Indexer) startMetricsServer(ctx context.Context) error { ...@@ -131,7 +131,7 @@ func (i *Indexer) startMetricsServer(ctx context.Context) error {
i.log.Info("metrics server started", "addr", srv.Addr()) i.log.Info("metrics server started", "addr", srv.Addr())
<-ctx.Done() <-ctx.Done()
defer i.log.Info("metrics server stopped") defer i.log.Info("metrics server stopped")
return srv.Close() return srv.Stop(context.Background())
} }
// Start starts the indexing service on L1 and L2 chains // Start starts the indexing service on L1 and L2 chains
......
...@@ -61,7 +61,11 @@ func Main(version string, cliCtx *cli.Context) error { ...@@ -61,7 +61,11 @@ func Main(version string, cliCtx *cli.Context) error {
return err return err
} }
l.Info("started pprof server", "addr", pprofSrv.Addr()) l.Info("started pprof server", "addr", pprofSrv.Addr())
defer pprofSrv.Close() defer func() {
if err := pprofSrv.Stop(context.Background()); err != nil {
l.Error("failed to stop pprof server", "err", err)
}
}()
} }
metricsCfg := cfg.MetricsConfig metricsCfg := cfg.MetricsConfig
...@@ -72,7 +76,11 @@ func Main(version string, cliCtx *cli.Context) error { ...@@ -72,7 +76,11 @@ func Main(version string, cliCtx *cli.Context) error {
return fmt.Errorf("failed to start metrics server: %w", err) return fmt.Errorf("failed to start metrics server: %w", err)
} }
l.Info("started metrics server", "addr", metricsSrv.Addr()) l.Info("started metrics server", "addr", metricsSrv.Addr())
defer metricsSrv.Close() defer func() {
if err := metricsSrv.Stop(context.Background()); err != nil {
l.Error("failed to stop pprof server", "err", err)
}
}()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
m.StartBalanceMetrics(ctx, l, batchSubmitter.L1Client, batchSubmitter.TxManager.From()) m.StartBalanceMetrics(ctx, l, batchSubmitter.L1Client, batchSubmitter.TxManager.From())
......
...@@ -3,8 +3,6 @@ package metrics ...@@ -3,8 +3,6 @@ package metrics
import ( import (
"context" "context"
"github.com/ethereum-optimism/optimism/op-service/httputil"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
...@@ -13,6 +11,7 @@ import ( ...@@ -13,6 +11,7 @@ import (
"github.com/ethereum-optimism/optimism/op-node/rollup/derive" "github.com/ethereum-optimism/optimism/op-node/rollup/derive"
"github.com/ethereum-optimism/optimism/op-service/eth" "github.com/ethereum-optimism/optimism/op-service/eth"
"github.com/ethereum-optimism/optimism/op-service/httputil"
opmetrics "github.com/ethereum-optimism/optimism/op-service/metrics" opmetrics "github.com/ethereum-optimism/optimism/op-service/metrics"
txmetrics "github.com/ethereum-optimism/optimism/op-service/txmgr/metrics" txmetrics "github.com/ethereum-optimism/optimism/op-service/txmgr/metrics"
) )
......
...@@ -78,7 +78,11 @@ func Main(cliCtx *cli.Context) error { ...@@ -78,7 +78,11 @@ func Main(cliCtx *cli.Context) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to start metrics server: %w", err) return fmt.Errorf("failed to start metrics server: %w", err)
} }
defer metricsSrv.Close() defer func() {
if err := metricsSrv.Stop(context.Background()); err != nil {
log.Error("failed to stop metrics server", "err", err)
}
}()
log.Info("started metrics server", "addr", metricsSrv.Addr()) log.Info("started metrics server", "addr", metricsSrv.Addr())
m.RecordUp() m.RecordUp()
} }
......
...@@ -31,13 +31,16 @@ type Service struct { ...@@ -31,13 +31,16 @@ type Service struct {
metricsSrv *httputil.HTTPServer metricsSrv *httputil.HTTPServer
} }
func (s *Service) Close() error { func (s *Service) Stop(ctx context.Context) error {
var result error var result error
if s.sched != nil {
result = errors.Join(result, s.sched.Close())
}
if s.pprofSrv != nil { if s.pprofSrv != nil {
result = errors.Join(result, s.pprofSrv.Close()) result = errors.Join(result, s.pprofSrv.Stop(ctx))
} }
if s.metricsSrv != nil { if s.metricsSrv != nil {
result = errors.Join(result, s.metricsSrv.Close()) result = errors.Join(result, s.metricsSrv.Stop(ctx))
} }
return result return result
} }
...@@ -66,7 +69,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se ...@@ -66,7 +69,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se
logger.Debug("starting pprof", "addr", pprofConfig.ListenAddr, "port", pprofConfig.ListenPort) logger.Debug("starting pprof", "addr", pprofConfig.ListenAddr, "port", pprofConfig.ListenPort)
pprofSrv, err := oppprof.StartServer(pprofConfig.ListenAddr, pprofConfig.ListenPort) pprofSrv, err := oppprof.StartServer(pprofConfig.ListenAddr, pprofConfig.ListenPort)
if err != nil { if err != nil {
return nil, errors.Join(fmt.Errorf("failed to start pprof server: %w", err), s.Close()) return nil, errors.Join(fmt.Errorf("failed to start pprof server: %w", err), s.Stop(ctx))
} }
s.pprofSrv = pprofSrv s.pprofSrv = pprofSrv
logger.Info("started pprof server", "addr", pprofSrv.Addr()) logger.Info("started pprof server", "addr", pprofSrv.Addr())
...@@ -77,7 +80,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se ...@@ -77,7 +80,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se
logger.Debug("starting metrics server", "addr", metricsCfg.ListenAddr, "port", metricsCfg.ListenPort) logger.Debug("starting metrics server", "addr", metricsCfg.ListenAddr, "port", metricsCfg.ListenPort)
metricsSrv, err := m.Start(metricsCfg.ListenAddr, metricsCfg.ListenPort) metricsSrv, err := m.Start(metricsCfg.ListenAddr, metricsCfg.ListenPort)
if err != nil { if err != nil {
return nil, errors.Join(fmt.Errorf("failed to start metrics server: %w", err), s.Close()) return nil, errors.Join(fmt.Errorf("failed to start metrics server: %w", err), s.Stop(ctx))
} }
logger.Info("started metrics server", "addr", metricsSrv.Addr()) logger.Info("started metrics server", "addr", metricsSrv.Addr())
s.metricsSrv = metricsSrv s.metricsSrv = metricsSrv
...@@ -86,7 +89,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se ...@@ -86,7 +89,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se
factory, err := bindings.NewDisputeGameFactory(cfg.GameFactoryAddress, l1Client) factory, err := bindings.NewDisputeGameFactory(cfg.GameFactoryAddress, l1Client)
if err != nil { if err != nil {
return nil, errors.Join(fmt.Errorf("failed to bind the fault dispute game factory contract: %w", err), s.Close()) return nil, errors.Join(fmt.Errorf("failed to bind the fault dispute game factory contract: %w", err), s.Stop(ctx))
} }
loader := NewGameLoader(factory) loader := NewGameLoader(factory)
...@@ -102,7 +105,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se ...@@ -102,7 +105,7 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se
pollClient, err := opClient.NewRPCWithClient(ctx, logger, cfg.L1EthRpc, opClient.NewBaseRPCClient(l1Client.Client()), cfg.PollInterval) pollClient, err := opClient.NewRPCWithClient(ctx, logger, cfg.L1EthRpc, opClient.NewBaseRPCClient(l1Client.Client()), cfg.PollInterval)
if err != nil { if err != nil {
return nil, errors.Join(fmt.Errorf("failed to create RPC client: %w", err), s.Close()) return nil, errors.Join(fmt.Errorf("failed to create RPC client: %w", err), s.Stop(ctx))
} }
s.monitor = newGameMonitor(logger, cl, loader, s.sched, cfg.GameWindow, l1Client.BlockNumber, cfg.GameAllowlist, pollClient) s.monitor = newGameMonitor(logger, cl, loader, s.sched, cfg.GameWindow, l1Client.BlockNumber, cfg.GameAllowlist, pollClient)
...@@ -115,7 +118,9 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se ...@@ -115,7 +118,9 @@ func NewService(ctx context.Context, logger log.Logger, cfg *config.Config) (*Se
// MonitorGame monitors the fault dispute game and attempts to progress it. // MonitorGame monitors the fault dispute game and attempts to progress it.
func (s *Service) MonitorGame(ctx context.Context) error { func (s *Service) MonitorGame(ctx context.Context) error {
s.sched.Start(ctx) s.sched.Start(ctx)
defer s.sched.Close() err := s.monitor.MonitorGames(ctx)
defer s.Close() // The other ctx is the close-trigger.
return s.monitor.MonitorGames(ctx) // We need to refactor Service more to allow for graceful/force-shutdown granularity.
err = errors.Join(err, s.Stop(context.Background()))
return err
} }
package op_heartbeat package op_heartbeat
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
...@@ -41,7 +42,7 @@ func Main(version string) func(ctx *cli.Context) error { ...@@ -41,7 +42,7 @@ func Main(version string) func(ctx *cli.Context) error {
oplog.SetGlobalLogHandler(l.GetHandler()) oplog.SetGlobalLogHandler(l.GetHandler())
l.Info("starting heartbeat monitor", "version", version) l.Info("starting heartbeat monitor", "version", version)
srv, err := Start(l, cfg, version) srv, err := Start(cliCtx.Context, l, cfg, version)
if err != nil { if err != nil {
l.Crit("error starting application", "err", err) l.Crit("error starting application", "err", err)
} }
...@@ -54,7 +55,7 @@ func Main(version string) func(ctx *cli.Context) error { ...@@ -54,7 +55,7 @@ func Main(version string) func(ctx *cli.Context) error {
syscall.SIGQUIT, syscall.SIGQUIT,
}...) }...)
<-doneCh <-doneCh
return srv.Close() return srv.Stop(context.Background())
} }
} }
...@@ -62,21 +63,21 @@ type HeartbeatService struct { ...@@ -62,21 +63,21 @@ type HeartbeatService struct {
pprof, metrics, http *httputil.HTTPServer pprof, metrics, http *httputil.HTTPServer
} }
func (hs *HeartbeatService) Close() error { func (hs *HeartbeatService) Stop(ctx context.Context) error {
var result error var result error
if hs.pprof != nil { if hs.pprof != nil {
result = errors.Join(result, hs.pprof.Close()) result = errors.Join(result, hs.pprof.Stop(ctx))
} }
if hs.metrics != nil { if hs.metrics != nil {
result = errors.Join(result, hs.metrics.Close()) result = errors.Join(result, hs.metrics.Stop(ctx))
} }
if hs.http != nil { if hs.http != nil {
result = errors.Join(result, hs.http.Close()) result = errors.Join(result, hs.http.Stop(ctx))
} }
return result return result
} }
func Start(l log.Logger, cfg Config, version string) (*HeartbeatService, error) { func Start(ctx context.Context, l log.Logger, cfg Config, version string) (*HeartbeatService, error) {
hs := &HeartbeatService{} hs := &HeartbeatService{}
registry := opmetrics.NewRegistry() registry := opmetrics.NewRegistry()
...@@ -85,7 +86,7 @@ func Start(l log.Logger, cfg Config, version string) (*HeartbeatService, error) ...@@ -85,7 +86,7 @@ func Start(l log.Logger, cfg Config, version string) (*HeartbeatService, error)
l.Debug("starting metrics server", "addr", metricsCfg.ListenAddr, "port", metricsCfg.ListenPort) l.Debug("starting metrics server", "addr", metricsCfg.ListenAddr, "port", metricsCfg.ListenPort)
metricsSrv, err := opmetrics.StartServer(registry, metricsCfg.ListenAddr, metricsCfg.ListenPort) metricsSrv, err := opmetrics.StartServer(registry, metricsCfg.ListenAddr, metricsCfg.ListenPort)
if err != nil { if err != nil {
return nil, errors.Join(fmt.Errorf("failed to start metrics server: %w", err), hs.Close()) return nil, errors.Join(fmt.Errorf("failed to start metrics server: %w", err), hs.Stop(ctx))
} }
hs.metrics = metricsSrv hs.metrics = metricsSrv
l.Info("started metrics server", "addr", metricsSrv.Addr()) l.Info("started metrics server", "addr", metricsSrv.Addr())
...@@ -96,7 +97,7 @@ func Start(l log.Logger, cfg Config, version string) (*HeartbeatService, error) ...@@ -96,7 +97,7 @@ func Start(l log.Logger, cfg Config, version string) (*HeartbeatService, error)
l.Debug("starting pprof", "addr", pprofCfg.ListenAddr, "port", pprofCfg.ListenPort) l.Debug("starting pprof", "addr", pprofCfg.ListenAddr, "port", pprofCfg.ListenPort)
pprofSrv, err := oppprof.StartServer(pprofCfg.ListenAddr, pprofCfg.ListenPort) pprofSrv, err := oppprof.StartServer(pprofCfg.ListenAddr, pprofCfg.ListenPort)
if err != nil { if err != nil {
return nil, errors.Join(fmt.Errorf("failed to start pprof server: %w", err), hs.Close()) return nil, errors.Join(fmt.Errorf("failed to start pprof server: %w", err), hs.Stop(ctx))
} }
l.Info("started pprof server", "addr", pprofSrv.Addr()) l.Info("started pprof server", "addr", pprofSrv.Addr())
hs.pprof = pprofSrv hs.pprof = pprofSrv
...@@ -121,7 +122,7 @@ func Start(l log.Logger, cfg Config, version string) (*HeartbeatService, error) ...@@ -121,7 +122,7 @@ func Start(l log.Logger, cfg Config, version string) (*HeartbeatService, error)
}), }),
httputil.WithMaxHeaderBytes(HTTPMaxHeaderSize)) httputil.WithMaxHeaderBytes(HTTPMaxHeaderSize))
if err != nil { if err != nil {
return nil, errors.Join(fmt.Errorf("failed to start HTTP server: %w", err), hs.Close()) return nil, errors.Join(fmt.Errorf("failed to start HTTP server: %w", err), hs.Stop(ctx))
} }
hs.http = srv hs.http = srv
......
...@@ -31,13 +31,13 @@ func TestService(t *testing.T) { ...@@ -31,13 +31,13 @@ func TestService(t *testing.T) {
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
srv, err := Start(log.New(), cfg, "foobar") srv, err := Start(ctx, log.New(), cfg, "foobar")
// Make sure that the service properly starts // Make sure that the service properly starts
require.NoError(t, err) require.NoError(t, err)
defer cancel() defer cancel()
defer func() { defer func() {
require.NoError(t, srv.Close(), "close heartbeat server") require.NoError(t, srv.Stop(ctx), "close heartbeat server")
}() }()
tests := []struct { tests := []struct {
......
...@@ -555,7 +555,9 @@ func (n *OpNode) Stop(ctx context.Context) error { ...@@ -555,7 +555,9 @@ func (n *OpNode) Stop(ctx context.Context) error {
var result *multierror.Error var result *multierror.Error
if n.server != nil { if n.server != nil {
n.server.Stop() if err := n.server.Stop(ctx); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close RPC server: %w", err))
}
} }
if n.p2pNode != nil { if n.p2pNode != nil {
if err := n.p2pNode.Close(); err != nil { if err := n.p2pNode.Close(); err != nil {
...@@ -623,12 +625,12 @@ func (n *OpNode) Stop(ctx context.Context) error { ...@@ -623,12 +625,12 @@ func (n *OpNode) Stop(ctx context.Context) error {
// Close metrics and pprof only after we are done idling // Close metrics and pprof only after we are done idling
if n.pprofSrv != nil { if n.pprofSrv != nil {
if err := n.pprofSrv.Close(); err != nil { if err := n.pprofSrv.Stop(ctx); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close pprof server: %w", err)) result = multierror.Append(result, fmt.Errorf("failed to close pprof server: %w", err))
} }
} }
if n.metricsSrv != nil { if n.metricsSrv != nil {
if err := n.metricsSrv.Close(); err != nil { if err := n.metricsSrv.Stop(ctx); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close metrics server: %w", err)) result = multierror.Append(result, fmt.Errorf("failed to close metrics server: %w", err))
} }
} }
...@@ -640,10 +642,9 @@ func (n *OpNode) Stopped() bool { ...@@ -640,10 +642,9 @@ func (n *OpNode) Stopped() bool {
return n.closed.Load() return n.closed.Load()
} }
func (n *OpNode) ListenAddr() string {
return n.server.listenAddr.String()
}
func (n *OpNode) HTTPEndpoint() string { func (n *OpNode) HTTPEndpoint() string {
return fmt.Sprintf("http://%s", n.ListenAddr()) if n.server == nil {
return ""
}
return fmt.Sprintf("http://%s", n.server.Addr().String())
} }
...@@ -2,7 +2,7 @@ package node ...@@ -2,7 +2,7 @@ package node
import ( import (
"context" "context"
"errors" "fmt"
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
...@@ -21,9 +21,8 @@ import ( ...@@ -21,9 +21,8 @@ import (
type rpcServer struct { type rpcServer struct {
endpoint string endpoint string
apis []rpc.API apis []rpc.API
httpServer *http.Server httpServer *ophttp.HTTPServer
appVersion string appVersion string
listenAddr net.Addr
log log.Logger log log.Logger
sources.L2Client sources.L2Client
} }
...@@ -79,27 +78,20 @@ func (s *rpcServer) Start() error { ...@@ -79,27 +78,20 @@ func (s *rpcServer) Start() error {
mux.Handle("/", nodeHandler) mux.Handle("/", nodeHandler)
mux.HandleFunc("/healthz", healthzHandler(s.appVersion)) mux.HandleFunc("/healthz", healthzHandler(s.appVersion))
listener, err := net.Listen("tcp", s.endpoint) hs, err := ophttp.StartHTTPServer(s.endpoint, mux)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to start HTTP RPC server: %w", err)
} }
s.listenAddr = listener.Addr() s.httpServer = hs
s.httpServer = ophttp.NewHttpServer(mux)
go func() {
if err := s.httpServer.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { // todo improve error handling
s.log.Error("http server failed", "err", err)
}
}()
return nil return nil
} }
func (r *rpcServer) Stop() { func (r *rpcServer) Stop(ctx context.Context) error {
_ = r.httpServer.Shutdown(context.Background()) return r.httpServer.Stop(ctx)
} }
func (r *rpcServer) Addr() net.Addr { func (r *rpcServer) Addr() net.Addr {
return r.listenAddr return r.httpServer.Addr()
} }
func healthzHandler(appVersion string) http.HandlerFunc { func healthzHandler(appVersion string) http.HandlerFunc {
......
...@@ -104,7 +104,9 @@ func TestOutputAtBlock(t *testing.T) { ...@@ -104,7 +104,9 @@ func TestOutputAtBlock(t *testing.T) {
server, err := newRPCServer(context.Background(), rpcCfg, rollupCfg, l2Client, drClient, log, "0.0", metrics.NoopMetrics) server, err := newRPCServer(context.Background(), rpcCfg, rollupCfg, l2Client, drClient, log, "0.0", metrics.NoopMetrics)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, server.Start()) require.NoError(t, server.Start())
defer server.Stop() defer func() {
require.NoError(t, server.Stop(context.Background()))
}()
client, err := rpcclient.NewRPC(context.Background(), log, "http://"+server.Addr().String(), rpcclient.WithDialBackoff(3)) client, err := rpcclient.NewRPC(context.Background(), log, "http://"+server.Addr().String(), rpcclient.WithDialBackoff(3))
require.NoError(t, err) require.NoError(t, err)
...@@ -136,7 +138,9 @@ func TestVersion(t *testing.T) { ...@@ -136,7 +138,9 @@ func TestVersion(t *testing.T) {
server, err := newRPCServer(context.Background(), rpcCfg, rollupCfg, l2Client, drClient, log, "0.0", metrics.NoopMetrics) server, err := newRPCServer(context.Background(), rpcCfg, rollupCfg, l2Client, drClient, log, "0.0", metrics.NoopMetrics)
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, server.Start()) assert.NoError(t, server.Start())
defer server.Stop() defer func() {
require.NoError(t, server.Stop(context.Background()))
}()
client, err := rpcclient.NewRPC(context.Background(), log, "http://"+server.Addr().String(), rpcclient.WithDialBackoff(3)) client, err := rpcclient.NewRPC(context.Background(), log, "http://"+server.Addr().String(), rpcclient.WithDialBackoff(3))
assert.NoError(t, err) assert.NoError(t, err)
...@@ -180,7 +184,9 @@ func TestSyncStatus(t *testing.T) { ...@@ -180,7 +184,9 @@ func TestSyncStatus(t *testing.T) {
server, err := newRPCServer(context.Background(), rpcCfg, rollupCfg, l2Client, drClient, log, "0.0", metrics.NoopMetrics) server, err := newRPCServer(context.Background(), rpcCfg, rollupCfg, l2Client, drClient, log, "0.0", metrics.NoopMetrics)
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, server.Start()) assert.NoError(t, server.Start())
defer server.Stop() defer func() {
require.NoError(t, server.Stop(context.Background()))
}()
client, err := rpcclient.NewRPC(context.Background(), log, "http://"+server.Addr().String(), rpcclient.WithDialBackoff(3)) client, err := rpcclient.NewRPC(context.Background(), log, "http://"+server.Addr().String(), rpcclient.WithDialBackoff(3))
assert.NoError(t, err) assert.NoError(t, err)
......
...@@ -79,7 +79,11 @@ func Main(version string, cliCtx *cli.Context) error { ...@@ -79,7 +79,11 @@ func Main(version string, cliCtx *cli.Context) error {
return err return err
} }
l.Info("started pprof server", "addr", pprofSrv.Addr()) l.Info("started pprof server", "addr", pprofSrv.Addr())
defer pprofSrv.Close() defer func() {
if err := pprofSrv.Stop(context.Background()); err != nil {
l.Error("failed to stop pprof server", "err", err)
}
}()
} }
metricsCfg := cfg.MetricsConfig metricsCfg := cfg.MetricsConfig
...@@ -90,7 +94,11 @@ func Main(version string, cliCtx *cli.Context) error { ...@@ -90,7 +94,11 @@ func Main(version string, cliCtx *cli.Context) error {
return fmt.Errorf("failed to start metrics server: %w", err) return fmt.Errorf("failed to start metrics server: %w", err)
} }
l.Info("started metrics server", "addr", metricsSrv.Addr()) l.Info("started metrics server", "addr", metricsSrv.Addr())
defer metricsSrv.Close() defer func() {
if err := metricsSrv.Stop(context.Background()); err != nil {
l.Error("failed to stop metrics server", "err", err)
}
}()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
m.StartBalanceMetrics(ctx, l, proposerConfig.L1Client, proposerConfig.TxManager.From()) m.StartBalanceMetrics(ctx, l, proposerConfig.L1Client, proposerConfig.TxManager.From())
......
package httputil package httputil
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net" "net"
...@@ -25,21 +26,27 @@ func StartHTTPServer(addr string, handler http.Handler, opts ...HTTPOption) (*HT ...@@ -25,21 +26,27 @@ func StartHTTPServer(addr string, handler http.Handler, opts ...HTTPOption) (*HT
return nil, fmt.Errorf("failed to bind to address %q: %w", addr, err) return nil, fmt.Errorf("failed to bind to address %q: %w", addr, err)
} }
srvCtx, srvCancel := context.WithCancel(context.Background())
srv := &http.Server{ srv := &http.Server{
Handler: handler, Handler: handler,
ReadTimeout: DefaultTimeouts.ReadTimeout, ReadTimeout: DefaultTimeouts.ReadTimeout,
ReadHeaderTimeout: DefaultTimeouts.ReadHeaderTimeout, ReadHeaderTimeout: DefaultTimeouts.ReadHeaderTimeout,
WriteTimeout: DefaultTimeouts.WriteTimeout, WriteTimeout: DefaultTimeouts.WriteTimeout,
IdleTimeout: DefaultTimeouts.IdleTimeout, IdleTimeout: DefaultTimeouts.IdleTimeout,
BaseContext: func(listener net.Listener) context.Context {
return srvCtx
},
} }
out := &HTTPServer{listener: listener, srv: srv} out := &HTTPServer{listener: listener, srv: srv}
for _, opt := range opts { for _, opt := range opts {
if err := opt(out); err != nil { if err := opt(out); err != nil {
srvCancel()
return nil, errors.Join(fmt.Errorf("failed to apply HTTP option: %w", err), listener.Close()) return nil, errors.Join(fmt.Errorf("failed to apply HTTP option: %w", err), listener.Close())
} }
} }
go func() { go func() {
err := out.srv.Serve(listener) err := out.srv.Serve(listener)
srvCancel()
// no error, unless ErrServerClosed (or unused base context closes, or unused http2 config error) // no error, unless ErrServerClosed (or unused base context closes, or unused http2 config error)
if errors.Is(err, http.ErrServerClosed) { if errors.Is(err, http.ErrServerClosed) {
out.closed.Store(true) out.closed.Store(true)
...@@ -54,17 +61,35 @@ func (s *HTTPServer) Closed() bool { ...@@ -54,17 +61,35 @@ func (s *HTTPServer) Closed() bool {
return s.closed.Load() return s.closed.Load()
} }
// Stop is a convenience method to gracefully shut down the server, but force-close if the ctx is cancelled.
// The ctx error is not returned when the force-close is successful.
func (s *HTTPServer) Stop(ctx context.Context) error {
if err := s.Shutdown(ctx); err != nil {
if errors.Is(err, ctx.Err()) { // force-close connections if we cancelled the stopping
return s.Close()
}
return err
}
return nil
}
// Shutdown shuts down the HTTP server and its listener,
// but allows active connections to close gracefully.
// If the function exits due to a ctx cancellation the listener is closed but active connections may remain,
// a call to Close() can force-close any remaining active connections.
func (s *HTTPServer) Shutdown(ctx context.Context) error {
// closes the underlying listener too.
return s.srv.Shutdown(ctx)
}
// Close force-closes the HTTPServer, its listener, and all its active connections.
func (s *HTTPServer) Close() error { func (s *HTTPServer) Close() error {
// closes the underlying listener too // closes the underlying listener too
err := s.srv.Close() return s.srv.Close()
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
} }
func (s *HTTPServer) Addr() string { func (s *HTTPServer) Addr() net.Addr {
return s.listener.Addr().String() return s.listener.Addr()
} }
func WithMaxHeaderBytes(max int) HTTPOption { func WithMaxHeaderBytes(max int) HTTPOption {
......
package httputil
import (
"context"
"net/http"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStartHTTPServer(t *testing.T) {
testSetup := func(t *testing.T) (srv *HTTPServer, reqRespBlock chan chan chan struct{}) {
reqRespBlock = make(chan chan chan struct{}, 10)
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, r.Context().Err())
respBlock := make(chan chan struct{})
reqRespBlock <- respBlock
select {
case block := <-respBlock:
block <- struct{}{}
w.WriteHeader(http.StatusTeapot)
case <-r.Context().Done():
w.WriteHeader(http.StatusServiceUnavailable)
}
})
srv, err := StartHTTPServer("localhost:0", h, WithTimeouts(HTTPTimeouts{
ReadTimeout: time.Minute,
ReadHeaderTimeout: time.Minute,
WriteTimeout: time.Minute,
IdleTimeout: time.Minute,
}))
require.NoError(t, err)
require.False(t, srv.Closed())
return srv, reqRespBlock
}
t.Run("basics", func(t *testing.T) {
srv, reqRespBlock := testSetup(t)
// test basics
go func() {
req := <-reqRespBlock // take request
block := make(chan struct{})
req <- block // start response
<-block // unblock response
}()
resp, err := http.Get("http://" + srv.Addr().String() + "/")
require.NoError(t, err)
assert.NoError(t, resp.Body.Close())
assert.Equal(t, http.StatusTeapot, resp.StatusCode, "I am a teapot")
assert.NoError(t, srv.Close())
assert.True(t, srv.Closed())
})
t.Run("force-shutdown", func(t *testing.T) {
srv, reqRespBlock := testSetup(t)
var wg sync.WaitGroup
wg.Add(1)
go func() {
resp, err := http.Get("http://" + srv.Addr().String() + "/")
assert.ErrorContains(t, err, "EOF") // error must indicate connection is force-closed
if resp != nil {
assert.NoError(t, resp.Body.Close()) // makes linter happy
}
wg.Done()
}()
req := <-reqRespBlock // take the request
block := make(chan struct{})
req <- block // start response
// just force-shutdown the server
assert.NoError(t, srv.Close())
wg.Wait()
// only now unblock the response
<-block
require.True(t, srv.Closed())
})
t.Run("graceful", func(t *testing.T) {
srv, reqRespBlock := testSetup(t)
var wg sync.WaitGroup
wg.Add(1)
go func() {
resp, err := http.Get("http://" + srv.Addr().String() + "/")
assert.NoError(t, err)
assert.NoError(t, resp.Body.Close())
assert.Equal(t, http.StatusServiceUnavailable, resp.StatusCode, "service unavailable when shutting down")
wg.Done()
}()
// Wait for a request, but don't start a response to it, just try to shut down the server
// The base-context will be shut down, allowing the server to stop waiting for the user,
// and gracefully tell the user it's not able to continue.
<-reqRespBlock
assert.NoError(t, srv.Shutdown(context.Background()))
wg.Wait()
require.True(t, srv.Closed())
})
}
...@@ -423,7 +423,11 @@ var ( ...@@ -423,7 +423,11 @@ var (
if err != nil { if err != nil {
return fmt.Errorf("failed to start metrics server: %w", err) return fmt.Errorf("failed to start metrics server: %w", err)
} }
defer metricsSrv.Close() defer func() {
if err := metricsSrv.Stop(context.Background()); err != nil {
l.Error("failed to stop metrics server: %w", err)
}
}()
} }
return engine.Auto(ctx, metrics, client, l, shutdown, settings) return engine.Auto(ctx, metrics, client, l, shutdown, settings)
}) })
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment