Commit a3b09d17 authored by Matthew Slipper's avatar Matthew Slipper

proxyd: Handle oversize requests/backend responses

Under certain conditions, users could provide batch requests that cause the upstream Geth nodes to return very large responses. These responses were being handled improperly, and rather than returning an error were being truncated and leading to backends being marked as offline. This issue extended to large client requests as well.

This PR also enables `pprof` on proxyd, which was used to debug this problem.
parent a7f93a08
...@@ -98,6 +98,18 @@ var ( ...@@ -98,6 +98,18 @@ var (
HTTPErrorCode: 400, HTTPErrorCode: 400,
} }
ErrRequestBodyTooLarge = &RPCErr{
Code: JSONRPCErrorInternal - 21,
Message: "request body too large",
HTTPErrorCode: 413,
}
ErrBackendResponseTooLarge = &RPCErr{
Code: JSONRPCErrorInternal - 20,
Message: "backend response too large",
HTTPErrorCode: 500,
}
ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response") ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response")
ErrConsensusGetReceiptsCantBeBatched = errors.New("consensus_getReceipts cannot be batched") ErrConsensusGetReceiptsCantBeBatched = errors.New("consensus_getReceipts cannot be batched")
...@@ -339,6 +351,14 @@ func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([] ...@@ -339,6 +351,14 @@ func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([]
res, err := b.doForward(ctx, reqs, isBatch) res, err := b.doForward(ctx, reqs, isBatch)
switch err { switch err {
case nil: // do nothing case nil: // do nothing
case ErrBackendResponseTooLarge:
log.Warn(
"backend response too large",
"name", b.Name,
"req_id", GetReqID(ctx),
"max", b.maxResponseSize,
)
RecordBatchRPCError(ctx, b.Name, reqs, err)
case ErrConsensusGetReceiptsCantBeBatched: case ErrConsensusGetReceiptsCantBeBatched:
log.Warn( log.Warn(
"Received unsupported batch request for consensus_getReceipts", "Received unsupported batch request for consensus_getReceipts",
...@@ -543,7 +563,10 @@ func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool ...@@ -543,7 +563,10 @@ func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool
} }
defer httpRes.Body.Close() defer httpRes.Body.Close()
resB, err := io.ReadAll(io.LimitReader(httpRes.Body, b.maxResponseSize)) resB, err := io.ReadAll(LimitReader(httpRes.Body, b.maxResponseSize))
if errors.Is(err, ErrLimitReaderOverLimit) {
return nil, ErrBackendResponseTooLarge
}
if err != nil { if err != nil {
b.networkErrorsSlidingWindow.Incr() b.networkErrorsSlidingWindow.Incr()
RecordBackendNetworkErrorRateSlidingWindow(b, b.ErrorRate()) RecordBackendNetworkErrorRateSlidingWindow(b, b.ErrorRate())
...@@ -726,6 +749,8 @@ func (bg *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch ...@@ -726,6 +749,8 @@ func (bg *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch
res := make([]*RPCRes, 0) res := make([]*RPCRes, 0)
var err error var err error
servedBy := fmt.Sprintf("%s/%s", bg.Name, back.Name)
if len(rpcReqs) > 0 { if len(rpcReqs) > 0 {
res, err = back.Forward(ctx, rpcReqs, isBatch) res, err = back.Forward(ctx, rpcReqs, isBatch)
if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) || if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) ||
...@@ -733,6 +758,9 @@ func (bg *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch ...@@ -733,6 +758,9 @@ func (bg *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch
errors.Is(err, ErrMethodNotWhitelisted) { errors.Is(err, ErrMethodNotWhitelisted) {
return nil, "", err return nil, "", err
} }
if errors.Is(err, ErrBackendResponseTooLarge) {
return nil, servedBy, err
}
if errors.Is(err, ErrBackendOffline) { if errors.Is(err, ErrBackendOffline) {
log.Warn( log.Warn(
"skipping offline backend", "skipping offline backend",
...@@ -773,7 +801,6 @@ func (bg *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch ...@@ -773,7 +801,6 @@ func (bg *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch
} }
} }
servedBy := fmt.Sprintf("%s/%s", bg.Name, back.Name)
return res, servedBy, nil return res, servedBy, nil
} }
......
package main package main
import ( import (
"net"
"net/http"
"net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"strconv"
"syscall" "syscall"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
...@@ -52,6 +56,17 @@ func main() { ...@@ -52,6 +56,17 @@ func main() {
), ),
) )
if config.Server.EnablePprof {
log.Info("starting pprof", "addr", "0.0.0.0", "port", "6060")
pprofSrv := StartPProf("0.0.0.0", 6060)
log.Info("started pprof server", "addr", pprofSrv.Addr)
defer func() {
if err := pprofSrv.Close(); err != nil {
log.Error("failed to stop pprof server", "err", err)
}
}()
}
_, shutdown, err := proxyd.Start(config) _, shutdown, err := proxyd.Start(config)
if err != nil { if err != nil {
log.Crit("error starting proxyd", "err", err) log.Crit("error starting proxyd", "err", err)
...@@ -63,3 +78,25 @@ func main() { ...@@ -63,3 +78,25 @@ func main() {
log.Info("caught signal, shutting down", "signal", recvSig) log.Info("caught signal, shutting down", "signal", recvSig)
shutdown() shutdown()
} }
func StartPProf(hostname string, port int) *http.Server {
mux := http.NewServeMux()
// have to do below to support multiple servers, since the
// pprof import only uses DefaultServeMux
mux.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index))
mux.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline))
mux.Handle("/debug/pprof/profile", http.HandlerFunc(pprof.Profile))
mux.Handle("/debug/pprof/symbol", http.HandlerFunc(pprof.Symbol))
mux.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace))
addr := net.JoinHostPort(hostname, strconv.Itoa(port))
srv := &http.Server{
Handler: mux,
Addr: addr,
}
go srv.ListenAndServe()
return srv
}
...@@ -24,7 +24,7 @@ type ServerConfig struct { ...@@ -24,7 +24,7 @@ type ServerConfig struct {
EnableRequestLog bool `toml:"enable_request_log"` EnableRequestLog bool `toml:"enable_request_log"`
MaxRequestBodyLogLen int `toml:"max_request_body_log_len"` MaxRequestBodyLogLen int `toml:"max_request_body_log_len"`
EnablePprof bool `toml:"enable_pprof"`
EnableXServedByHeader bool `toml:"enable_served_by_header"` EnableXServedByHeader bool `toml:"enable_served_by_header"`
} }
......
This diff is collapsed.
...@@ -20,6 +20,9 @@ func TestBatching(t *testing.T) { ...@@ -20,6 +20,9 @@ func TestBatching(t *testing.T) {
ethAccountsResponse2 := `{"jsonrpc": "2.0", "result": [], "id": 2}` ethAccountsResponse2 := `{"jsonrpc": "2.0", "result": [], "id": 2}`
backendResTooLargeResponse1 := `{"error":{"code":-32020,"message":"backend response too large"},"id":1,"jsonrpc":"2.0"}`
backendResTooLargeResponse2 := `{"error":{"code":-32020,"message":"backend response too large"},"id":2,"jsonrpc":"2.0"}`
type mockResult struct { type mockResult struct {
method string method string
id string id string
...@@ -40,6 +43,7 @@ func TestBatching(t *testing.T) { ...@@ -40,6 +43,7 @@ func TestBatching(t *testing.T) {
expectedRes string expectedRes string
maxUpstreamBatchSize int maxUpstreamBatchSize int
numExpectedForwards int numExpectedForwards int
maxResponseSizeBytes int64
}{ }{
{ {
name: "backend returns batches out of order", name: "backend returns batches out of order",
...@@ -128,11 +132,24 @@ func TestBatching(t *testing.T) { ...@@ -128,11 +132,24 @@ func TestBatching(t *testing.T) {
maxUpstreamBatchSize: 2, maxUpstreamBatchSize: 2,
numExpectedForwards: 1, numExpectedForwards: 1,
}, },
{
name: "large upstream response gets dropped",
mocks: []mockResult{chainIDMock1, chainIDMock2},
reqs: []*proxyd.RPCReq{
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("2", "eth_chainId", nil),
},
expectedRes: asArray(backendResTooLargeResponse1, backendResTooLargeResponse2),
maxUpstreamBatchSize: 2,
numExpectedForwards: 1,
maxResponseSizeBytes: 1,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
config.Server.MaxUpstreamBatchSize = tt.maxUpstreamBatchSize config.Server.MaxUpstreamBatchSize = tt.maxUpstreamBatchSize
config.BackendOptions.MaxResponseSizeBytes = tt.maxResponseSizeBytes
handler := tt.handler handler := tt.handler
if handler == nil { if handler == nil {
......
whitelist_error_message = "rpc method is not whitelisted custom message"
[server]
rpc_port = 8545
max_request_body_size_bytes = 150
[backend]
response_timeout_seconds = 1
max_response_size_bytes = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
\ No newline at end of file
package integration_tests package integration_tests
import ( import (
"fmt"
"os" "os"
"strings" "strings"
"testing" "testing"
...@@ -227,6 +228,31 @@ func TestBatchRPCValidation(t *testing.T) { ...@@ -227,6 +228,31 @@ func TestBatchRPCValidation(t *testing.T) {
} }
} }
func TestSizeLimits(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("size_limits")
client := NewProxydClient("http://127.0.0.1:8545")
_, shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
payload := strings.Repeat("barf", 1024*1024)
out, code, err := client.SendRequest([]byte(fmt.Sprintf(`{"jsonrpc": "2.0", "method": "eth_chainId", "params": [%s], "id": 1}`, payload)))
require.NoError(t, err)
require.Equal(t, `{"jsonrpc":"2.0","error":{"code":-32021,"message":"request body too large"},"id":null}`, strings.TrimSpace(string(out)))
require.Equal(t, 413, code)
// The default response is already over the size limit in size_limits.toml.
out, code, err = client.SendRequest([]byte(`{"jsonrpc": "2.0", "method": "eth_chainId", "params": [], "id": 1}`))
require.NoError(t, err)
require.Equal(t, `{"jsonrpc":"2.0","error":{"code":-32020,"message":"backend response too large"},"id":1}`, strings.TrimSpace(string(out)))
require.Equal(t, 500, code)
}
func asArray(in ...string) string { func asArray(in ...string) string {
return "[" + strings.Join(in, ",") + "]" return "[" + strings.Join(in, ",") + "]"
} }
package proxyd
import (
"errors"
"io"
)
var ErrLimitReaderOverLimit = errors.New("over read limit")
func LimitReader(r io.Reader, n int64) io.Reader { return &LimitedReader{r, n} }
// A LimitedReader reads from R but limits the amount of
// data returned to just N bytes. Each call to Read
// updates N to reflect the new amount remaining.
// Unlike the standard library version, Read returns
// ErrLimitReaderOverLimit when N <= 0.
type LimitedReader struct {
R io.Reader // underlying reader
N int64 // max bytes remaining
}
func (l *LimitedReader) Read(p []byte) (int, error) {
if l.N <= 0 {
return 0, ErrLimitReaderOverLimit
}
if int64(len(p)) > l.N {
p = p[0:l.N]
}
n, err := l.R.Read(p)
l.N -= int64(n)
return n, err
}
package proxyd
import (
"github.com/stretchr/testify/require"
"io"
"strings"
"testing"
)
func TestLimitReader(t *testing.T) {
data := "hellohellohellohello"
r := LimitReader(strings.NewReader(data), 3)
buf := make([]byte, 3)
// Buffer reads OK
n, err := r.Read(buf)
require.NoError(t, err)
require.Equal(t, 3, n)
// Buffer is over limit
n, err = r.Read(buf)
require.Equal(t, ErrLimitReaderOverLimit, err)
require.Equal(t, 0, n)
// Buffer on initial read is over size
buf = make([]byte, 16)
r = LimitReader(strings.NewReader(data), 3)
n, err = r.Read(buf)
require.NoError(t, err)
require.Equal(t, 3, n)
// test with read all where the limit is less than the data
r = LimitReader(strings.NewReader(data), 3)
out, err := io.ReadAll(r)
require.Equal(t, ErrLimitReaderOverLimit, err)
require.Equal(t, "hel", string(out))
// test with read all where the limit is more than the data
r = LimitReader(strings.NewReader(data), 21)
out, err = io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, data, string(out))
}
...@@ -319,7 +319,13 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ...@@ -319,7 +319,13 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
"remote_ip", xff, "remote_ip", xff,
) )
body, err := io.ReadAll(io.LimitReader(r.Body, s.maxBodySize)) body, err := io.ReadAll(LimitReader(r.Body, s.maxBodySize))
if errors.Is(err, ErrLimitReaderOverLimit) {
log.Error("request body too large", "req_id", GetReqID(ctx))
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrRequestBodyTooLarge)
writeRPCError(ctx, w, nil, ErrRequestBodyTooLarge)
return
}
if err != nil { if err != nil {
log.Error("error reading request body", "err", err) log.Error("error reading request body", "err", err)
writeRPCError(ctx, w, nil, ErrInternal) writeRPCError(ctx, w, nil, ErrInternal)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment