Commit 42ae2a6b authored by Matthew Slipper's avatar Matthew Slipper Committed by GitHub

Merge pull request #2034 from mslipper/feat/proxyd-batching

go/proxyd: Add integration tests and batch support
parents a8763df7 90d9c3ac
---
'@eth-optimism/proxyd': minor
---
Add integration tests and batching
name: proxyd unit tests
on:
push:
branches:
- 'master'
- 'develop'
pull_request:
workflow_dispatch:
defaults:
run:
working-directory: ./go/proxyd
jobs:
test:
steps:
- name: Install Go
uses: actions/setup-go@v2
with:
go-version: 1.15.x
- name: Checkout code
uses: actions/checkout@v2
- name: Build
run: make proxyd
- name: Lint
run: make lint
- name: Test
run: make test
......@@ -11,3 +11,11 @@ fmt:
go mod tidy
gofmt -w .
.PHONY: fmt
test:
go test -race -v ./...
.PHONY: test
lint:
go vet ./...
.PHONY: test
\ No newline at end of file
......@@ -62,6 +62,10 @@ var (
Message: "backend returned an invalid response",
HTTPErrorCode: 500,
}
ErrTooManyBatchRequests = &RPCErr{
Code: JSONRPCErrorInternal - 14,
Message: "too many RPC calls in batch request",
}
)
func ErrInvalidRequest(msg string) *RPCErr {
......@@ -631,7 +635,7 @@ func (w *WSProxier) close() {
}
func (w *WSProxier) prepareClientMsg(msg []byte) (*RPCReq, error) {
req, err := ParseRPCReq(bytes.NewReader(msg))
req, err := ParseRPCReq(msg)
if err != nil {
return nil, err
}
......
......@@ -62,6 +62,7 @@ func (c *redisCache) Get(ctx context.Context, key string) (string, error) {
if err == redis.Nil {
return "", nil
} else if err != nil {
RecordRedisError("CacheGet")
return "", err
}
return val, nil
......@@ -69,6 +70,9 @@ func (c *redisCache) Get(ctx context.Context, key string) (string, error) {
func (c *redisCache) Put(ctx context.Context, key string, value string) error {
err := c.rdb.Set(ctx, key, value, 0).Err()
if err != nil {
RecordRedisError("CacheSet")
}
return err
}
......@@ -105,6 +109,7 @@ func (c *rpcCache) GetRPC(ctx context.Context, req *RPCReq) (*RPCRes, error) {
return nil, err
}
if !cacheable {
RecordCacheMiss(req.Method)
return nil, nil
}
......@@ -114,6 +119,7 @@ func (c *rpcCache) GetRPC(ctx context.Context, req *RPCReq) (*RPCRes, error) {
return nil, err
}
if encodedVal == "" {
RecordCacheMiss(req.Method)
return nil, nil
}
val, err := snappy.Decode(nil, []byte(encodedVal))
......@@ -121,6 +127,7 @@ func (c *rpcCache) GetRPC(ctx context.Context, req *RPCReq) (*RPCRes, error) {
return nil, err
}
RecordCacheHit(req.Method)
res := new(RPCRes)
err = json.Unmarshal(val, res)
if err != nil {
......
......@@ -2,6 +2,8 @@ package main
import (
"os"
"os/signal"
"syscall"
"github.com/BurntSushi/toml"
"github.com/ethereum-optimism/optimism/go/proxyd"
......@@ -35,7 +37,14 @@ func main() {
log.Crit("error reading config file", "err", err)
}
if err := proxyd.Start(config); err != nil {
shutdown, err := proxyd.Start(config)
if err != nil {
log.Crit("error starting proxyd", "err", err)
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
recvSig := <-sig
log.Info("caught signal, shutting down", "signal", recvSig)
shutdown()
}
......@@ -61,11 +61,11 @@ type MethodMappingsConfig map[string]string
type Config struct {
WSBackendGroup string `toml:"ws_backend_group"`
Server *ServerConfig `toml:"server"`
Cache *CacheConfig `toml:"cache"`
Redis *RedisConfig `toml:"redis"`
Metrics *MetricsConfig `toml:"metrics"`
BackendOptions *BackendOptions `toml:"backend"`
Server ServerConfig `toml:"server"`
Cache CacheConfig `toml:"cache"`
Redis RedisConfig `toml:"redis"`
Metrics MetricsConfig `toml:"metrics"`
BackendOptions BackendOptions `toml:"backend"`
Backends BackendsConfig `toml:"backends"`
Authentication map[string]string `toml:"authentication"`
BackendGroups BackendGroupsConfig `toml:"backend_groups"`
......
......@@ -4,13 +4,18 @@ go 1.16
require (
github.com/BurntSushi/toml v0.4.1
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/alicebob/miniredis v2.5.0+incompatible
github.com/ethereum/go-ethereum v1.10.11
github.com/go-redis/redis/v8 v8.11.4
github.com/golang/snappy v0.0.4
github.com/gomodule/redigo v1.8.8 // indirect
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
github.com/prometheus/client_golang v1.11.0
github.com/rs/cors v1.8.0
github.com/stretchr/testify v1.7.0
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
......@@ -48,6 +48,10 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis v2.5.0+incompatible h1:yBHoLpsyjupjz3NL3MhKMVkR41j82Yjf3KFv7ApYzUI=
github.com/alicebob/miniredis v2.5.0+incompatible/go.mod h1:8HZjEj4yU0dwhYHky+DxYx+6BMjkBbe5ONFIF1MXffk=
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM=
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
github.com/apache/arrow/go/arrow v0.0.0-20191024131854-af6fa24be0db/go.mod h1:VTxUBvSJ3s3eHAg65PNgrsn5BtqCRPdmyXh6rAfdxN0=
......@@ -185,6 +189,8 @@ github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golangci/lint-1 v0.0.0-20181222135242-d2cdd8c08219/go.mod h1:/X8TswGSh1pIozq4ZwCfxS0WA5JGXguxk94ar/4c87Y=
github.com/gomodule/redigo v1.8.8 h1:f6cXq6RRfiyrOJEV7p3JhLDlmawGBVBBP1MggY8Mo4E=
github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
......@@ -427,6 +433,8 @@ github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+
github.com/willf/bitset v1.1.3/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4=
github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 h1:k/gmLsJDWwWqbLCur2yWnJzwQEKRcAHXo6seXGuSwWw=
github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
......@@ -520,6 +528,7 @@ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
......@@ -679,8 +688,9 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
......
package integration_tests
import (
"bytes"
"fmt"
"github.com/alicebob/miniredis"
"github.com/ethereum-optimism/optimism/go/proxyd"
"github.com/stretchr/testify/require"
"os"
"testing"
"time"
)
func TestCaching(t *testing.T) {
redis, err := miniredis.Run()
require.NoError(t, err)
defer redis.Close()
backend := NewMockBackend(RPCResponseHandler(map[string]string{
"eth_chainId": "0x420",
"net_version": "0x1234",
"eth_blockNumber": "0x64",
"eth_getBlockByNumber": "dummy_block",
}))
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
require.NoError(t, os.Setenv("REDIS_URL", fmt.Sprintf("redis://127.0.0.1:%s", redis.Port())))
config := ReadConfig("caching")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
// allow time for the block number fetcher to fire
time.Sleep(1500 * time.Millisecond)
tests := []struct {
method string
params []interface{}
response string
}{
{
"eth_chainId",
nil,
"{\"jsonrpc\": \"2.0\", \"result\": \"0x420\", \"id\": 999}",
},
{
"net_version",
nil,
"{\"jsonrpc\": \"2.0\", \"result\": \"0x1234\", \"id\": 999}",
},
{
"eth_getBlockByNumber",
[]interface{}{
"0x1",
true,
},
"{\"jsonrpc\": \"2.0\", \"result\": \"dummy_block\", \"id\": 999}",
},
}
for _, tt := range tests {
t.Run(tt.method, func(t *testing.T) {
_, _, err := client.SendRPC(tt.method, tt.params)
require.NoError(t, err)
res, _, err := client.SendRPC(tt.method, tt.params)
require.NoError(t, err)
RequireEqualJSON(t, []byte(tt.response), res)
var count int
for _, req := range backend.Requests() {
if bytes.Contains(req.Body, []byte(tt.method)) {
count++
}
}
require.Equal(t, 1, count)
backend.Reset()
})
}
}
package integration_tests
import (
"fmt"
"github.com/ethereum-optimism/optimism/go/proxyd"
"github.com/stretchr/testify/require"
"net/http"
"os"
"sync/atomic"
"testing"
"time"
)
const (
goodResponse = `{"jsonrpc": "2.0", "result": "hello", "id": 999}`
noBackendsResponse = `{"error":{"code":-32011,"message":"no backends available for method"},"id":999,"jsonrpc":"2.0"}`
)
func TestFailover(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse))
defer goodBackend.Close()
badBackend := NewMockBackend(nil)
defer badBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
require.NoError(t, os.Setenv("BAD_BACKEND_RPC_URL", badBackend.URL()))
config := ReadConfig("failover")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
tests := []struct {
name string
handler http.Handler
}{
{
"backend responds 200 with non-JSON response",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("this data is not JSON!"))
}),
},
{
"backend responds with no body",
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}),
},
}
codes := []int{
300,
301,
302,
401,
403,
429,
500,
503,
}
for _, code := range codes {
tests = append(tests, struct {
name string
handler http.Handler
}{
fmt.Sprintf("backend %d", code),
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(code)
}),
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
badBackend.SetHandler(tt.handler)
res, statusCode, err := client.SendRPC("eth_chainId", nil)
require.NoError(t, err)
require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(goodResponse), res)
require.Equal(t, 1, len(badBackend.Requests()))
require.Equal(t, 1, len(goodBackend.Requests()))
badBackend.Reset()
goodBackend.Reset()
})
}
t.Run("backend times out and falls back to another", func(t *testing.T) {
badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second)
w.Write([]byte("{}"))
}))
res, statusCode, err := client.SendRPC("eth_chainId", nil)
require.NoError(t, err)
require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(goodResponse), res)
require.Equal(t, 1, len(badBackend.Requests()))
require.Equal(t, 1, len(goodBackend.Requests()))
goodBackend.Reset()
badBackend.Reset()
})
t.Run("works with a batch request", func(t *testing.T) {
badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
}))
res, statusCode, err := client.SendBatchRPC(
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("1", "eth_chainId", nil),
)
require.NoError(t, err)
require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(asArray(goodResponse, goodResponse)), res)
require.Equal(t, 2, len(badBackend.Requests()))
require.Equal(t, 2, len(goodBackend.Requests()))
})
}
func TestRetries(t *testing.T) {
backend := NewMockBackend(SingleResponseHandler(200, goodResponse))
defer backend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", backend.URL()))
config := ReadConfig("retries")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
attempts := int32(0)
backend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
incremented := atomic.AddInt32(&attempts, 1)
if incremented != 2 {
w.WriteHeader(500)
return
}
w.Write([]byte(goodResponse))
}))
// test case where request eventually succeeds
res, statusCode, err := client.SendRPC("eth_chainId", nil)
require.NoError(t, err)
require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(goodResponse), res)
require.Equal(t, 2, len(backend.Requests()))
// test case where it does not
backend.Reset()
attempts = -10
res, statusCode, err = client.SendRPC("eth_chainId", nil)
require.NoError(t, err)
require.Equal(t, 503, statusCode)
RequireEqualJSON(t, []byte(noBackendsResponse), res)
require.Equal(t, 4, len(backend.Requests()))
}
func TestOutOfServiceInterval(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse))
defer goodBackend.Close()
badBackend := NewMockBackend(nil)
defer badBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
require.NoError(t, os.Setenv("BAD_BACKEND_RPC_URL", badBackend.URL()))
config := ReadConfig("out_of_service_interval")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(goodResponse))
})
badBackend.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(503)
}))
goodBackend.SetHandler(okHandler)
res, statusCode, err := client.SendRPC("eth_chainId", nil)
require.NoError(t, err)
require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(goodResponse), res)
require.Equal(t, 2, len(badBackend.Requests()))
require.Equal(t, 1, len(goodBackend.Requests()))
res, statusCode, err = client.SendRPC("eth_chainId", nil)
require.NoError(t, err)
require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(goodResponse), res)
require.Equal(t, 2, len(badBackend.Requests()))
require.Equal(t, 2, len(goodBackend.Requests()))
res, statusCode, err = client.SendBatchRPC(
NewRPCReq("1", "eth_chainId", nil),
NewRPCReq("1", "eth_chainId", nil),
)
require.Equal(t, 2, len(badBackend.Requests()))
require.Equal(t, 4, len(goodBackend.Requests()))
time.Sleep(time.Second)
badBackend.SetHandler(okHandler)
res, statusCode, err = client.SendRPC("eth_chainId", nil)
require.NoError(t, err)
require.Equal(t, 200, statusCode)
RequireEqualJSON(t, []byte(goodResponse), res)
require.Equal(t, 3, len(badBackend.Requests()))
require.Equal(t, 4, len(goodBackend.Requests()))
}
package integration_tests
import (
"bytes"
"context"
"encoding/json"
"github.com/ethereum-optimism/optimism/go/proxyd"
"io/ioutil"
"net/http"
"net/http/httptest"
"sync"
)
type RecordedRequest struct {
Method string
Headers http.Header
Body []byte
}
type MockBackend struct {
handler http.Handler
server *httptest.Server
mtx sync.RWMutex
requests []*RecordedRequest
}
func SingleResponseHandler(code int, response string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(code)
w.Write([]byte(response))
}
}
func RPCResponseHandler(rpcResponses map[string]string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
panic(err)
}
req, err := proxyd.ParseRPCReq(body)
if err != nil {
panic(err)
}
res := rpcResponses[req.Method]
if res == "" {
w.WriteHeader(400)
return
}
out := &proxyd.RPCRes{
JSONRPC: proxyd.JSONRPCVersion,
Result: res,
ID: req.ID,
}
enc := json.NewEncoder(w)
if err := enc.Encode(out); err != nil {
panic(err)
}
}
}
func NewMockBackend(handler http.Handler) *MockBackend {
mb := &MockBackend{
handler: handler,
}
mb.server = httptest.NewServer(http.HandlerFunc(mb.wrappedHandler))
return mb
}
func (m *MockBackend) URL() string {
return m.server.URL
}
func (m *MockBackend) Close() {
m.server.Close()
}
func (m *MockBackend) SetHandler(handler http.Handler) {
m.mtx.Lock()
m.handler = handler
m.mtx.Unlock()
}
func (m *MockBackend) Reset() {
m.mtx.Lock()
m.requests = nil
m.mtx.Unlock()
}
func (m *MockBackend) Requests() []*RecordedRequest {
m.mtx.RLock()
defer m.mtx.RUnlock()
out := make([]*RecordedRequest, len(m.requests))
for i := 0; i < len(m.requests); i++ {
out[i] = m.requests[i]
}
return out
}
func (m *MockBackend) wrappedHandler(w http.ResponseWriter, r *http.Request) {
m.mtx.Lock()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
panic(err)
}
clone := r.Clone(context.Background())
clone.Body = ioutil.NopCloser(bytes.NewReader(body))
m.requests = append(m.requests, &RecordedRequest{
Method: r.Method,
Headers: r.Header.Clone(),
Body: body,
})
m.handler.ServeHTTP(w, clone)
m.mtx.Unlock()
}
package integration_tests
import (
"github.com/ethereum-optimism/optimism/go/proxyd"
"github.com/stretchr/testify/require"
"os"
"testing"
)
type resWithCode struct {
code int
res []byte
}
func TestMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse))
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("rate_limit")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
resCh := make(chan *resWithCode)
for i := 0; i < 3; i++ {
go func() {
res, code, err := client.SendRPC("eth_chainId", nil)
require.NoError(t, err)
resCh <- &resWithCode{
code: code,
res: res,
}
}()
}
codes := make(map[int]int)
var limitedRes []byte
for i := 0; i < 3; i++ {
res := <-resCh
code := res.code
if codes[code] == 0 {
codes[code] = 1
} else {
codes[code] += 1
}
// 503 because there's only one backend available
if code == 503 {
limitedRes = res.res
}
}
require.Equal(t, 2, codes[200])
require.Equal(t, 1, codes[503])
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
}
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
[redis]
url = "$REDIS_URL"
[cache]
enabled = true
block_sync_rpc_url = "$GOOD_BACKEND_RPC_URL"
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
net_version = "main"
eth_getBlockByNumber = "main"
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backends.bad]
rpc_url = "$BAD_BACKEND_RPC_URL"
ws_url = "$BAD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["bad", "good"]
[rpc_method_mappings]
eth_chainId = "main"
\ No newline at end of file
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
max_retries = 1
out_of_service_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backends.bad]
rpc_url = "$BAD_BACKEND_RPC_URL"
ws_url = "$BAD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["bad", "good"]
[rpc_method_mappings]
eth_chainId = "main"
\ No newline at end of file
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
max_rps = 2
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
\ No newline at end of file
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
max_retries = 3
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
\ No newline at end of file
[server]
rpc_port = 8545
[backend]
response_timeout_seconds = 1
[backends]
[backends.good]
rpc_url = "$GOOD_BACKEND_RPC_URL"
ws_url = "$GOOD_BACKEND_RPC_URL"
[backend_groups]
[backend_groups.main]
backends = ["good"]
[rpc_method_mappings]
eth_chainId = "main"
\ No newline at end of file
package integration_tests
import (
"bytes"
"encoding/json"
"fmt"
"github.com/BurntSushi/toml"
"github.com/ethereum-optimism/optimism/go/proxyd"
"github.com/stretchr/testify/require"
"io/ioutil"
"net/http"
"testing"
)
type ProxydClient struct {
url string
}
func NewProxydClient(url string) *ProxydClient {
return &ProxydClient{url: url}
}
func (p *ProxydClient) SendRPC(method string, params []interface{}) ([]byte, int, error) {
rpcReq := NewRPCReq("999", method, params)
body, err := json.Marshal(rpcReq)
if err != nil {
panic(err)
}
return p.SendRequest(body)
}
func (p *ProxydClient) SendBatchRPC(reqs ...*proxyd.RPCReq) ([]byte, int, error) {
body, err := json.Marshal(reqs)
if err != nil {
panic(err)
}
return p.SendRequest(body)
}
func (p *ProxydClient) SendRequest(body []byte) ([]byte, int, error) {
res, err := http.Post(p.url, "application/json", bytes.NewReader(body))
if err != nil {
return nil, -1, err
}
defer res.Body.Close()
code := res.StatusCode
resBody, err := ioutil.ReadAll(res.Body)
if err != nil {
panic(err)
}
return resBody, code, nil
}
func RequireEqualJSON(t *testing.T, expected []byte, actual []byte) {
expJSON := canonicalizeJSON(t, expected)
actJSON := canonicalizeJSON(t, actual)
require.Equal(t, string(expJSON), string(actJSON))
}
func canonicalizeJSON(t *testing.T, in []byte) []byte {
var any interface{}
if in[0] == '[' {
any = make([]interface{}, 0)
} else {
any = make(map[string]interface{})
}
err := json.Unmarshal(in, &any)
require.NoError(t, err)
out, err := json.Marshal(any)
require.NoError(t, err)
return out
}
func ReadConfig(name string) *proxyd.Config {
config := new(proxyd.Config)
_, err := toml.DecodeFile(fmt.Sprintf("testdata/%s.toml", name), config)
if err != nil {
panic(err)
}
return config
}
func NewRPCReq(id string, method string, params []interface{}) *proxyd.RPCReq {
jsonParams, err := json.Marshal(params)
if err != nil {
panic(err)
}
return &proxyd.RPCReq{
JSONRPC: proxyd.JSONRPCVersion,
Method: method,
Params: jsonParams,
ID: []byte(id),
}
}
package integration_tests
import (
"github.com/ethereum-optimism/optimism/go/proxyd"
"github.com/stretchr/testify/require"
"os"
"strings"
"testing"
)
const (
notWhitelistedResponse = `{"jsonrpc":"2.0","error":{"code":-32001,"message":"rpc method is not whitelisted"},"id":999}`
parseErrResponse = `{"jsonrpc":"2.0","error":{"code":-32700,"message":"parse error"},"id":null}`
invalidJSONRPCVersionResponse = `{"error":{"code":-32601,"message":"invalid JSON-RPC version"},"id":null,"jsonrpc":"2.0"}`
invalidIDResponse = `{"error":{"code":-32601,"message":"invalid ID"},"id":null,"jsonrpc":"2.0"}`
invalidMethodResponse = `{"error":{"code":-32601,"message":"no method specified"},"id":null,"jsonrpc":"2.0"}`
invalidBatchLenResponse = `{"error":{"code":-32601,"message":"must specify at least one batch call"},"id":null,"jsonrpc":"2.0"}`
)
func TestSingleRPCValidation(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse))
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("whitelist")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
tests := []struct {
name string
body string
res string
code int
}{
{
"body not JSON",
"this ain't an RPC call",
parseErrResponse,
400,
},
{
"body not RPC",
"{\"not\": \"rpc\"}",
invalidJSONRPCVersionResponse,
400,
},
{
"body missing RPC ID",
"{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23]}",
invalidIDResponse,
400,
},
{
"body has array ID",
"{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": []}",
invalidIDResponse,
400,
},
{
"body has object ID",
"{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": {}}",
invalidIDResponse,
400,
},
{
"bad method",
"{\"jsonrpc\": \"2.0\", \"method\": 7, \"params\": [42, 23], \"id\": 1}",
parseErrResponse,
400,
},
{
"bad JSON-RPC",
"{\"jsonrpc\": \"1.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": 1}",
invalidJSONRPCVersionResponse,
400,
},
{
"omitted method",
"{\"jsonrpc\": \"2.0\", \"params\": [42, 23], \"id\": 1}",
invalidMethodResponse,
400,
},
{
"not whitelisted method",
"{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": 999}",
notWhitelistedResponse,
403,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, code, err := client.SendRequest([]byte(tt.body))
require.NoError(t, err)
RequireEqualJSON(t, []byte(tt.res), res)
require.Equal(t, tt.code, code)
require.Equal(t, 0, len(goodBackend.Requests()))
})
}
}
func TestBatchRPCValidation(t *testing.T) {
goodBackend := NewMockBackend(SingleResponseHandler(200, goodResponse))
defer goodBackend.Close()
require.NoError(t, os.Setenv("GOOD_BACKEND_RPC_URL", goodBackend.URL()))
config := ReadConfig("whitelist")
client := NewProxydClient("http://127.0.0.1:8545")
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()
tests := []struct {
name string
body string
res string
code int
reqCount int
}{
{
"empty batch",
"[]",
invalidBatchLenResponse,
400,
0,
},
{
"bad json",
"[{,]",
parseErrResponse,
400,
0,
},
{
"not object in batch",
"[123]",
asArray(parseErrResponse),
200,
0,
},
{
"body not RPC",
"[{\"not\": \"rpc\"}]",
asArray(invalidJSONRPCVersionResponse),
200,
0,
},
{
"body missing RPC ID",
"[{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23]}]",
asArray(invalidIDResponse),
200,
0,
},
{
"body has array ID",
"[{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": []}]",
asArray(invalidIDResponse),
200,
0,
},
{
"body has object ID",
"[{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": {}}]",
asArray(invalidIDResponse),
200,
0,
},
// this happens because we can't deserialize the method into a non
// string value, and it blows up the parsing for the whole request.
{
"bad method",
"[{\"error\":{\"code\":-32600,\"message\":\"invalid request\"},\"id\":null,\"jsonrpc\":\"2.0\"}]",
asArray(invalidMethodResponse),
200,
0,
},
{
"bad JSON-RPC",
"[{\"jsonrpc\": \"1.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": 1}]",
asArray(invalidJSONRPCVersionResponse),
200,
0,
},
{
"omitted method",
"[{\"jsonrpc\": \"2.0\", \"params\": [42, 23], \"id\": 1}]",
asArray(invalidMethodResponse),
200,
0,
},
{
"not whitelisted method",
"[{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": 999}]",
asArray(notWhitelistedResponse),
200,
0,
},
{
"mixed",
asArray(
"{\"jsonrpc\": \"2.0\", \"method\": \"subtract\", \"params\": [42, 23], \"id\": 999}",
"{\"jsonrpc\": \"2.0\", \"method\": \"eth_chainId\", \"params\": [], \"id\": 123}",
"123",
),
asArray(
notWhitelistedResponse,
goodResponse,
parseErrResponse,
),
200,
1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, code, err := client.SendRequest([]byte(tt.body))
require.NoError(t, err)
RequireEqualJSON(t, []byte(tt.res), res)
require.Equal(t, tt.code, code)
require.Equal(t, tt.reqCount, len(goodBackend.Requests()))
})
}
}
func asArray(in ...string) string {
return "[" + strings.Join(in, ",") + "]"
}
......@@ -149,7 +149,6 @@ var (
Buckets: PayloadSizeBuckets,
}, []string{
"auth",
"method_name",
})
responsePayloadSizesGauge = promauto.NewHistogramVec(prometheus.HistogramOpts{
......@@ -161,6 +160,22 @@ var (
"auth",
})
cacheHitsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "cache_hits_total",
Help: "Number of cache hits.",
}, []string{
"method",
})
cacheMissesTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: MetricsNamespace,
Name: "cache_misses_total",
Help: "Number of cache misses.",
}, []string{
"method",
})
rpcSpecialErrors = []string{
"nonce too low",
"gas price too high",
......@@ -208,10 +223,18 @@ func MaybeRecordSpecialRPCError(ctx context.Context, backendName, method string,
}
}
func RecordRequestPayloadSize(ctx context.Context, method string, payloadSize int) {
requestPayloadSizesGauge.WithLabelValues(GetAuthCtx(ctx), method).Observe(float64(payloadSize))
func RecordRequestPayloadSize(ctx context.Context, payloadSize int) {
requestPayloadSizesGauge.WithLabelValues(GetAuthCtx(ctx)).Observe(float64(payloadSize))
}
func RecordResponsePayloadSize(ctx context.Context, payloadSize int) {
responsePayloadSizesGauge.WithLabelValues(GetAuthCtx(ctx)).Observe(float64(payloadSize))
}
func RecordCacheHit(method string) {
cacheHitsTotal.WithLabelValues(method).Inc()
}
func RecordCacheMiss(method string) {
cacheMissesTotal.WithLabelValues(method).Inc()
}
......@@ -7,40 +7,47 @@ import (
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func Start(config *Config) error {
func Start(config *Config) (func(), error) {
if len(config.Backends) == 0 {
return errors.New("must define at least one backend")
return nil, errors.New("must define at least one backend")
}
if len(config.BackendGroups) == 0 {
return errors.New("must define at least one backend group")
return nil, errors.New("must define at least one backend group")
}
if len(config.RPCMethodMappings) == 0 {
return errors.New("must define at least one RPC method mapping")
return nil, errors.New("must define at least one RPC method mapping")
}
for authKey := range config.Authentication {
if authKey == "none" {
return errors.New("cannot use none as an auth key")
return nil, errors.New("cannot use none as an auth key")
}
}
var redisURL string
if config.Redis.URL != "" {
rURL, err := ReadFromEnvOrConfig(config.Redis.URL)
if err != nil {
return nil, err
}
redisURL = rURL
}
var lim RateLimiter
var err error
if config.Redis == nil {
if redisURL == "" {
log.Warn("redis is not configured, using local rate limiter")
lim = NewLocalRateLimiter()
} else {
lim, err = NewRedisRateLimiter(config.Redis.URL)
lim, err = NewRedisRateLimiter(redisURL)
if err != nil {
return err
return nil, err
}
}
......@@ -51,17 +58,17 @@ func Start(config *Config) error {
rpcURL, err := ReadFromEnvOrConfig(cfg.RPCURL)
if err != nil {
return err
return nil, err
}
wsURL, err := ReadFromEnvOrConfig(cfg.WSURL)
if err != nil {
return err
return nil, err
}
if rpcURL == "" {
return fmt.Errorf("must define an RPC URL for backend %s", name)
return nil, fmt.Errorf("must define an RPC URL for backend %s", name)
}
if wsURL == "" {
return fmt.Errorf("must define a WS URL for backend %s", name)
return nil, fmt.Errorf("must define a WS URL for backend %s", name)
}
if config.BackendOptions.ResponseTimeoutSeconds != 0 {
......@@ -86,13 +93,13 @@ func Start(config *Config) error {
if cfg.Password != "" {
passwordVal, err := ReadFromEnvOrConfig(cfg.Password)
if err != nil {
return err
return nil, err
}
opts = append(opts, WithBasicAuth(cfg.Username, passwordVal))
}
tlsConfig, err := configureBackendTLS(cfg)
if err != nil {
return err
return nil, err
}
if tlsConfig != nil {
log.Info("using custom TLS config for backend", "name", name)
......@@ -113,7 +120,7 @@ func Start(config *Config) error {
backends := make([]*Backend, 0)
for _, bName := range bg.Backends {
if backendsByName[bName] == nil {
return fmt.Errorf("backend %s is not defined", bName)
return nil, fmt.Errorf("backend %s is not defined", bName)
}
backends = append(backends, backendsByName[bName])
}
......@@ -128,17 +135,17 @@ func Start(config *Config) error {
if config.WSBackendGroup != "" {
wsBackendGroup = backendGroups[config.WSBackendGroup]
if wsBackendGroup == nil {
return fmt.Errorf("ws backend group %s does not exist", config.WSBackendGroup)
return nil, fmt.Errorf("ws backend group %s does not exist", config.WSBackendGroup)
}
}
if wsBackendGroup == nil && config.Server.WSPort != 0 {
return fmt.Errorf("a ws port was defined, but no ws group was defined")
return nil, fmt.Errorf("a ws port was defined, but no ws group was defined")
}
for _, bg := range config.RPCMethodMappings {
if backendGroups[bg] == nil {
return fmt.Errorf("undefined backend group %s", bg)
return nil, fmt.Errorf("undefined backend group %s", bg)
}
}
......@@ -149,34 +156,39 @@ func Start(config *Config) error {
for secret, alias := range config.Authentication {
resolvedSecret, err := ReadFromEnvOrConfig(secret)
if err != nil {
return err
return nil, err
}
resolvedAuth[resolvedSecret] = alias
}
}
var rpcCache RPCCache
if config.Cache != nil && config.Cache.Enabled {
var latestHead *LatestBlockHead
if config.Cache.Enabled {
var getLatestBlockNumFn GetLatestBlockNumFn
if config.Cache.BlockSyncRPCURL == "" {
return nil, fmt.Errorf("block sync node required for caching")
}
blockSyncRPCURL, err := ReadFromEnvOrConfig(config.Cache.BlockSyncRPCURL)
if err != nil {
return nil, err
}
var cache Cache
if config.Redis != nil {
if cache, err = newRedisCache(config.Redis.URL); err != nil {
return err
if redisURL != "" {
if cache, err = newRedisCache(redisURL); err != nil {
return nil, err
}
} else {
log.Warn("redis is not configured, using in-memory cache")
cache = newMemoryCache()
}
var getLatestBlockNumFn GetLatestBlockNumFn
if config.Cache.BlockSyncRPCURL == "" {
return fmt.Errorf("block sync node required for caching")
}
latestHead, err := newLatestBlockHead(config.Cache.BlockSyncRPCURL)
latestHead, err = newLatestBlockHead(blockSyncRPCURL)
if err != nil {
return err
return nil, err
}
latestHead.Start()
defer latestHead.Stop()
getLatestBlockNumFn = func(ctx context.Context) (uint64, error) {
return latestHead.GetBlockNum(), nil
......@@ -194,12 +206,17 @@ func Start(config *Config) error {
rpcCache,
)
if config.Metrics != nil && config.Metrics.Enabled {
if config.Metrics.Enabled {
addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port)
log.Info("starting metrics server", "addr", addr)
go http.ListenAndServe(addr, promhttp.Handler())
}
// To allow integration tests to cleanly come up, wait
// 10ms to give the below goroutines enough time to
// encounter an error creating their servers
errTimer := time.NewTimer(10 * time.Millisecond)
if config.Server.RPCPort != 0 {
go func() {
if err := srv.RPCListenAndServe(config.Server.RPCHost, config.Server.RPCPort); err != nil {
......@@ -224,15 +241,20 @@ func Start(config *Config) error {
}()
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
recvSig := <-sig
log.Info("caught signal, shutting down", "signal", recvSig)
<-errTimer.C
log.Info("started proxyd")
return func() {
log.Info("shutting down proxyd")
if latestHead != nil {
latestHead.Stop()
}
srv.Shutdown()
if err := lim.FlushBackendWSConns(backendNames); err != nil {
log.Error("error flushing backend ws conns", "err", err)
}
return nil
log.Info("goodbye")
}, nil
}
func secondsToDuration(seconds int) time.Duration {
......
......@@ -46,30 +46,22 @@ func IsValidID(id json.RawMessage) bool {
return len(id) > 0 && id[0] != '{' && id[0] != '['
}
func ParseRPCReq(r io.Reader) (*RPCReq, error) {
body, err := ioutil.ReadAll(r)
if err != nil {
return nil, wrapErr(err, "error reading request body")
}
func ParseRPCReq(body []byte) (*RPCReq, error) {
req := new(RPCReq)
if err := json.Unmarshal(body, req); err != nil {
return nil, ErrParseErr
}
if req.JSONRPC != JSONRPCVersion {
return nil, ErrInvalidRequest("invalid JSON-RPC version")
}
if req.Method == "" {
return nil, ErrInvalidRequest("no method specified")
}
return req, nil
}
if !IsValidID(req.ID) {
return nil, ErrInvalidRequest("invalid ID")
func ParseBatchRPCReq(body []byte) ([]json.RawMessage, error) {
batch := make([]json.RawMessage, 0)
if err := json.Unmarshal(body, &batch); err != nil {
return nil, err
}
return req, nil
return batch, nil
}
func ParseRPCRes(r io.Reader) (*RPCRes, error) {
......@@ -86,6 +78,22 @@ func ParseRPCRes(r io.Reader) (*RPCRes, error) {
return res, nil
}
func ValidateRPCReq(req *RPCReq) error {
if req.JSONRPC != JSONRPCVersion {
return ErrInvalidRequest("invalid JSON-RPC version")
}
if req.Method == "" {
return ErrInvalidRequest("no method specified")
}
if !IsValidID(req.ID) {
return ErrInvalidRequest("invalid ID")
}
return nil
}
func NewRPCErrorRes(id json.RawMessage, err error) *RPCRes {
var rpcErr *RPCErr
if rr, ok := err.(*RPCErr); ok {
......@@ -103,3 +111,14 @@ func NewRPCErrorRes(id json.RawMessage, err error) *RPCRes {
ID: id,
}
}
func IsBatch(raw []byte) bool {
for _, c := range raw {
// skip insignificant whitespace (http://www.ietf.org/rfc/rfc4627.txt)
if c == 0x20 || c == 0x09 || c == 0x0a || c == 0x0d {
continue
}
return c == '['
}
return false
}
\ No newline at end of file
......@@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net/http"
"strconv"
"strings"
......@@ -22,6 +24,7 @@ const (
ContextKeyAuth = "authorization"
ContextKeyReqID = "req_id"
ContextKeyXForwardedFor = "x_forwarded_for"
MaxBatchRPCCalls = 100
)
type Server struct {
......@@ -49,6 +52,11 @@ func NewServer(
if cache == nil {
cache = &NoopRPCCache{}
}
if maxBodySize == 0 {
maxBodySize = math.MaxInt64
}
return &Server{
backendGroups: backendGroups,
wsBackendGroup: wsBackendGroup,
......@@ -122,15 +130,66 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
"user_agent", r.Header.Get("user-agent"),
)
bodyReader := &recordLenReader{Reader: io.LimitReader(r.Body, s.maxBodySize)}
req, err := ParseRPCReq(bodyReader)
body, err := ioutil.ReadAll(io.LimitReader(r.Body, s.maxBodySize))
if err != nil {
log.Error("error reading request body", "err", err)
writeRPCError(ctx, w, nil, ErrInternal)
return
}
RecordRequestPayloadSize(ctx, len(body))
if IsBatch(body) {
reqs, err := ParseBatchRPCReq(body)
if err != nil {
log.Info("rejected request with bad rpc request", "source", "rpc", "err", err)
log.Error("error parsing batch RPC request", "err", err)
RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
writeRPCError(ctx, w, nil, ErrParseErr)
return
}
if len(reqs) > MaxBatchRPCCalls {
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrTooManyBatchRequests)
writeRPCError(ctx, w, nil, ErrTooManyBatchRequests)
return
}
if len(reqs) == 0 {
writeRPCError(ctx, w, nil, ErrInvalidRequest("must specify at least one batch call"))
return
}
batchRes := make([]*RPCRes, len(reqs), len(reqs))
for i := 0; i < len(reqs); i++ {
req, err := ParseRPCReq(reqs[i])
if err != nil {
log.Info("error parsing RPC call", "source", "rpc", "err", err)
batchRes[i] = NewRPCErrorRes(nil, err)
continue
}
batchRes[i] = s.handleSingleRPC(ctx, req)
}
writeBatchRPCRes(ctx, w, batchRes)
return
}
req, err := ParseRPCReq(body)
if err != nil {
log.Info("error parsing RPC call", "source", "rpc", "err", err)
writeRPCError(ctx, w, nil, err)
return
}
RecordRequestPayloadSize(ctx, req.Method, bodyReader.Len)
backendRes := s.handleSingleRPC(ctx, req)
writeRPCRes(ctx, w, backendRes)
}
func (s *Server) handleSingleRPC(ctx context.Context, req *RPCReq) *RPCRes {
if err := ValidateRPCReq(req); err != nil {
RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
return NewRPCErrorRes(nil, err)
}
group := s.rpcMethodMappings[req.Method]
if group == "" {
......@@ -143,16 +202,11 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
"method", req.Method,
)
RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
writeRPCError(ctx, w, req.ID, ErrMethodNotWhitelisted)
return
return NewRPCErrorRes(req.ID, ErrMethodNotWhitelisted)
}
var backendRes *RPCRes
backendRes, err = s.cache.GetRPC(ctx, req)
if err == nil && backendRes != nil {
writeRPCRes(ctx, w, backendRes)
return
}
backendRes, err := s.cache.GetRPC(ctx, req)
if err != nil {
log.Warn(
"cache lookup error",
......@@ -160,6 +214,9 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
"err", err,
)
}
if backendRes != nil {
return backendRes
}
backendRes, err = s.backendGroups[group].Forward(ctx, req)
if err != nil {
......@@ -169,8 +226,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
"req_id", GetReqID(ctx),
"err", err,
)
writeRPCError(ctx, w, req.ID, err)
return
return NewRPCErrorRes(req.ID, err)
}
if backendRes.Error == nil {
......@@ -183,7 +239,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
}
}
writeRPCRes(ctx, w, backendRes)
return backendRes
}
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
......@@ -282,6 +338,7 @@ func writeRPCRes(ctx context.Context, w http.ResponseWriter, res *RPCRes) {
statusCode = res.Error.HTTPErrorCode
}
w.Header().Set("content-type", "application/json")
w.WriteHeader(statusCode)
ww := &recordLenWriter{Writer: w}
enc := json.NewEncoder(ww)
......@@ -294,6 +351,19 @@ func writeRPCRes(ctx context.Context, w http.ResponseWriter, res *RPCRes) {
RecordResponsePayloadSize(ctx, ww.Len)
}
func writeBatchRPCRes(ctx context.Context, w http.ResponseWriter, res []*RPCRes) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(200)
ww := &recordLenWriter{Writer: w}
enc := json.NewEncoder(ww)
if err := enc.Encode(res); err != nil {
log.Error("error writing batch rpc response", "err", err)
RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
return
}
RecordResponsePayloadSize(ctx, ww.Len)
}
func instrumentedHdlr(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
respTimer := prometheus.NewTimer(httpRequestDurationSumm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment