Commit 8d8219f5 authored by protolambda's avatar protolambda Committed by GitHub

op-heartbeat: handle heartbeat spamming (#4507)

parent 0669b2eb
package op_heartbeat
import (
"fmt"
"strconv"
"sync/atomic"
"time"
"github.com/ethereum-optimism/optimism/op-node/heartbeat"
lru "github.com/hashicorp/golang-lru"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/ethereum-optimism/optimism/op-node/heartbeat"
)
const MetricsNamespace = "op_heartbeat"
const (
MetricsNamespace = "op_heartbeat"
MinHeartbeatInterval = 10*time.Minute - 10*time.Second
UsersCacheSize = 10_000
)
type Metrics interface {
RecordHeartbeat(payload heartbeat.Payload)
RecordHeartbeat(payload heartbeat.Payload, ip string)
RecordVersion(version string)
}
type metrics struct {
heartbeats *prometheus.CounterVec
version *prometheus.GaugeVec
sameIP *prometheus.HistogramVec
// Groups heartbeats per unique IP, version and chain ID combination.
// string(IP ++ version ++ chainID) -> *heartbeatEntry
heartbeatUsers *lru.Cache
}
type heartbeatEntry struct {
// Count number of heartbeats per interval, atomically updated
Count uint64
// Changes once per heartbeat interval
Time time.Time
}
func NewMetrics(r *prometheus.Registry) Metrics {
lruCache, _ := lru.New(UsersCacheSize)
m := &metrics{
heartbeats: promauto.With(r).NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "heartbeats",
Help: "Counts number of heartbeats by chain ID",
Help: "Counts number of heartbeats by chain ID, version and filtered to unique IPs",
}, []string{
"chain_id",
"version",
......@@ -37,11 +59,21 @@ func NewMetrics(r *prometheus.Registry) Metrics {
}, []string{
"version",
}),
sameIP: promauto.With(r).NewHistogramVec(prometheus.HistogramOpts{
Namespace: MetricsNamespace,
Name: "heartbeat_same_ip",
Buckets: []float64{1, 2, 4, 8, 16, 32, 64, 128},
Help: "Histogram of events within same heartbeat interval per unique IP, by chain ID and version",
}, []string{
"chain_id",
"version",
}),
heartbeatUsers: lruCache,
}
return m
}
func (m *metrics) RecordHeartbeat(payload heartbeat.Payload) {
func (m *metrics) RecordHeartbeat(payload heartbeat.Payload, ip string) {
var chainID string
if AllowedChainIDs[payload.ChainID] {
chainID = strconv.FormatUint(payload.ChainID, 10)
......@@ -54,7 +86,32 @@ func (m *metrics) RecordHeartbeat(payload heartbeat.Payload) {
} else {
version = "unknown"
}
m.heartbeats.WithLabelValues(chainID, version).Inc()
key := fmt.Sprintf("%s;%s;%s", ip, version, chainID)
now := time.Now()
previous, ok, _ := m.heartbeatUsers.PeekOrAdd(key, &heartbeatEntry{Time: now, Count: 1})
if !ok {
// if it's a new entry, observe it and exit.
m.sameIP.WithLabelValues(chainID, version).Observe(1)
m.heartbeats.WithLabelValues(chainID, version).Inc()
return
}
entry := previous.(*heartbeatEntry)
if now.Sub(entry.Time) < MinHeartbeatInterval {
// if the span is still going, then add it up
atomic.AddUint64(&entry.Count, 1)
} else {
// if the span ended, then meter it, and reset it
m.sameIP.WithLabelValues(chainID, version).Observe(float64(atomic.LoadUint64(&entry.Count)))
entry.Time = now
atomic.StoreUint64(&entry.Count, 1)
m.heartbeats.WithLabelValues(chainID, version).Inc()
}
// always add, to keep LRU accurate
m.heartbeatUsers.Add(key, entry)
}
func (m *metrics) RecordVersion(version string) {
......
......@@ -10,19 +10,25 @@ import (
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/urfave/cli"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum-optimism/optimism/op-node/heartbeat"
"github.com/ethereum-optimism/optimism/op-service/httputil"
oplog "github.com/ethereum-optimism/optimism/op-service/log"
opmetrics "github.com/ethereum-optimism/optimism/op-service/metrics"
oppprof "github.com/ethereum-optimism/optimism/op-service/pprof"
"github.com/ethereum/go-ethereum/log"
"github.com/urfave/cli"
)
const HTTPMaxBodySize = 1024 * 1024
const (
HTTPMaxHeaderSize = 10 * 1024
HTTPMaxBodySize = 1024 * 1024
)
func Main(version string) func(ctx *cli.Context) error {
return func(cliCtx *cli.Context) error {
......@@ -87,7 +93,7 @@ func Start(ctx context.Context, l log.Logger, cfg Config, version string) error
server := &http.Server{
Addr: net.JoinHostPort(cfg.HTTPAddr, strconv.Itoa(cfg.HTTPPort)),
MaxHeaderBytes: HTTPMaxBodySize,
MaxHeaderBytes: HTTPMaxHeaderSize,
Handler: mw,
WriteTimeout: 30 * time.Second,
IdleTimeout: time.Minute,
......@@ -99,8 +105,14 @@ func Start(ctx context.Context, l log.Logger, cfg Config, version string) error
func Handler(l log.Logger, metrics Metrics) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ipStr := r.Header.Get("X-Forwarded-For")
// XFF can be a comma-separated list. Left-most is the original client.
if i := strings.Index(ipStr, ","); i >= 0 {
ipStr = ipStr[:i]
}
innerL := l.New(
"xff", r.Header.Get("X-Forwarded-For"),
"ip", ipStr,
"user_agent", r.Header.Get("User-Agent"),
"remote_addr", r.RemoteAddr,
)
......@@ -122,7 +134,7 @@ func Handler(l log.Logger, metrics Metrics) http.HandlerFunc {
"chain_id", payload.ChainID,
)
metrics.RecordHeartbeat(payload)
metrics.RecordHeartbeat(payload, ipStr)
w.WriteHeader(204)
}
......
......@@ -11,10 +11,11 @@ import (
"testing"
"time"
"github.com/ethereum-optimism/optimism/op-node/heartbeat"
opmetrics "github.com/ethereum-optimism/optimism/op-service/metrics"
"github.com/ethereum/go-ethereum/log"
"github.com/stretchr/testify/require"
"github.com/ethereum-optimism/optimism/op-node/heartbeat"
opmetrics "github.com/ethereum-optimism/optimism/op-service/metrics"
)
func TestService(t *testing.T) {
......@@ -45,58 +46,82 @@ func TestService(t *testing.T) {
}
tests := []struct {
name string
hb heartbeat.Payload
metricName string
metricValue int
name string
hbs []heartbeat.Payload
metric string
ip string
}{
{
"no whitelisted version",
heartbeat.Payload{
[]heartbeat.Payload{{
Version: "not_whitelisted",
Meta: "whatever",
Moniker: "whatever",
PeerID: "1X2398ug",
ChainID: 10,
},
`op_heartbeat_heartbeats{chain_id="10",version="unknown"}`,
1,
}},
`op_heartbeat_heartbeats{chain_id="10",version="unknown"} 1`,
"1.2.3.100",
},
{
"no whitelisted chain",
heartbeat.Payload{
[]heartbeat.Payload{{
Version: "v0.1.0-beta.1",
Meta: "whatever",
Moniker: "whatever",
PeerID: "1X2398ug",
ChainID: 999,
},
`op_heartbeat_heartbeats{chain_id="unknown",version="v0.1.0-beta.1"}`,
1,
}},
`op_heartbeat_heartbeats{chain_id="unknown",version="v0.1.0-beta.1"} 1`,
"1.2.3.101",
},
{
"both whitelisted",
heartbeat.Payload{
[]heartbeat.Payload{{
Version: "v0.1.0-beta.1",
Meta: "whatever",
Moniker: "whatever",
PeerID: "1X2398ug",
ChainID: 10,
}},
`op_heartbeat_heartbeats{chain_id="10",version="v0.1.0-beta.1"} 1`,
"1.2.3.102",
},
{
"spamming",
[]heartbeat.Payload{
{
Version: "v0.1.0-goerli-rehearsal.1",
Meta: "whatever",
Moniker: "alice",
PeerID: "1X2398ug",
ChainID: 10,
},
{
Version: "v0.1.0-goerli-rehearsal.1",
Meta: "whatever",
Moniker: "bob",
PeerID: "1X2398ug",
ChainID: 10,
},
},
`op_heartbeat_heartbeats{chain_id="10",version="v0.1.0-beta.1"}`,
1,
`op_heartbeat_heartbeat_same_ip_bucket{chain_id="10",version="v0.1.0-goerli-rehearsal.1",le="32"} 1`,
"1.2.3.103",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.hb)
require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://127.0.0.1:%d", httpPort), bytes.NewReader(data))
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, res.StatusCode, 204)
for _, hb := range tt.hbs {
data, err := json.Marshal(hb)
require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://127.0.0.1:%d", httpPort), bytes.NewReader(data))
require.NoError(t, err)
req.Header.Set("X-Forwarded-For", tt.ip)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
res.Body.Close()
require.Equal(t, res.StatusCode, 204)
}
metricsRes, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d", metricsPort))
require.NoError(t, err)
......@@ -104,7 +129,7 @@ func TestService(t *testing.T) {
require.NoError(t, err)
metricsBody, err := io.ReadAll(metricsRes.Body)
require.NoError(t, err)
require.Contains(t, string(metricsBody), fmt.Sprintf("%s %d", tt.metricName, tt.metricValue))
require.Contains(t, string(metricsBody), tt.metric)
})
}
......
......@@ -11,7 +11,8 @@ import (
"github.com/ethereum/go-ethereum/log"
)
var SendInterval = 10 * time.Minute
// SendInterval determines the delay between requests. This must be larger than the MinHeartbeatInterval in the server.
const SendInterval = 10 * time.Minute
type Payload struct {
Version string `json:"version"`
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment