Commit 3079a26a authored by Matthew Slipper's avatar Matthew Slipper Committed by GitHub

Merge pull request #2002 from ethereum-optimism/inphi/proxyd-cache

feat(proxyd): Support immutable RPC caching 
parents 6749ab3d 3123faed
---
'@eth-optimism/proxyd': minor
---
cache immutable RPC responses in proxyd
package proxyd
import (
"context"
"encoding/json"
"github.com/go-redis/redis/v8"
"github.com/golang/snappy"
lru "github.com/hashicorp/golang-lru"
)
type Cache interface {
Get(ctx context.Context, key string) (string, error)
Put(ctx context.Context, key string, value string) error
}
// assuming an average RPCRes size of 3 KB
const (
memoryCacheLimit = 4096
numBlockConfirmations = 50
)
type cache struct {
lru *lru.Cache
}
func newMemoryCache() *cache {
rep, _ := lru.New(memoryCacheLimit)
return &cache{rep}
}
func (c *cache) Get(ctx context.Context, key string) (string, error) {
if val, ok := c.lru.Get(key); ok {
return val.(string), nil
}
return "", nil
}
func (c *cache) Put(ctx context.Context, key string, value string) error {
c.lru.Add(key, value)
return nil
}
type redisCache struct {
rdb *redis.Client
}
func newRedisCache(url string) (*redisCache, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
rdb := redis.NewClient(opts)
if err := rdb.Ping(context.Background()).Err(); err != nil {
return nil, wrapErr(err, "error connecting to redis")
}
return &redisCache{rdb}, nil
}
func (c *redisCache) Get(ctx context.Context, key string) (string, error) {
val, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil
} else if err != nil {
return "", err
}
return val, nil
}
func (c *redisCache) Put(ctx context.Context, key string, value string) error {
err := c.rdb.Set(ctx, key, value, 0).Err()
return err
}
type GetLatestBlockNumFn func(ctx context.Context) (uint64, error)
type RPCCache interface {
GetRPC(ctx context.Context, req *RPCReq) (*RPCRes, error)
PutRPC(ctx context.Context, req *RPCReq, res *RPCRes) error
}
type rpcCache struct {
cache Cache
getLatestBlockNumFn GetLatestBlockNumFn
handlers map[string]RPCMethodHandler
}
func newRPCCache(cache Cache, getLatestBlockNumFn GetLatestBlockNumFn) RPCCache {
handlers := map[string]RPCMethodHandler{
"eth_chainId": &StaticRPCMethodHandler{"eth_chainId"},
"net_version": &StaticRPCMethodHandler{"net_version"},
"eth_getBlockByNumber": &EthGetBlockByNumberMethod{getLatestBlockNumFn},
"eth_getBlockRange": &EthGetBlockRangeMethod{getLatestBlockNumFn},
}
return &rpcCache{cache: cache, getLatestBlockNumFn: getLatestBlockNumFn, handlers: handlers}
}
func (c *rpcCache) GetRPC(ctx context.Context, req *RPCReq) (*RPCRes, error) {
handler := c.handlers[req.Method]
if handler == nil {
return nil, nil
}
cacheable, err := handler.IsCacheable(req)
if err != nil {
return nil, err
}
if !cacheable {
return nil, nil
}
key := handler.CacheKey(req)
encodedVal, err := c.cache.Get(ctx, key)
if err != nil {
return nil, err
}
if encodedVal == "" {
return nil, nil
}
val, err := snappy.Decode(nil, []byte(encodedVal))
if err != nil {
return nil, err
}
res := new(RPCRes)
err = json.Unmarshal(val, res)
if err != nil {
return nil, err
}
res.ID = req.ID
return res, nil
}
func (c *rpcCache) PutRPC(ctx context.Context, req *RPCReq, res *RPCRes) error {
handler := c.handlers[req.Method]
if handler == nil {
return nil
}
cacheable, err := handler.IsCacheable(req)
if err != nil {
return err
}
if !cacheable {
return nil
}
requiresConfirmations, err := handler.RequiresUnconfirmedBlocks(ctx, req)
if err != nil {
return err
}
if requiresConfirmations {
return nil
}
key := handler.CacheKey(req)
val := mustMarshalJSON(res)
encodedVal := snappy.Encode(nil, val)
return c.cache.Put(ctx, key, string(encodedVal))
}
package proxyd
import (
"context"
"math"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestRPCCacheWhitelist(t *testing.T) {
const blockHead = math.MaxUint64
ctx := context.Background()
fn := func(ctx context.Context) (uint64, error) {
return blockHead, nil
}
cache := newRPCCache(newMemoryCache(), fn)
ID := []byte(strconv.Itoa(1))
rpcs := []struct {
req *RPCReq
res *RPCRes
name string
}{
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_chainId",
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: "0xff",
ID: ID,
},
name: "eth_chainId",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "net_version",
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: "9999",
ID: ID,
},
name: "net_version",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockByNumber",
Params: []byte(`["0x1", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `{"difficulty": "0x1", "number": "0x1"}`,
ID: ID,
},
name: "eth_getBlockByNumber",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockByNumber",
Params: []byte(`["earliest", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `{"difficulty": "0x1", "number": "0x1"}`,
ID: ID,
},
name: "eth_getBlockByNumber earliest",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockRange",
Params: []byte(`["0x1", "0x2", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `[{"number": "0x1"}, {"number": "0x2"}]`,
ID: ID,
},
name: "eth_getBlockRange",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockRange",
Params: []byte(`["earliest", "0x2", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `[{"number": "0x1"}, {"number": "0x2"}]`,
ID: ID,
},
name: "eth_getBlockRange earliest",
},
}
for _, rpc := range rpcs {
t.Run(rpc.name, func(t *testing.T) {
err := cache.PutRPC(ctx, rpc.req, rpc.res)
require.NoError(t, err)
cachedRes, err := cache.GetRPC(ctx, rpc.req)
require.NoError(t, err)
require.Equal(t, rpc.res, cachedRes)
})
}
}
func TestRPCCacheUnsupportedMethod(t *testing.T) {
const blockHead = math.MaxUint64
ctx := context.Background()
fn := func(ctx context.Context) (uint64, error) {
return blockHead, nil
}
cache := newRPCCache(newMemoryCache(), fn)
ID := []byte(strconv.Itoa(1))
req := &RPCReq{
JSONRPC: "2.0",
Method: "eth_blockNumber",
ID: ID,
}
res := &RPCRes{
JSONRPC: "2.0",
Result: `0x1000`,
ID: ID,
}
err := cache.PutRPC(ctx, req, res)
require.NoError(t, err)
cachedRes, err := cache.GetRPC(ctx, req)
require.NoError(t, err)
require.Nil(t, cachedRes)
}
func TestRPCCacheEthGetBlockByNumberForRecentBlocks(t *testing.T) {
ctx := context.Background()
var blockHead uint64 = 2
fn := func(ctx context.Context) (uint64, error) {
return blockHead, nil
}
cache := newRPCCache(newMemoryCache(), fn)
ID := []byte(strconv.Itoa(1))
rpcs := []struct {
req *RPCReq
res *RPCRes
name string
}{
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockByNumber",
Params: []byte(`["0x1", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `{"difficulty": "0x1", "number": "0x1"}`,
ID: ID,
},
name: "recent block num",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockByNumber",
Params: []byte(`["latest", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `{"difficulty": "0x1", "number": "0x1"}`,
ID: ID,
},
name: "latest block",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockByNumber",
Params: []byte(`["pending", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `{"difficulty": "0x1", "number": "0x1"}`,
ID: ID,
},
name: "pending block",
},
}
for _, rpc := range rpcs {
t.Run(rpc.name, func(t *testing.T) {
err := cache.PutRPC(ctx, rpc.req, rpc.res)
require.NoError(t, err)
cachedRes, err := cache.GetRPC(ctx, rpc.req)
require.NoError(t, err)
require.Nil(t, cachedRes)
})
}
}
func TestRPCCacheEthGetBlockByNumberInvalidRequest(t *testing.T) {
ctx := context.Background()
const blockHead = math.MaxUint64
fn := func(ctx context.Context) (uint64, error) {
return blockHead, nil
}
cache := newRPCCache(newMemoryCache(), fn)
ID := []byte(strconv.Itoa(1))
req := &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockByNumber",
Params: []byte(`["0x1"]`), // missing required boolean param
ID: ID,
}
res := &RPCRes{
JSONRPC: "2.0",
Result: `{"difficulty": "0x1", "number": "0x1"}`,
ID: ID,
}
err := cache.PutRPC(ctx, req, res)
require.Error(t, err)
cachedRes, err := cache.GetRPC(ctx, req)
require.Error(t, err)
require.Nil(t, cachedRes)
}
func TestRPCCacheEthGetBlockRangeForRecentBlocks(t *testing.T) {
ctx := context.Background()
var blockHead uint64 = 0x1000
fn := func(ctx context.Context) (uint64, error) {
return blockHead, nil
}
cache := newRPCCache(newMemoryCache(), fn)
ID := []byte(strconv.Itoa(1))
rpcs := []struct {
req *RPCReq
res *RPCRes
name string
}{
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockRange",
Params: []byte(`["0x1", "0x1000", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `[{"number": "0x1"}, {"number": "0x2"}]`,
ID: ID,
},
name: "recent block num",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockRange",
Params: []byte(`["0x1", "latest", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `[{"number": "0x1"}, {"number": "0x2"}]`,
ID: ID,
},
name: "latest block",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockRange",
Params: []byte(`["0x1", "pending", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `[{"number": "0x1"}, {"number": "0x2"}]`,
ID: ID,
},
name: "pending block",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockRange",
Params: []byte(`["latest", "0x1000", false]`),
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `[{"number": "0x1"}, {"number": "0x2"}]`,
ID: ID,
},
name: "latest block 2",
},
}
for _, rpc := range rpcs {
t.Run(rpc.name, func(t *testing.T) {
err := cache.PutRPC(ctx, rpc.req, rpc.res)
require.NoError(t, err)
cachedRes, err := cache.GetRPC(ctx, rpc.req)
require.NoError(t, err)
require.Nil(t, cachedRes)
})
}
}
func TestRPCCacheEthGetBlockRangeInvalidRequest(t *testing.T) {
ctx := context.Background()
const blockHead = math.MaxUint64
fn := func(ctx context.Context) (uint64, error) {
return blockHead, nil
}
cache := newRPCCache(newMemoryCache(), fn)
ID := []byte(strconv.Itoa(1))
rpcs := []struct {
req *RPCReq
res *RPCRes
name string
}{
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockRange",
Params: []byte(`["0x1", "0x2"]`), // missing required boolean param
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `[{"number": "0x1"}, {"number": "0x2"}]`,
ID: ID,
},
name: "missing boolean param",
},
{
req: &RPCReq{
JSONRPC: "2.0",
Method: "eth_getBlockRange",
Params: []byte(`["abc", "0x2", true]`), // invalid block hex
ID: ID,
},
res: &RPCRes{
JSONRPC: "2.0",
Result: `[{"number": "0x1"}, {"number": "0x2"}]`,
ID: ID,
},
name: "invalid block hex",
},
}
for _, rpc := range rpcs {
t.Run(rpc.name, func(t *testing.T) {
err := cache.PutRPC(ctx, rpc.req, rpc.res)
require.Error(t, err)
cachedRes, err := cache.GetRPC(ctx, rpc.req)
require.Error(t, err)
require.Nil(t, cachedRes)
})
}
}
......@@ -14,6 +14,11 @@ type ServerConfig struct {
MaxBodySizeBytes int64 `toml:"max_body_size_bytes"`
}
type CacheConfig struct {
Enabled bool `toml:"enabled"`
BlockSyncRPCURL string `toml:"block_sync_rpc_url"`
}
type RedisConfig struct {
URL string `toml:"url"`
}
......@@ -57,6 +62,7 @@ type MethodMappingsConfig map[string]string
type Config struct {
WSBackendGroup string `toml:"ws_backend_group"`
Server *ServerConfig `toml:"server"`
Cache *CacheConfig `toml:"cache"`
Redis *RedisConfig `toml:"redis"`
Metrics *MetricsConfig `toml:"metrics"`
BackendOptions *BackendOptions `toml:"backend"`
......
......@@ -6,8 +6,11 @@ require (
github.com/BurntSushi/toml v0.4.1
github.com/ethereum/go-ethereum v1.10.11
github.com/go-redis/redis/v8 v8.11.4
github.com/golang/snappy v0.0.4
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2
github.com/hashicorp/golang-lru v0.5.5-0.20210104140557-80c98217689d
github.com/prometheus/client_golang v1.11.0
github.com/rs/cors v1.8.0
github.com/stretchr/testify v1.7.0
)
This diff is collapsed.
package proxyd
import (
"context"
"sync"
"time"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/log"
)
const blockHeadSyncPeriod = 1 * time.Second
type LatestBlockHead struct {
url string
client *ethclient.Client
quit chan struct{}
done chan struct{}
mutex sync.RWMutex
blockNum uint64
}
func newLatestBlockHead(url string) (*LatestBlockHead, error) {
client, err := ethclient.Dial(url)
if err != nil {
return nil, err
}
return &LatestBlockHead{
url: url,
client: client,
quit: make(chan struct{}),
done: make(chan struct{}),
}, nil
}
func (h *LatestBlockHead) Start() {
go func() {
ticker := time.NewTicker(blockHeadSyncPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
blockNum, err := h.getBlockNum()
if err != nil {
log.Error("error retrieving latest block number", "error", err)
continue
}
log.Trace("polling block number", "blockNum", blockNum)
h.mutex.Lock()
h.blockNum = blockNum
h.mutex.Unlock()
case <-h.quit:
close(h.done)
return
}
}
}()
}
func (h *LatestBlockHead) getBlockNum() (uint64, error) {
const maxRetries = 5
var err error
for i := 0; i <= maxRetries; i++ {
var blockNum uint64
blockNum, err = h.client.BlockNumber(context.Background())
if err != nil {
backoff := calcBackoff(i)
log.Warn("http operation failed. retrying...", "error", err, "backoff", backoff)
time.Sleep(backoff)
continue
}
return blockNum, nil
}
return 0, wrapErr(err, "exceeded retries")
}
func (h *LatestBlockHead) Stop() {
close(h.quit)
<-h.done
h.client.Close()
}
func (h *LatestBlockHead) GetBlockNum() uint64 {
h.mutex.RLock()
defer h.mutex.RUnlock()
return h.blockNum
}
package proxyd
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/common/hexutil"
)
var errInvalidRPCParams = errors.New("invalid RPC params")
type RPCMethodHandler interface {
CacheKey(req *RPCReq) string
IsCacheable(req *RPCReq) (bool, error)
RequiresUnconfirmedBlocks(ctx context.Context, req *RPCReq) (bool, error)
}
type StaticRPCMethodHandler struct {
method string
}
func (s *StaticRPCMethodHandler) CacheKey(req *RPCReq) string {
return fmt.Sprintf("method:%s", s.method)
}
func (s *StaticRPCMethodHandler) IsCacheable(*RPCReq) (bool, error) { return true, nil }
func (s *StaticRPCMethodHandler) RequiresUnconfirmedBlocks(context.Context, *RPCReq) (bool, error) {
return false, nil
}
type EthGetBlockByNumberMethod struct {
getLatestBlockNumFn GetLatestBlockNumFn
}
func (e *EthGetBlockByNumberMethod) CacheKey(req *RPCReq) string {
input, includeTx, err := decodeGetBlockByNumberParams(req.Params)
if err != nil {
return ""
}
return fmt.Sprintf("method:eth_getBlockByNumber:%s:%t", input, includeTx)
}
func (e *EthGetBlockByNumberMethod) IsCacheable(req *RPCReq) (bool, error) {
blockNum, _, err := decodeGetBlockByNumberParams(req.Params)
if err != nil {
return false, err
}
return !isBlockDependentParam(blockNum), nil
}
func (e *EthGetBlockByNumberMethod) RequiresUnconfirmedBlocks(ctx context.Context, req *RPCReq) (bool, error) {
curBlock, err := e.getLatestBlockNumFn(ctx)
if err != nil {
return false, err
}
blockInput, _, err := decodeGetBlockByNumberParams(req.Params)
if err != nil {
return false, err
}
if isBlockDependentParam(blockInput) {
return true, nil
}
if blockInput == "earliest" {
return false, nil
}
blockNum, err := decodeBlockInput(blockInput)
if err != nil {
return false, err
}
return curBlock <= blockNum+numBlockConfirmations, nil
}
type EthGetBlockRangeMethod struct {
getLatestBlockNumFn GetLatestBlockNumFn
}
func (e *EthGetBlockRangeMethod) CacheKey(req *RPCReq) string {
start, end, includeTx, err := decodeGetBlockRangeParams(req.Params)
if err != nil {
return ""
}
return fmt.Sprintf("method:eth_getBlockRange:%s:%s:%t", start, end, includeTx)
}
func (e *EthGetBlockRangeMethod) IsCacheable(req *RPCReq) (bool, error) {
start, end, _, err := decodeGetBlockRangeParams(req.Params)
if err != nil {
return false, err
}
return !isBlockDependentParam(start) && !isBlockDependentParam(end), nil
}
func (e *EthGetBlockRangeMethod) RequiresUnconfirmedBlocks(ctx context.Context, req *RPCReq) (bool, error) {
curBlock, err := e.getLatestBlockNumFn(ctx)
if err != nil {
return false, err
}
start, end, _, err := decodeGetBlockRangeParams(req.Params)
if err != nil {
return false, err
}
if isBlockDependentParam(start) || isBlockDependentParam(end) {
return true, nil
}
if start == "earliest" && end == "earliest" {
return false, nil
}
if start != "earliest" {
startNum, err := decodeBlockInput(start)
if err != nil {
return false, err
}
if curBlock <= startNum+numBlockConfirmations {
return true, nil
}
}
if end != "earliest" {
endNum, err := decodeBlockInput(end)
if err != nil {
return false, err
}
if curBlock <= endNum+numBlockConfirmations {
return true, nil
}
}
return false, nil
}
func isBlockDependentParam(s string) bool {
return s == "latest" || s == "pending"
}
func decodeGetBlockByNumberParams(params json.RawMessage) (string, bool, error) {
var list []interface{}
if err := json.Unmarshal(params, &list); err != nil {
return "", false, err
}
if len(list) != 2 {
return "", false, errInvalidRPCParams
}
blockNum, ok := list[0].(string)
if !ok {
return "", false, errInvalidRPCParams
}
includeTx, ok := list[1].(bool)
if !ok {
return "", false, errInvalidRPCParams
}
if !validBlockInput(blockNum) {
return "", false, errInvalidRPCParams
}
return blockNum, includeTx, nil
}
func decodeGetBlockRangeParams(params json.RawMessage) (string, string, bool, error) {
var list []interface{}
if err := json.Unmarshal(params, &list); err != nil {
return "", "", false, err
}
if len(list) != 3 {
return "", "", false, errInvalidRPCParams
}
startBlockNum, ok := list[0].(string)
if !ok {
return "", "", false, errInvalidRPCParams
}
endBlockNum, ok := list[1].(string)
if !ok {
return "", "", false, errInvalidRPCParams
}
includeTx, ok := list[2].(bool)
if !ok {
return "", "", false, errInvalidRPCParams
}
if !validBlockInput(startBlockNum) || !validBlockInput(endBlockNum) {
return "", "", false, errInvalidRPCParams
}
return startBlockNum, endBlockNum, includeTx, nil
}
func decodeBlockInput(input string) (uint64, error) {
return hexutil.DecodeUint64(input)
}
func validBlockInput(input string) bool {
if input == "earliest" || input == "pending" || input == "latest" {
return true
}
_, err := decodeBlockInput(input)
return err == nil
}
package proxyd
import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/log"
"github.com/prometheus/client_golang/prometheus/promhttp"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func Start(config *Config) error {
......@@ -153,6 +155,35 @@ func Start(config *Config) error {
}
}
var rpcCache RPCCache
if config.Cache != nil && config.Cache.Enabled {
var cache Cache
if config.Redis != nil {
if cache, err = newRedisCache(config.Redis.URL); err != nil {
return err
}
} else {
log.Warn("redis is not configured, using in-memory cache")
cache = newMemoryCache()
}
var getLatestBlockNumFn GetLatestBlockNumFn
if config.Cache.BlockSyncRPCURL == "" {
return fmt.Errorf("block sync node required for caching")
}
latestHead, err := newLatestBlockHead(config.Cache.BlockSyncRPCURL)
if err != nil {
return err
}
latestHead.Start()
defer latestHead.Stop()
getLatestBlockNumFn = func(ctx context.Context) (uint64, error) {
return latestHead.GetBlockNum(), nil
}
rpcCache = newRPCCache(cache, getLatestBlockNumFn)
}
srv := NewServer(
backendGroups,
wsBackendGroup,
......@@ -160,9 +191,10 @@ func Start(config *Config) error {
config.RPCMethodMappings,
config.Server.MaxBodySizeBytes,
resolvedAuth,
rpcCache,
)
if config.Metrics.Enabled {
if config.Metrics != nil && config.Metrics.Enabled {
addr := fmt.Sprintf("%s:%d", config.Metrics.Host, config.Metrics.Port)
log.Info("starting metrics server", "addr", addr)
go http.ListenAndServe(addr, promhttp.Handler())
......
......@@ -34,6 +34,7 @@ type Server struct {
upgrader *websocket.Upgrader
rpcServer *http.Server
wsServer *http.Server
cache RPCCache
}
func NewServer(
......@@ -43,7 +44,11 @@ func NewServer(
rpcMethodMappings map[string]string,
maxBodySize int64,
authenticatedPaths map[string]string,
cache RPCCache,
) *Server {
if cache == nil {
cache = &NoopRPCCache{}
}
return &Server{
backendGroups: backendGroups,
wsBackendGroup: wsBackendGroup,
......@@ -51,6 +56,7 @@ func NewServer(
rpcMethodMappings: rpcMethodMappings,
maxBodySize: maxBodySize,
authenticatedPaths: authenticatedPaths,
cache: cache,
upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
},
......@@ -141,7 +147,21 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return
}
backendRes, err := s.backendGroups[group].Forward(ctx, req)
var backendRes *RPCRes
backendRes, err = s.cache.GetRPC(ctx, req)
if err == nil && backendRes != nil {
writeRPCRes(ctx, w, backendRes)
return
}
if err != nil {
log.Warn(
"cache lookup error",
"req_id", GetReqID(ctx),
"err", err,
)
}
backendRes, err = s.backendGroups[group].Forward(ctx, req)
if err != nil {
log.Error(
"error forwarding RPC request",
......@@ -153,6 +173,16 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return
}
if backendRes.Error == nil {
if err = s.cache.PutRPC(ctx, req, backendRes); err != nil {
log.Warn(
"cache put error",
"req_id", GetReqID(ctx),
"err", err,
)
}
}
writeRPCRes(ctx, w, backendRes)
}
......@@ -318,3 +348,13 @@ func (w *recordLenWriter) Write(p []byte) (n int, err error) {
w.Len += n
return
}
type NoopRPCCache struct{}
func (n *NoopRPCCache) GetRPC(context.Context, *RPCReq) (*RPCRes, error) {
return nil, nil
}
func (n *NoopRPCCache) PutRPC(context.Context, *RPCReq, *RPCRes) error {
return nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment