Commit f48bd87d authored by Murphy Law's avatar Murphy Law Committed by GitHub

go/proxyd: Proxy requests using batch JSON-RPC (#2480)

* go/proxyd: Proxy requests as batched RPC

We forward several RPC request objects to upstreams using the JSON-RPC
batch functionality. This should be more efficient than serialized RPC
request proxying the round-trip latency of the remaining request objects
are eliminated.

A new server config, `max_upstream_batch_size`, is introduced to limit
the number of RPC request objects in a single batch request. This is to
avoid overloading upstream as proxyd may accept a large number of
request objects in a single request by having a large `max_body_size_bytes`
config value.

* remove flakes: no more SequencedResponseHandler
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent 84e4e2c6
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"math" "math"
"math/rand" "math/rand"
"net/http" "net/http"
"sort"
"strconv" "strconv"
"strings" "strings"
"time" "time"
...@@ -199,13 +200,13 @@ func NewBackend( ...@@ -199,13 +200,13 @@ func NewBackend(
return backend return backend
} }
func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) { func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([]*RPCRes, error) {
if !b.Online() { if !b.Online() {
RecordRPCError(ctx, b.Name, req.Method, ErrBackendOffline) RecordBatchRPCError(ctx, b.Name, reqs, ErrBackendOffline)
return nil, ErrBackendOffline return nil, ErrBackendOffline
} }
if b.IsRateLimited() { if b.IsRateLimited() {
RecordRPCError(ctx, b.Name, req.Method, ErrBackendOverCapacity) RecordBatchRPCError(ctx, b.Name, reqs, ErrBackendOverCapacity)
return nil, ErrBackendOverCapacity return nil, ErrBackendOverCapacity
} }
...@@ -213,9 +214,20 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) { ...@@ -213,9 +214,20 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) {
// <= to account for the first attempt not technically being // <= to account for the first attempt not technically being
// a retry // a retry
for i := 0; i <= b.maxRetries; i++ { for i := 0; i <= b.maxRetries; i++ {
RecordRPCForward(ctx, b.Name, req.Method, RPCRequestSourceHTTP) RecordBatchRPCForward(ctx, b.Name, reqs, RPCRequestSourceHTTP)
respTimer := prometheus.NewTimer(rpcBackendRequestDurationSumm.WithLabelValues(b.Name, req.Method)) metricLabelMethod := reqs[0].Method
res, err := b.doForward(ctx, req) if isBatch {
metricLabelMethod = "<batch>"
}
timer := prometheus.NewTimer(
rpcBackendRequestDurationSumm.WithLabelValues(
b.Name,
metricLabelMethod,
strconv.FormatBool(isBatch),
),
)
res, err := b.doForward(ctx, reqs, isBatch)
if err != nil { if err != nil {
lastError = err lastError = err
log.Warn( log.Warn(
...@@ -224,31 +236,14 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) { ...@@ -224,31 +236,14 @@ func (b *Backend) Forward(ctx context.Context, req *RPCReq) (*RPCRes, error) {
"req_id", GetReqID(ctx), "req_id", GetReqID(ctx),
"err", err, "err", err,
) )
respTimer.ObserveDuration() timer.ObserveDuration()
RecordRPCError(ctx, b.Name, req.Method, err) RecordBatchRPCError(ctx, b.Name, reqs, err)
sleepContext(ctx, calcBackoff(i)) sleepContext(ctx, calcBackoff(i))
continue continue
} }
respTimer.ObserveDuration() timer.ObserveDuration()
if res.IsError() {
RecordRPCError(ctx, b.Name, req.Method, res.Error) MaybeRecordErrorsInRPCRes(ctx, b.Name, reqs, res)
log.Info(
"backend responded with RPC error",
"backend", b.Name,
"code", res.Error.Code,
"msg", res.Error.Message,
"req_id", GetReqID(ctx),
"source", "rpc",
"auth", GetAuthCtx(ctx),
)
} else {
log.Info("forwarded RPC request",
"backend", b.Name,
"method", req.Method,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
)
}
return res, nil return res, nil
} }
...@@ -337,8 +332,8 @@ func (b *Backend) setOffline() { ...@@ -337,8 +332,8 @@ func (b *Backend) setOffline() {
} }
} }
func (b *Backend) doForward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, error) { func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool) ([]*RPCRes, error) {
body := mustMarshalJSON(rpcReq) body := mustMarshalJSON(rpcReqs)
httpReq, err := http.NewRequestWithContext(ctx, "POST", b.rpcURL, bytes.NewReader(body)) httpReq, err := http.NewRequestWithContext(ctx, "POST", b.rpcURL, bytes.NewReader(body))
if err != nil { if err != nil {
...@@ -367,11 +362,16 @@ func (b *Backend) doForward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, error ...@@ -367,11 +362,16 @@ func (b *Backend) doForward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, error
return nil, wrapErr(err, "error in backend request") return nil, wrapErr(err, "error in backend request")
} }
metricLabelMethod := rpcReqs[0].Method
if isBatch {
metricLabelMethod = "<batch>"
}
rpcBackendHTTPResponseCodesTotal.WithLabelValues( rpcBackendHTTPResponseCodesTotal.WithLabelValues(
GetAuthCtx(ctx), GetAuthCtx(ctx),
b.Name, b.Name,
rpcReq.Method, metricLabelMethod,
strconv.Itoa(httpRes.StatusCode), strconv.Itoa(httpRes.StatusCode),
strconv.FormatBool(isBatch),
).Inc() ).Inc()
// Alchemy returns a 400 on bad JSONs, so handle that case // Alchemy returns a 400 on bad JSONs, so handle that case
...@@ -385,30 +385,60 @@ func (b *Backend) doForward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, error ...@@ -385,30 +385,60 @@ func (b *Backend) doForward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, error
return nil, wrapErr(err, "error reading response body") return nil, wrapErr(err, "error reading response body")
} }
res := new(RPCRes) var res []*RPCRes
if err := json.Unmarshal(resB, res); err != nil { if err := json.Unmarshal(resB, &res); err != nil {
return nil, ErrBackendBadResponse
}
// Alas! Certain node providers (Infura) always return a single JSON object for some types of errors
if len(rpcReqs) != len(res) {
return nil, ErrBackendBadResponse return nil, ErrBackendBadResponse
} }
// capture the HTTP status code in the response. this will only // capture the HTTP status code in the response. this will only
// ever be 400 given the status check on line 318 above. // ever be 400 given the status check on line 318 above.
if httpRes.StatusCode != 200 { if httpRes.StatusCode != 200 {
res.Error.HTTPErrorCode = httpRes.StatusCode for _, res := range res {
res.Error.HTTPErrorCode = httpRes.StatusCode
}
} }
sortBatchRPCResponse(rpcReqs, res)
return res, nil return res, nil
} }
// sortBatchRPCResponse sorts the RPCRes slice according to the position of its corresponding ID in the RPCReq slice
func sortBatchRPCResponse(req []*RPCReq, res []*RPCRes) {
pos := make(map[string]int, len(req))
for i, r := range req {
key := string(r.ID)
if _, ok := pos[key]; ok {
panic("bug! detected requests with duplicate IDs")
}
pos[key] = i
}
sort.Slice(res, func(i, j int) bool {
l := res[i].ID
r := res[j].ID
return pos[string(l)] < pos[string(r)]
})
}
type BackendGroup struct { type BackendGroup struct {
Name string Name string
Backends []*Backend Backends []*Backend
} }
func (b *BackendGroup) Forward(ctx context.Context, rpcReq *RPCReq) (*RPCRes, error) { func (b *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool) ([]*RPCRes, error) {
if len(rpcReqs) == 0 {
return nil, nil
}
rpcRequestsTotal.Inc() rpcRequestsTotal.Inc()
for _, back := range b.Backends { for _, back := range b.Backends {
res, err := back.Forward(ctx, rpcReq) res, err := back.Forward(ctx, rpcReqs, isBatch)
if errors.Is(err, ErrMethodNotWhitelisted) { if errors.Is(err, ErrMethodNotWhitelisted) {
return nil, err return nil, err
} }
...@@ -712,3 +742,44 @@ func (c *LimitedHTTPClient) DoLimited(req *http.Request) (*http.Response, error) ...@@ -712,3 +742,44 @@ func (c *LimitedHTTPClient) DoLimited(req *http.Request) (*http.Response, error)
defer c.sem.Release(1) defer c.sem.Release(1)
return c.Do(req) return c.Do(req)
} }
func RecordBatchRPCError(ctx context.Context, backendName string, reqs []*RPCReq, err error) {
for _, req := range reqs {
RecordRPCError(ctx, backendName, req.Method, err)
}
}
func MaybeRecordErrorsInRPCRes(ctx context.Context, backendName string, reqs []*RPCReq, resBatch []*RPCRes) {
log.Info("forwarded RPC request",
"backend", backendName,
"auth", GetAuthCtx(ctx),
"req_id", GetReqID(ctx),
"batch_size", len(reqs),
)
var lastError *RPCErr
for i, res := range resBatch {
if res.IsError() {
lastError = res.Error
RecordRPCError(ctx, backendName, reqs[i].Method, res.Error)
}
}
if lastError != nil {
log.Info(
"backend responded with RPC error",
"backend", backendName,
"last_error_code", lastError.Code,
"last_error_msg", lastError.Message,
"req_id", GetReqID(ctx),
"source", "rpc",
"auth", GetAuthCtx(ctx),
)
}
}
func RecordBatchRPCForward(ctx context.Context, backendName string, reqs []*RPCReq, source string) {
for _, req := range reqs {
RecordRPCForward(ctx, backendName, req.Method, source)
}
}
...@@ -16,6 +16,8 @@ type ServerConfig struct { ...@@ -16,6 +16,8 @@ type ServerConfig struct {
// TimeoutSeconds specifies the maximum time spent serving an HTTP request. Note that isn't used for websocket connections // TimeoutSeconds specifies the maximum time spent serving an HTTP request. Note that isn't used for websocket connections
TimeoutSeconds int `toml:"timeout_seconds"` TimeoutSeconds int `toml:"timeout_seconds"`
MaxUpstreamBatchSize int `toml:"max_upstream_batch_size"`
} }
type CacheConfig struct { type CacheConfig struct {
......
...@@ -29,7 +29,7 @@ func TestBatchTimeout(t *testing.T) { ...@@ -29,7 +29,7 @@ func TestBatchTimeout(t *testing.T) {
slowBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { slowBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check the config. The sleep duration should be at least double the server.timeout_seconds config to prevent flakes // check the config. The sleep duration should be at least double the server.timeout_seconds config to prevent flakes
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)
SingleResponseHandler(200, goodResponse)(w, r) BatchedResponseHandler(200, goodResponse)(w, r)
})) }))
res, statusCode, err := client.SendBatchRPC( res, statusCode, err := client.SendBatchRPC(
NewRPCReq("1", "eth_chainId", nil), NewRPCReq("1", "eth_chainId", nil),
......
package integration_tests
import (
"net/http"
"os"
"testing"
"github.com/ethereum-optimism/optimism/go/proxyd"
"github.com/stretchr/testify/require"
)
func TestBatching(t *testing.T) {
config := ReadConfig("batching")
chainIDResponse1 := `{"jsonrpc": "2.0", "result": "hello1", "id": 1}`
chainIDResponse2 := `{"jsonrpc": "2.0", "result": "hello2", "id": 2}`
chainIDResponse3 := `{"jsonrpc": "2.0", "result": "hello3", "id": 3}`
netVersionResponse1 := `{"jsonrpc": "2.0", "result": "1.0", "id": 1}`
callResponse1 := `{"jsonrpc": "2.0", "result": "ekans1", "id": 1}`
type mockResult struct {
method string
id string
result interface{}
}
chainIDMock1 := mockResult{"eth_chainId", "1", "hello1"}
chainIDMock2 := mockResult{"eth_chainId", "2", "hello2"}
chainIDMock3 := mockResult{"eth_chainId", "3", "hello3"}
netVersionMock1 := mockResult{"net_version", "1", "1.0"}
callMock1 := mockResult{"eth_call", "1", "ekans1"}
tests := []struct {
name string
handler http.Handler
mocks []mockResult
reqs []*proxyd.RPCReq
expectedRes string
maxBatchSize int
numExpectedForwards int
}{
{
name: "backend returns batches out of order",
mocks: []mockResult{chainIDMock1, chainIDMock2, chainIDMock3},
reqs: []*proxyd.RPCReq{
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("2", "eth_chainId", nil),
NewRPCReq("3", "eth_chainId", nil),
},
expectedRes: asArray(chainIDResponse1, chainIDResponse2, chainIDResponse3),
maxBatchSize: 2,
numExpectedForwards: 2,
},
{
// infura behavior
name: "backend returns single RPC response object as error",
handler: SingleResponseHandler(500, `{"jsonrpc":"2.0","error":{"code":-32001,"message":"internal server error"},"id":1}`),
reqs: []*proxyd.RPCReq{
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("2", "eth_chainId", nil),
},
expectedRes: asArray(
`{"error":{"code":-32011,"message":"no backends available for method"},"id":1,"jsonrpc":"2.0"}`,
`{"error":{"code":-32011,"message":"no backends available for method"},"id":2,"jsonrpc":"2.0"}`,
),
maxBatchSize: 10,
numExpectedForwards: 1,
},
{
name: "backend returns single RPC response object for minibatches",
handler: SingleResponseHandler(500, `{"jsonrpc":"2.0","error":{"code":-32001,"message":"internal server error"},"id":1}`),
reqs: []*proxyd.RPCReq{
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("2", "eth_chainId", nil),
},
expectedRes: asArray(
`{"error":{"code":-32011,"message":"no backends available for method"},"id":1,"jsonrpc":"2.0"}`,
`{"error":{"code":-32011,"message":"no backends available for method"},"id":2,"jsonrpc":"2.0"}`,
),
maxBatchSize: 1,
numExpectedForwards: 2,
},
{
name: "duplicate request ids are on distinct batches",
mocks: []mockResult{
netVersionMock1,
chainIDMock2,
chainIDMock1,
callMock1,
},
reqs: []*proxyd.RPCReq{
NewRPCReq("1", "net_version", nil),
NewRPCReq("2", "eth_chainId", nil),
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("1", "eth_call", nil),
},
expectedRes: asArray(netVersionResponse1, chainIDResponse2, chainIDResponse1, callResponse1),
maxBatchSize: 2,
numExpectedForwards: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config.Server.MaxUpstreamBatchSize = tt.maxBatchSize
handler := tt.handler
if handler == nil {
router := NewBatchRPCResponseRouter()
for _, mock := range tt.mocks {
router.SetRoute(mock.method, mock.id, mock.result)
}
handler = router
}
goodBackend := NewMockBackend(handler)
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
res, statusCode, err := client.SendBatchRPC(tt.reqs...)
require.NoError(t, err)
require.Equal(t, http.StatusOK, statusCode)
RequireEqualJSON(t, []byte(tt.expectedRes), res)
if tt.numExpectedForwards != 0 {
require.Equal(t, tt.numExpectedForwards, len(goodBackend.Requests()))
}
if handler, ok := handler.(*BatchRPCResponseRouter); ok {
for i, mock := range tt.mocks {
require.Equal(t, 1, handler.GetNumCalls(mock.method, mock.id), i)
}
}
})
}
}
...@@ -17,13 +17,17 @@ func TestCaching(t *testing.T) { ...@@ -17,13 +17,17 @@ func TestCaching(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer redis.Close() defer redis.Close()
hdlr := NewRPCResponseHandler(map[string]interface{}{ hdlr := NewBatchRPCResponseRouter()
"eth_chainId": "0x420", hdlr.SetRoute("eth_chainId", "999", "0x420")
"net_version": "0x1234", hdlr.SetRoute("net_version", "999", "0x1234")
"eth_blockNumber": "0x64", hdlr.SetRoute("eth_blockNumber", "999", "0x64")
"eth_getBlockByNumber": "dummy_block", hdlr.SetRoute("eth_getBlockByNumber", "999", "dummy_block")
"eth_call": "dummy_call", hdlr.SetRoute("eth_call", "999", "dummy_call")
})
// mock LVC requests
hdlr.SetFallbackRoute("eth_blockNumber", "0x64")
hdlr.SetFallbackRoute("eth_gasPrice", "0x420")
backend := NewMockBackend(hdlr) backend := NewMockBackend(hdlr)
defer backend.Close() defer backend.Close()
...@@ -125,7 +129,7 @@ func TestCaching(t *testing.T) { ...@@ -125,7 +129,7 @@ func TestCaching(t *testing.T) {
} }
t.Run("block numbers update", func(t *testing.T) { t.Run("block numbers update", func(t *testing.T) {
hdlr.SetResponse("eth_blockNumber", "0x100") hdlr.SetFallbackRoute("eth_blockNumber", "0x100")
time.Sleep(1500 * time.Millisecond) time.Sleep(1500 * time.Millisecond)
resRaw, _, err := client.SendRPC("eth_blockNumber", nil) resRaw, _, err := client.SendRPC("eth_blockNumber", nil)
require.NoError(t, err) require.NoError(t, err)
...@@ -134,7 +138,7 @@ func TestCaching(t *testing.T) { ...@@ -134,7 +138,7 @@ func TestCaching(t *testing.T) {
}) })
t.Run("nil responses should not be cached", func(t *testing.T) { t.Run("nil responses should not be cached", func(t *testing.T) {
hdlr.SetResponse("eth_getBlockByNumber", nil) hdlr.SetRoute("eth_getBlockByNumber", "999", nil)
resRaw, _, err := client.SendRPC("eth_getBlockByNumber", []interface{}{"0x123"}) resRaw, _, err := client.SendRPC("eth_getBlockByNumber", []interface{}{"0x123"})
require.NoError(t, err) require.NoError(t, err)
resCache, _, err := client.SendRPC("eth_getBlockByNumber", []interface{}{"0x123"}) resCache, _, err := client.SendRPC("eth_getBlockByNumber", []interface{}{"0x123"})
...@@ -145,6 +149,61 @@ func TestCaching(t *testing.T) { ...@@ -145,6 +149,61 @@ func TestCaching(t *testing.T) {
}) })
} }
func TestBatchCaching(t *testing.T) {
redis, err := miniredis.Run()
require.NoError(t, err)
defer redis.Close()
hdlr := NewBatchRPCResponseRouter()
hdlr.SetRoute("eth_chainId", "1", "0x420")
hdlr.SetRoute("net_version", "1", "0x1234")
hdlr.SetRoute("eth_call", "1", "dummy_call")
// mock LVC requests
hdlr.SetFallbackRoute("eth_blockNumber", "0x64")
hdlr.SetFallbackRoute("eth_gasPrice", "0x420")
backend := NewMockBackend(hdlr)
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
require.NoError(t, os.Setenv("REDIS_URL", fmt.Sprintf("redis://127.0.0.1:%s", redis.Port())))
config := ReadConfig("caching")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
// allow time for the block number fetcher to fire
time.Sleep(1500 * time.Millisecond)
goodChainIdResponse := "{\"jsonrpc\": \"2.0\", \"result\": \"0x420\", \"id\": 1}"
goodNetVersionResponse := "{\"jsonrpc\": \"2.0\", \"result\": \"0x1234\", \"id\": 1}"
goodEthCallResponse := "{\"jsonrpc\": \"2.0\", \"result\": \"dummy_call\", \"id\": 1}"
res, _, err := client.SendBatchRPC(
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("1", "net_version", nil),
)
require.NoError(t, err)
RequireEqualJSON(t, []byte(asArray(goodChainIdResponse, goodNetVersionResponse)), res)
require.Equal(t, 1, countRequests(backend, "eth_chainId"))
require.Equal(t, 1, countRequests(backend, "net_version"))
backend.Reset()
res, _, err = client.SendBatchRPC(
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("1", "eth_call", []interface{}{`{"to":"0x1234"}`, "pending"}),
NewRPCReq("1", "net_version", nil),
)
require.NoError(t, err)
RequireEqualJSON(t, []byte(asArray(goodChainIdResponse, goodEthCallResponse, goodNetVersionResponse)), res)
require.Equal(t, 0, countRequests(backend, "eth_chainId"))
require.Equal(t, 0, countRequests(backend, "net_version"))
require.Equal(t, 1, countRequests(backend, "eth_call"))
}
func countRequests(backend *MockBackend, name string) int { func countRequests(backend *MockBackend, name string) int {
var count int var count int
for _, req := range backend.Requests() { for _, req := range backend.Requests() {
......
...@@ -18,7 +18,7 @@ const ( ...@@ -18,7 +18,7 @@ const (
) )
func TestFailover(t *testing.T) { func TestFailover(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close() defer goodBackend.Close()
badBackend := NewMockBackend(nil) badBackend := NewMockBackend(nil)
defer badBackend.Close() defer badBackend.Close()
...@@ -88,7 +88,7 @@ func TestFailover(t *testing.T) { ...@@ -88,7 +88,7 @@ func TestFailover(t *testing.T) {
t.Run("backend times out and falls back to another", func(t *testing.T) { t.Run("backend times out and falls back to another", func(t *testing.T) {
badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
_, _ = w.Write([]byte("{}")) _, _ = w.Write([]byte("[{}]"))
})) }))
res, statusCode, err := client.SendRPC("eth_chainId", nil) res, statusCode, err := client.SendRPC("eth_chainId", nil)
require.NoError(t, err) require.NoError(t, err)
...@@ -101,23 +101,26 @@ func TestFailover(t *testing.T) { ...@@ -101,23 +101,26 @@ func TestFailover(t *testing.T) {
}) })
t.Run("works with a batch request", func(t *testing.T) { t.Run("works with a batch request", func(t *testing.T) {
goodBackend.SetHandler(BatchedResponseHandler(200, goodResponse, goodResponse))
badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500) w.WriteHeader(500)
})) }))
res, statusCode, err := client.SendBatchRPC( res, statusCode, err := client.SendBatchRPC(
NewRPCReq("1", "eth_chainId", nil), NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("1", "eth_chainId", nil), NewRPCReq("2", "eth_chainId", nil),
) )
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 200, statusCode) require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(asArray(goodResponse, goodResponse)), res) RequireEqualJSON(t, []byte(asArray(goodResponse, goodResponse)), res)
require.Equal(t, 2, len(badBackend.Requests())) require.Equal(t, 1, len(badBackend.Requests()))
require.Equal(t, 2, len(goodBackend.Requests())) require.Equal(t, 1, len(goodBackend.Requests()))
goodBackend.Reset()
badBackend.Reset()
}) })
} }
func TestRetries(t *testing.T) { func TestRetries(t *testing.T) {
backend := NewMockBackend(SingleResponseHandler(200, goodResponse)) backend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer backend.Close() defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL())) require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
...@@ -134,7 +137,7 @@ func TestRetries(t *testing.T) { ...@@ -134,7 +137,7 @@ func TestRetries(t *testing.T) {
w.WriteHeader(500) w.WriteHeader(500)
return return
} }
_, _ = w.Write([]byte(goodResponse)) BatchedResponseHandler(200, goodResponse)(w, r)
})) }))
// test case where request eventually succeeds // test case where request eventually succeeds
...@@ -155,7 +158,8 @@ func TestRetries(t *testing.T) { ...@@ -155,7 +158,8 @@ func TestRetries(t *testing.T) {
} }
func TestOutOfServiceInterval(t *testing.T) { func TestOutOfServiceInterval(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse)) okHandler := BatchedResponseHandler(200, goodResponse)
goodBackend := NewMockBackend(okHandler)
defer goodBackend.Close() defer goodBackend.Close()
badBackend := NewMockBackend(nil) badBackend := NewMockBackend(nil)
defer badBackend.Close() defer badBackend.Close()
...@@ -169,13 +173,9 @@ func TestOutOfServiceInterval(t *testing.T) { ...@@ -169,13 +173,9 @@ func TestOutOfServiceInterval(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer shutdown() defer shutdown()
okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(goodResponse))
})
badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(503) w.WriteHeader(503)
})) }))
goodBackend.SetHandler(okHandler)
res, statusCode, err := client.SendRPC("eth_chainId", nil) res, statusCode, err := client.SendRPC("eth_chainId", nil)
require.NoError(t, err) require.NoError(t, err)
...@@ -210,3 +210,33 @@ func TestOutOfServiceInterval(t *testing.T) { ...@@ -210,3 +210,33 @@ func TestOutOfServiceInterval(t *testing.T) {
require.Equal(t, 3, len(badBackend.Requests())) require.Equal(t, 3, len(badBackend.Requests()))
require.Equal(t, 4, len(goodBackend.Requests())) require.Equal(t, 4, len(goodBackend.Requests()))
} }
func TestBatchWithPartialFailover(t *testing.T) {
config := ReadConfig("failover")
config.Server.MaxUpstreamBatchSize = 2
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse, goodResponse))
defer goodBackend.Close()
badBackend := NewMockBackend(SingleResponseHandler(200, "this data is not JSON!"))
defer badBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
require.NoError(t, os.Setenv("BAD_BACKEND_RPC_URL", badBackend.URL()))
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
res, statusCode, err := client.SendBatchRPC(
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("2", "eth_chainId", nil),
NewRPCReq("3", "eth_chainId", nil),
NewRPCReq("4", "eth_chainId", nil),
)
require.NoError(t, err)
require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(asArray(goodResponse, goodResponse, goodResponse, goodResponse)), res)
require.Equal(t, 2, len(badBackend.Requests()))
require.Equal(t, 2, len(goodBackend.Requests()))
}
...@@ -27,7 +27,7 @@ func TestMaxConcurrentRPCs(t *testing.T) { ...@@ -27,7 +27,7 @@ func TestMaxConcurrentRPCs(t *testing.T) {
mu.Unlock() mu.Unlock()
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)
SingleResponseHandler(200, goodResponse)(w, r) BatchedResponseHandler(200, goodResponse)(w, r)
mu.Lock() mu.Lock()
concurrentRPCs-- concurrentRPCs--
......
...@@ -32,52 +32,163 @@ func SingleResponseHandler(code int, response string) http.HandlerFunc { ...@@ -32,52 +32,163 @@ func SingleResponseHandler(code int, response string) http.HandlerFunc {
} }
} }
type RPCResponseHandler struct { func BatchedResponseHandler(code int, responses ...string) http.HandlerFunc {
mtx sync.RWMutex // all proxyd upstream requests are batched
rpcResponses map[string]interface{} return func(w http.ResponseWriter, r *http.Request) {
var body string
body += "["
for i, response := range responses {
body += response
if i+1 < len(responses) {
body += ","
}
}
body += "]"
SingleResponseHandler(code, body)(w, r)
}
}
type responseMapping struct {
result interface{}
calls int
}
type BatchRPCResponseRouter struct {
m map[string]map[string]*responseMapping
fallback map[string]interface{}
mtx sync.Mutex
} }
func NewRPCResponseHandler(rpcResponses map[string]interface{}) *RPCResponseHandler { func NewBatchRPCResponseRouter() *BatchRPCResponseRouter {
return &RPCResponseHandler{ return &BatchRPCResponseRouter{
rpcResponses: rpcResponses, m: make(map[string]map[string]*responseMapping),
fallback: make(map[string]interface{}),
}
}
func (h *BatchRPCResponseRouter) SetRoute(method string, id string, result interface{}) {
h.mtx.Lock()
defer h.mtx.Unlock()
switch result.(type) {
case string:
case nil:
break
default:
panic("invalid result type")
}
m := h.m[method]
if m == nil {
m = make(map[string]*responseMapping)
} }
m[id] = &responseMapping{result: result}
h.m[method] = m
} }
func (h *RPCResponseHandler) SetResponse(method string, response interface{}) { func (h *BatchRPCResponseRouter) SetFallbackRoute(method string, result interface{}) {
h.mtx.Lock() h.mtx.Lock()
defer h.mtx.Unlock() defer h.mtx.Unlock()
switch response.(type) { switch result.(type) {
case string: case string:
case nil: case nil:
break break
default: default:
panic("invalid response type") panic("invalid result type")
} }
h.rpcResponses[method] = response h.fallback[method] = result
} }
func (h *RPCResponseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *BatchRPCResponseRouter) GetNumCalls(method string, id string) int {
h.mtx.Lock()
defer h.mtx.Unlock()
if m := h.m[method]; m != nil {
if rm := m[id]; rm != nil {
return rm.calls
}
}
return 0
}
func (h *BatchRPCResponseRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.mtx.Lock()
defer h.mtx.Unlock()
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
panic(err) panic(err)
} }
if proxyd.IsBatch(body) {
batch, err := proxyd.ParseBatchRPCReq(body)
if err != nil {
panic(err)
}
out := make([]*proxyd.RPCRes, len(batch))
for i := range batch {
req, err := proxyd.ParseRPCReq(batch[i])
if err != nil {
panic(err)
}
var result interface{}
var resultHasValue bool
if mappings, exists := h.m[req.Method]; exists {
if rm := mappings[string(req.ID)]; rm != nil {
result = rm.result
resultHasValue = true
rm.calls++
}
}
if !resultHasValue {
result, resultHasValue = h.fallback[req.Method]
}
if !resultHasValue {
w.WriteHeader(400)
return
}
out[i] = &proxyd.RPCRes{
JSONRPC: proxyd.JSONRPCVersion,
Result: result,
ID: req.ID,
}
}
if err := json.NewEncoder(w).Encode(out); err != nil {
panic(err)
}
return
}
req, err := proxyd.ParseRPCReq(body) req, err := proxyd.ParseRPCReq(body)
if err != nil { if err != nil {
panic(err) panic(err)
} }
h.mtx.RLock()
res := h.rpcResponses[req.Method] var result interface{}
h.mtx.RUnlock() var resultHasValue bool
if res == "" {
if mappings, exists := h.m[req.Method]; exists {
if rm := mappings[string(req.ID)]; rm != nil {
result = rm.result
resultHasValue = true
rm.calls++
}
}
if !resultHasValue {
result, resultHasValue = h.fallback[req.Method]
}
if !resultHasValue {
w.WriteHeader(400) w.WriteHeader(400)
return return
} }
out := &proxyd.RPCRes{ out := &proxyd.RPCRes{
JSONRPC: proxyd.JSONRPCVersion, JSONRPC: proxyd.JSONRPCVersion,
Result: res, Result: result,
ID: req.ID, ID: req.ID,
} }
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
......
...@@ -14,7 +14,7 @@ type resWithCode struct { ...@@ -14,7 +14,7 @@ type resWithCode struct {
} }
func TestMaxRPSLimit(t *testing.T) { func TestMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close() defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL())) require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
......
[server] [server]
rpc_port = 8545 rpc_port = 8545
timeout_seconds = 1 timeout_seconds = 1
max_upstream_batch_size = 1
[backend] [backend]
response_timeout_seconds = 1 response_timeout_seconds = 1
......
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
net_version = "main"
eth_call = "main"
...@@ -6,10 +6,12 @@ import ( ...@@ -6,10 +6,12 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os"
"testing" "testing"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/ethereum-optimism/optimism/go/proxyd" "github.com/ethereum-optimism/optimism/go/proxyd"
"github.com/ethereum/go-ethereum/log"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
...@@ -95,3 +97,13 @@ func NewRPCReq(id string, method string, params []interface{}) *proxyd.RPCReq { ...@@ -95,3 +97,13 @@ func NewRPCReq(id string, method string, params []interface{}) *proxyd.RPCReq {
ID: []byte(id), ID: []byte(id),
} }
} }
func InitLogger() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlDebug,
log.StreamHandler(
os.Stdout,
log.TerminalFormat(false),
)),
)
}
...@@ -19,7 +19,7 @@ const ( ...@@ -19,7 +19,7 @@ const (
) )
func TestSingleRPCValidation(t *testing.T) { func TestSingleRPCValidation(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close() defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL())) require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
...@@ -103,7 +103,7 @@ func TestSingleRPCValidation(t *testing.T) { ...@@ -103,7 +103,7 @@ func TestSingleRPCValidation(t *testing.T) {
} }
func TestBatchRPCValidation(t *testing.T) { func TestBatchRPCValidation(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse)) goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close() defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL())) require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
......
...@@ -51,6 +51,7 @@ var ( ...@@ -51,6 +51,7 @@ var (
"backend_name", "backend_name",
"method_name", "method_name",
"status_code", "status_code",
"batched",
}) })
rpcErrorsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ rpcErrorsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
...@@ -83,6 +84,7 @@ var ( ...@@ -83,6 +84,7 @@ var (
}, []string{ }, []string{
"backend_name", "backend_name",
"method_name", "method_name",
"batched",
}) })
activeClientWsConnsGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{ activeClientWsConnsGauge = promauto.NewGaugeVec(prometheus.GaugeOpts{
......
...@@ -220,6 +220,7 @@ func Start(config *Config) (func(), error) { ...@@ -220,6 +220,7 @@ func Start(config *Config) (func(), error) {
config.Server.MaxBodySizeBytes, config.Server.MaxBodySizeBytes,
resolvedAuth, resolvedAuth,
secondsToDuration(config.Server.TimeoutSeconds), secondsToDuration(config.Server.TimeoutSeconds),
config.Server.MaxUpstreamBatchSize,
rpcCache, rpcCache,
) )
......
...@@ -21,27 +21,29 @@ import ( ...@@ -21,27 +21,29 @@ import (
) )
const ( const (
ContextKeyAuth = "authorization" ContextKeyAuth = "authorization"
ContextKeyReqID = "req_id" ContextKeyReqID = "req_id"
ContextKeyXForwardedFor = "x_forwarded_for" ContextKeyXForwardedFor = "x_forwarded_for"
MaxBatchRPCCalls = 100 MaxBatchRPCCalls = 100
cacheStatusHdr = "X-Proxyd-Cache-Status" cacheStatusHdr = "X-Proxyd-Cache-Status"
defaultServerTimeout = time.Second * 10 defaultServerTimeout = time.Second * 10
maxLogLength = 2000 maxLogLength = 2000
defaultMaxUpstreamBatchSize = 10
) )
type Server struct { type Server struct {
backendGroups map[string]*BackendGroup backendGroups map[string]*BackendGroup
wsBackendGroup *BackendGroup wsBackendGroup *BackendGroup
wsMethodWhitelist *StringSet wsMethodWhitelist *StringSet
rpcMethodMappings map[string]string rpcMethodMappings map[string]string
maxBodySize int64 maxBodySize int64
authenticatedPaths map[string]string authenticatedPaths map[string]string
timeout time.Duration timeout time.Duration
upgrader *websocket.Upgrader maxUpstreamBatchSize int
rpcServer *http.Server upgrader *websocket.Upgrader
wsServer *http.Server rpcServer *http.Server
cache RPCCache wsServer *http.Server
cache RPCCache
} }
func NewServer( func NewServer(
...@@ -52,6 +54,7 @@ func NewServer( ...@@ -52,6 +54,7 @@ func NewServer(
maxBodySize int64, maxBodySize int64,
authenticatedPaths map[string]string, authenticatedPaths map[string]string,
timeout time.Duration, timeout time.Duration,
maxUpstreamBatchSize int,
cache RPCCache, cache RPCCache,
) *Server { ) *Server {
if cache == nil { if cache == nil {
...@@ -66,15 +69,20 @@ func NewServer( ...@@ -66,15 +69,20 @@ func NewServer(
timeout = defaultServerTimeout timeout = defaultServerTimeout
} }
if maxUpstreamBatchSize == 0 {
maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
}
return &Server{ return &Server{
backendGroups: backendGroups, backendGroups: backendGroups,
wsBackendGroup: wsBackendGroup, wsBackendGroup: wsBackendGroup,
wsMethodWhitelist: wsMethodWhitelist, wsMethodWhitelist: wsMethodWhitelist,
rpcMethodMappings: rpcMethodMappings, rpcMethodMappings: rpcMethodMappings,
maxBodySize: maxBodySize, maxBodySize: maxBodySize,
authenticatedPaths: authenticatedPaths, authenticatedPaths: authenticatedPaths,
timeout: timeout, timeout: timeout,
cache: cache, maxUpstreamBatchSize: maxUpstreamBatchSize,
cache: cache,
upgrader: &websocket.Upgrader{ upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
}, },
...@@ -177,34 +185,14 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ...@@ -177,34 +185,14 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return return
} }
batchRes := make([]*RPCRes, len(reqs)) batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, true)
var batchContainsCached bool if err == context.DeadlineExceeded {
for i := 0; i < len(reqs); i++ { writeRPCError(ctx, w, nil, ErrGatewayTimeout)
if ctx.Err() == context.DeadlineExceeded { return
log.Info( }
"short-circuiting batch RPC", if err != nil {
"req_id", GetReqID(ctx), writeRPCError(ctx, w, nil, ErrInternal)
"auth", GetAuthCtx(ctx), return
"index", i,
"batch_size", len(reqs),
)
batchRPCShortCircuitsTotal.Inc()
writeRPCError(ctx, w, nil, ErrGatewayTimeout)
return
}
req, err := ParseRPCReq(reqs[i])
if err != nil {
log.Info("error parsing RPC call", "source", "rpc", "err", err)
batchRes[i] = NewRPCErrorRes(nil, err)
continue
}
var cached bool
batchRes[i], cached = s.handleSingleRPC(ctx, req)
if cached {
batchContainsCached = true
}
} }
setCacheHeader(w, batchContainsCached) setCacheHeader(w, batchContainsCached)
...@@ -212,73 +200,131 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) { ...@@ -212,73 +200,131 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return return
} }
req, err := ParseRPCReq(body) rawBody := json.RawMessage(body)
backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, false)
if err != nil { if err != nil {
log.Info("error parsing RPC call", "source", "rpc", "err", err) writeRPCError(ctx, w, nil, ErrInternal)
writeRPCError(ctx, w, nil, err)
return return
} }
backendRes, cached := s.handleSingleRPC(ctx, req)
setCacheHeader(w, cached) setCacheHeader(w, cached)
writeRPCRes(ctx, w, backendRes) writeRPCRes(ctx, w, backendRes[0])
} }
func (s *Server) handleSingleRPC(ctx context.Context, req *RPCReq) (*RPCRes, bool) { func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isBatch bool) ([]*RPCRes, bool, error) {
if err := ValidateRPCReq(req); err != nil { // A request set is transformed into groups of batches.
RecordRPCError(ctx, BackendProxyd, MethodUnknown, err) // Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints)
return NewRPCErrorRes(nil, err), false // A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's
} // forwarded to the backend. This is done to ensure that the order of JSON-RPC Responses match the Request order
// as the backend MAY return Responses out of order.
group := s.rpcMethodMappings[req.Method] // NOTE: Duplicate request ids induces 1-sized JSON-RPC batches
if group == "" { type batchGroup struct {
// use unknown below to prevent DOS vector that fills up memory groupID int
// with arbitrary method names. backendGroup string
log.Info(
"blocked request for non-whitelisted method",
"source", "rpc",
"req_id", GetReqID(ctx),
"method", req.Method,
)
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
return NewRPCErrorRes(req.ID, ErrMethodNotWhitelisted), false
} }
var backendRes *RPCRes responses := make([]*RPCRes, len(reqs))
backendRes, err := s.cache.GetRPC(ctx, req) batches := make(map[batchGroup][]batchElem)
if err != nil { ids := make(map[string]int, len(reqs))
log.Warn(
"cache lookup error",
"req_id", GetReqID(ctx),
"err", err,
)
}
if backendRes != nil {
return backendRes, true
}
backendRes, err = s.backendGroups[group].Forward(ctx, req) for i := range reqs {
if err != nil { parsedReq, err := ParseRPCReq(reqs[i])
log.Error( if err != nil {
"error forwarding RPC request", log.Info("error parsing RPC call", "source", "rpc", "err", err)
"method", req.Method, responses[i] = NewRPCErrorRes(nil, err)
"req_id", GetReqID(ctx), continue
"err", err, }
)
return NewRPCErrorRes(req.ID, err), false
}
if backendRes.Error == nil && backendRes.Result != nil { if err := ValidateRPCReq(parsedReq); err != nil {
if err = s.cache.PutRPC(ctx, req, backendRes); err != nil { RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
log.Warn( responses[i] = NewRPCErrorRes(nil, err)
"cache put error", continue
}
group := s.rpcMethodMappings[parsedReq.Method]
if group == "" {
// use unknown below to prevent DOS vector that fills up memory
// with arbitrary method names.
log.Info(
"blocked request for non-whitelisted method",
"source", "rpc",
"req_id", GetReqID(ctx), "req_id", GetReqID(ctx),
"err", err, "method", parsedReq.Method,
) )
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
responses[i] = NewRPCErrorRes(parsedReq.ID, ErrMethodNotWhitelisted)
continue
}
id := string(parsedReq.ID)
// If this is a duplicate Request ID, move the Request to a new batchGroup
ids[id]++
batchGroupID := ids[id]
batchGroup := batchGroup{groupID: batchGroupID, backendGroup: group}
batches[batchGroup] = append(batches[batchGroup], batchElem{parsedReq, i})
}
var cached bool
for group, batch := range batches {
var cacheMisses []batchElem
for _, req := range batch {
backendRes, _ := s.cache.GetRPC(ctx, req.Req)
if backendRes != nil {
responses[req.Index] = backendRes
cached = true
} else {
cacheMisses = append(cacheMisses, req)
}
}
// Create minibatches - each minibatch must be no larger than the maxUpstreamBatchSize
numBatches := int(math.Ceil(float64(len(cacheMisses)) / float64(s.maxUpstreamBatchSize)))
for i := 0; i < numBatches; i++ {
if ctx.Err() == context.DeadlineExceeded {
log.Info("short-circuiting batch RPC",
"req_id", GetReqID(ctx),
"auth", GetAuthCtx(ctx),
"batch_index", i,
)
batchRPCShortCircuitsTotal.Inc()
return nil, false, context.DeadlineExceeded
}
start := i * s.maxUpstreamBatchSize
end := int(math.Min(float64(start+s.maxUpstreamBatchSize), float64(len(cacheMisses))))
elems := cacheMisses[start:end]
res, err := s.backendGroups[group.backendGroup].Forward(ctx, createBatchRequest(elems), isBatch)
if err != nil {
log.Error(
"error forwarding RPC batch",
"batch_size", len(elems),
"backend_group", group,
"err", err,
)
res = nil
for _, elem := range elems {
res = append(res, NewRPCErrorRes(elem.Req.ID, err))
}
}
for i := range elems {
responses[elems[i].Index] = res[i]
// TODO(inphi): batch put these
if res[i].Error == nil && res[i].Result != nil {
if err := s.cache.PutRPC(ctx, elems[i].Req, res[i]); err != nil {
log.Warn(
"cache put error",
"req_id", GetReqID(ctx),
"err", err,
)
}
}
}
} }
} }
return backendRes, false return responses, cached, nil
} }
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
...@@ -472,3 +518,16 @@ func truncate(str string) string { ...@@ -472,3 +518,16 @@ func truncate(str string) string {
return str return str
} }
} }
type batchElem struct {
Req *RPCReq
Index int
}
func createBatchRequest(elems []batchElem) []*RPCReq {
batch := make([]*RPCReq, len(elems))
for i := range elems {
batch[i] = elems[i].Req
}
return batch
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment