server.go 22.2 KB
Newer Older
Matthew Slipper's avatar
Matthew Slipper committed
1 2 3
package proxyd

import (
4
	"context"
5 6
	"crypto/rand"
	"encoding/hex"
7 8 9
	"encoding/json"
	"errors"
	"fmt"
10
	"io"
11
	"math"
12
	"net/http"
13
	"regexp"
14 15
	"strconv"
	"strings"
16
	"sync"
17 18
	"time"

19
	"github.com/ethereum/go-ethereum/common/hexutil"
20
	"github.com/ethereum/go-ethereum/core"
21
	"github.com/ethereum/go-ethereum/core/types"
22
	"github.com/ethereum/go-ethereum/log"
23
	"github.com/go-redis/redis/v8"
24 25 26 27
	"github.com/gorilla/mux"
	"github.com/gorilla/websocket"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/rs/cors"
Matthew Slipper's avatar
Matthew Slipper committed
28 29
)

30
const (
31 32 33
	ContextKeyAuth              = "authorization"
	ContextKeyReqID             = "req_id"
	ContextKeyXForwardedFor     = "x_forwarded_for"
34
	MaxBatchRPCCallsHardLimit   = 100
35 36
	cacheStatusHdr              = "X-Proxyd-Cache-Status"
	defaultServerTimeout        = time.Second * 10
37
	maxRequestBodyLogLen        = 2000
38
	defaultMaxUpstreamBatchSize = 10
39 40
)

41 42
var emptyArrayResponse = json.RawMessage("[]")

Matthew Slipper's avatar
Matthew Slipper committed
43
type Server struct {
44
	BackendGroups          map[string]*BackendGroup
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
	wsBackendGroup         *BackendGroup
	wsMethodWhitelist      *StringSet
	rpcMethodMappings      map[string]string
	maxBodySize            int64
	enableRequestLog       bool
	maxRequestBodyLogLen   int
	authenticatedPaths     map[string]string
	timeout                time.Duration
	maxUpstreamBatchSize   int
	maxBatchSize           int
	upgrader               *websocket.Upgrader
	mainLim                FrontendRateLimiter
	overrideLims           map[string]FrontendRateLimiter
	senderLim              FrontendRateLimiter
	limExemptOrigins       []*regexp.Regexp
	limExemptUserAgents    []*regexp.Regexp
	globallyLimitedMethods map[string]bool
	rpcServer              *http.Server
	wsServer               *http.Server
	cache                  RPCCache
	srvMu                  sync.Mutex
Matthew Slipper's avatar
Matthew Slipper committed
66 67
}

68 69
type limiterFunc func(method string) bool

70
func NewServer(
71 72 73 74
	backendGroups map[string]*BackendGroup,
	wsBackendGroup *BackendGroup,
	wsMethodWhitelist *StringSet,
	rpcMethodMappings map[string]string,
75
	maxBodySize int64,
76
	authenticatedPaths map[string]string,
77
	timeout time.Duration,
78
	maxUpstreamBatchSize int,
inphi's avatar
inphi committed
79
	cache RPCCache,
80
	rateLimitConfig RateLimitConfig,
81
	senderRateLimitConfig SenderRateLimitConfig,
82 83
	enableRequestLog bool,
	maxRequestBodyLogLen int,
84
	maxBatchSize int,
85
	redisClient *redis.Client,
86
) (*Server, error) {
inphi's avatar
inphi committed
87 88 89
	if cache == nil {
		cache = &NoopRPCCache{}
	}
90 91 92 93 94

	if maxBodySize == 0 {
		maxBodySize = math.MaxInt64
	}

95 96 97 98
	if timeout == 0 {
		timeout = defaultServerTimeout
	}

99 100 101 102
	if maxUpstreamBatchSize == 0 {
		maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
	}

103 104 105 106
	if maxBatchSize == 0 || maxBatchSize > MaxBatchRPCCallsHardLimit {
		maxBatchSize = MaxBatchRPCCallsHardLimit
	}

107 108 109
	limiterFactory := func(dur time.Duration, max int, prefix string) FrontendRateLimiter {
		if rateLimitConfig.UseRedis {
			return NewRedisFrontendRateLimiter(redisClient, dur, max, prefix)
110 111
		}

112 113 114 115
		return NewMemoryFrontendRateLimit(dur, max)
	}

	var mainLim FrontendRateLimiter
116 117
	limExemptOrigins := make([]*regexp.Regexp, 0)
	limExemptUserAgents := make([]*regexp.Regexp, 0)
118 119
	if rateLimitConfig.BaseRate > 0 {
		mainLim = limiterFactory(time.Duration(rateLimitConfig.BaseInterval), rateLimitConfig.BaseRate, "main")
120
		for _, origin := range rateLimitConfig.ExemptOrigins {
121 122 123 124 125
			pattern, err := regexp.Compile(origin)
			if err != nil {
				return nil, err
			}
			limExemptOrigins = append(limExemptOrigins, pattern)
126 127
		}
		for _, agent := range rateLimitConfig.ExemptUserAgents {
128 129 130 131 132
			pattern, err := regexp.Compile(agent)
			if err != nil {
				return nil, err
			}
			limExemptUserAgents = append(limExemptUserAgents, pattern)
133 134
		}
	} else {
135
		mainLim = NoopFrontendRateLimiter
136 137
	}

138
	overrideLims := make(map[string]FrontendRateLimiter)
139
	globalMethodLims := make(map[string]bool)
140 141
	for method, override := range rateLimitConfig.MethodOverrides {
		var err error
142
		overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method)
143 144 145
		if err != nil {
			return nil, err
		}
146 147 148 149

		if override.Global {
			globalMethodLims[method] = true
		}
150
	}
151 152 153 154
	var senderLim FrontendRateLimiter
	if senderRateLimitConfig.Enabled {
		senderLim = limiterFactory(time.Duration(senderRateLimitConfig.Interval), senderRateLimitConfig.Limit, "senders")
	}
155

Matthew Slipper's avatar
Matthew Slipper committed
156
	return &Server{
157
		BackendGroups:        backendGroups,
158 159 160 161 162 163 164 165
		wsBackendGroup:       wsBackendGroup,
		wsMethodWhitelist:    wsMethodWhitelist,
		rpcMethodMappings:    rpcMethodMappings,
		maxBodySize:          maxBodySize,
		authenticatedPaths:   authenticatedPaths,
		timeout:              timeout,
		maxUpstreamBatchSize: maxUpstreamBatchSize,
		cache:                cache,
166 167
		enableRequestLog:     enableRequestLog,
		maxRequestBodyLogLen: maxRequestBodyLogLen,
168
		maxBatchSize:         maxBatchSize,
169 170 171
		upgrader: &websocket.Upgrader{
			HandshakeTimeout: 5 * time.Second,
		},
172 173 174 175 176 177
		mainLim:                mainLim,
		overrideLims:           overrideLims,
		globallyLimitedMethods: globalMethodLims,
		senderLim:              senderLim,
		limExemptOrigins:       limExemptOrigins,
		limExemptUserAgents:    limExemptUserAgents,
178
	}, nil
Matthew Slipper's avatar
Matthew Slipper committed
179 180
}

181
func (s *Server) RPCListenAndServe(host string, port int) error {
182
	s.srvMu.Lock()
Matthew Slipper's avatar
Matthew Slipper committed
183 184
	hdlr := mux.NewRouter()
	hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
185 186
	hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
	hdlr.HandleFunc("/{authorization}", s.HandleRPC).Methods("POST")
187
	c := cors.New(cors.Options{
188 189
		AllowedOrigins: []string{"*"},
	})
Matthew Slipper's avatar
Matthew Slipper committed
190
	addr := fmt.Sprintf("%s:%d", host, port)
191
	s.rpcServer = &http.Server{
192
		Handler: instrumentedHdlr(c.Handler(hdlr)),
Matthew Slipper's avatar
Matthew Slipper committed
193 194 195
		Addr:    addr,
	}
	log.Info("starting HTTP server", "addr", addr)
196
	s.srvMu.Unlock()
197 198 199 200
	return s.rpcServer.ListenAndServe()
}

func (s *Server) WSListenAndServe(host string, port int) error {
201
	s.srvMu.Lock()
202 203 204 205 206 207 208 209 210 211 212 213
	hdlr := mux.NewRouter()
	hdlr.HandleFunc("/", s.HandleWS)
	hdlr.HandleFunc("/{authorization}", s.HandleWS)
	c := cors.New(cors.Options{
		AllowedOrigins: []string{"*"},
	})
	addr := fmt.Sprintf("%s:%d", host, port)
	s.wsServer = &http.Server{
		Handler: instrumentedHdlr(c.Handler(hdlr)),
		Addr:    addr,
	}
	log.Info("starting WS server", "addr", addr)
214
	s.srvMu.Unlock()
215
	return s.wsServer.ListenAndServe()
216 217 218
}

func (s *Server) Shutdown() {
219 220
	s.srvMu.Lock()
	defer s.srvMu.Unlock()
221
	if s.rpcServer != nil {
222
		_ = s.rpcServer.Shutdown(context.Background())
223 224
	}
	if s.wsServer != nil {
225
		_ = s.wsServer.Shutdown(context.Background())
226
	}
227
	for _, bg := range s.BackendGroups {
228
		bg.Shutdown()
229
	}
Matthew Slipper's avatar
Matthew Slipper committed
230 231 232
}

func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
233
	_, _ = w.Write([]byte("OK"))
Matthew Slipper's avatar
Matthew Slipper committed
234 235 236
}

func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
237
	ctx := s.populateContext(w, r)
238 239 240
	if ctx == nil {
		return
	}
241 242 243
	var cancel context.CancelFunc
	ctx, cancel = context.WithTimeout(ctx, s.timeout)
	defer cancel()
244

245 246 247 248
	origin := r.Header.Get("Origin")
	userAgent := r.Header.Get("User-Agent")
	// Use XFF in context since it will automatically be replaced by the remote IP
	xff := stripXFF(GetXForwardedFor(ctx))
249 250 251 252 253 254 255 256 257
	isUnlimitedOrigin := s.isUnlimitedOrigin(origin)
	isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent)

	if xff == "" {
		writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP"))
		return
	}

	isLimited := func(method string) bool {
258 259
		isGloballyLimitedMethod := s.isGlobalLimit(method)
		if !isGloballyLimitedMethod && (isUnlimitedOrigin || isUnlimitedUserAgent) {
260 261 262
			return false
		}

263
		var lim FrontendRateLimiter
264 265
		if method == "" {
			lim = s.mainLim
266
		} else {
267
			lim = s.overrideLims[method]
268
		}
269 270 271 272 273

		if lim == nil {
			return false
		}

274 275 276 277 278
		ok, err := lim.Take(ctx, xff)
		if err != nil {
			log.Warn("error taking rate limit", "err", err)
			return true
		}
279
		return !ok
280
	}
281 282

	if isLimited("") {
283
		RecordRPCError(ctx, BackendProxyd, "unknown", ErrOverRateLimit)
284 285 286 287 288 289 290 291
		log.Warn(
			"rate limited request",
			"req_id", GetReqID(ctx),
			"auth", GetAuthCtx(ctx),
			"user_agent", userAgent,
			"origin", origin,
			"remote_ip", xff,
		)
292
		writeRPCError(ctx, w, nil, ErrOverRateLimit)
293 294 295
		return
	}

296 297 298 299
	log.Info(
		"received RPC request",
		"req_id", GetReqID(ctx),
		"auth", GetAuthCtx(ctx),
300
		"user_agent", userAgent,
301 302
		"origin", origin,
		"remote_ip", xff,
303
	)
304

305
	body, err := io.ReadAll(io.LimitReader(r.Body, s.maxBodySize))
Matthew Slipper's avatar
Matthew Slipper committed
306
	if err != nil {
307 308 309 310 311 312
		log.Error("error reading request body", "err", err)
		writeRPCError(ctx, w, nil, ErrInternal)
		return
	}
	RecordRequestPayloadSize(ctx, len(body))

313 314 315 316 317 318 319
	if s.enableRequestLog {
		log.Info("Raw RPC request",
			"body", truncate(string(body), s.maxRequestBodyLogLen),
			"req_id", GetReqID(ctx),
			"auth", GetAuthCtx(ctx),
		)
	}
320

Matthew Slipper's avatar
Matthew Slipper committed
321
	if IsBatch(body) {
322 323 324 325 326 327 328 329
		reqs, err := ParseBatchRPCReq(body)
		if err != nil {
			log.Error("error parsing batch RPC request", "err", err)
			RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
			writeRPCError(ctx, w, nil, ErrParseErr)
			return
		}

330 331 332
		RecordBatchSize(len(reqs))

		if len(reqs) > s.maxBatchSize {
333 334 335 336 337 338 339 340 341 342
			RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrTooManyBatchRequests)
			writeRPCError(ctx, w, nil, ErrTooManyBatchRequests)
			return
		}

		if len(reqs) == 0 {
			writeRPCError(ctx, w, nil, ErrInvalidRequest("must specify at least one batch call"))
			return
		}

343
		batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, isLimited, true)
344 345 346 347
		if err == context.DeadlineExceeded {
			writeRPCError(ctx, w, nil, ErrGatewayTimeout)
			return
		}
348 349
		if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) ||
			errors.Is(err, ErrConsensusGetReceiptsInvalidTarget) {
350 351 352
			writeRPCError(ctx, w, nil, ErrInvalidRequest(err.Error()))
			return
		}
353 354 355
		if err != nil {
			writeRPCError(ctx, w, nil, ErrInternal)
			return
356 357
		}

358
		setCacheHeader(w, batchContainsCached)
359 360 361 362
		writeBatchRPCRes(ctx, w, batchRes)
		return
	}

363
	rawBody := json.RawMessage(body)
364
	backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false)
365
	if err != nil {
366 367 368 369 370
		if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) ||
			errors.Is(err, ErrConsensusGetReceiptsInvalidTarget) {
			writeRPCError(ctx, w, nil, ErrInvalidRequest(err.Error()))
			return
		}
371
		writeRPCError(ctx, w, nil, ErrInternal)
Matthew Slipper's avatar
Matthew Slipper committed
372 373
		return
	}
374
	setCacheHeader(w, cached)
375
	writeRPCRes(ctx, w, backendRes[0])
376 377
}

378
func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, error) {
379 380 381 382 383 384 385 386 387
	// A request set is transformed into groups of batches.
	// Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints)
	// A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's
	// forwarded to the backend. This is done to ensure that the order of JSON-RPC Responses match the Request order
	// as the backend MAY return Responses out of order.
	// NOTE: Duplicate request ids induces 1-sized JSON-RPC batches
	type batchGroup struct {
		groupID      int
		backendGroup string
388
	}
389

390 391 392
	responses := make([]*RPCRes, len(reqs))
	batches := make(map[batchGroup][]batchElem)
	ids := make(map[string]int, len(reqs))
393

394 395 396 397 398 399 400
	for i := range reqs {
		parsedReq, err := ParseRPCReq(reqs[i])
		if err != nil {
			log.Info("error parsing RPC call", "source", "rpc", "err", err)
			responses[i] = NewRPCErrorRes(nil, err)
			continue
		}
401

402 403 404 405 406 407
		if err := ValidateRPCReq(parsedReq); err != nil {
			RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
			responses[i] = NewRPCErrorRes(nil, err)
			continue
		}

408 409 410 411 412 413
		if parsedReq.Method == "eth_accounts" {
			RecordRPCForward(ctx, BackendProxyd, "eth_accounts", RPCRequestSourceHTTP)
			responses[i] = NewRPCRes(parsedReq.ID, emptyArrayResponse)
			continue
		}

414 415 416 417 418 419 420
		group := s.rpcMethodMappings[parsedReq.Method]
		if group == "" {
			// use unknown below to prevent DOS vector that fills up memory
			// with arbitrary method names.
			log.Info(
				"blocked request for non-whitelisted method",
				"source", "rpc",
421
				"req_id", GetReqID(ctx),
422
				"method", parsedReq.Method,
423
			)
424 425 426 427 428
			RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
			responses[i] = NewRPCErrorRes(parsedReq.ID, ErrMethodNotWhitelisted)
			continue
		}

429 430 431 432 433 434 435 436 437 438 439
		// Take rate limit for specific methods.
		// NOTE: eventually, this should apply to all batch requests. However,
		// since we don't have data right now on the size of each batch, we
		// only apply this to the methods that have an additional rate limit.
		if _, ok := s.overrideLims[parsedReq.Method]; ok && isLimited(parsedReq.Method) {
			log.Info(
				"rate limited specific RPC",
				"source", "rpc",
				"req_id", GetReqID(ctx),
				"method", parsedReq.Method,
			)
440 441
			RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit)
			responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit)
442 443 444
			continue
		}

445 446 447 448 449 450 451 452 453 454 455
		// Apply a sender-based rate limit if it is enabled. Note that sender-based rate
		// limits apply regardless of origin or user-agent. As such, they don't use the
		// isLimited method.
		if parsedReq.Method == "eth_sendRawTransaction" && s.senderLim != nil {
			if err := s.rateLimitSender(ctx, parsedReq); err != nil {
				RecordRPCError(ctx, BackendProxyd, parsedReq.Method, err)
				responses[i] = NewRPCErrorRes(parsedReq.ID, err)
				continue
			}
		}

456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
		id := string(parsedReq.ID)
		// If this is a duplicate Request ID, move the Request to a new batchGroup
		ids[id]++
		batchGroupID := ids[id]
		batchGroup := batchGroup{groupID: batchGroupID, backendGroup: group}
		batches[batchGroup] = append(batches[batchGroup], batchElem{parsedReq, i})
	}

	var cached bool
	for group, batch := range batches {
		var cacheMisses []batchElem

		for _, req := range batch {
			backendRes, _ := s.cache.GetRPC(ctx, req.Req)
			if backendRes != nil {
				responses[req.Index] = backendRes
				cached = true
			} else {
				cacheMisses = append(cacheMisses, req)
			}
		}

		// Create minibatches - each minibatch must be no larger than the maxUpstreamBatchSize
		numBatches := int(math.Ceil(float64(len(cacheMisses)) / float64(s.maxUpstreamBatchSize)))
		for i := 0; i < numBatches; i++ {
			if ctx.Err() == context.DeadlineExceeded {
				log.Info("short-circuiting batch RPC",
					"req_id", GetReqID(ctx),
					"auth", GetAuthCtx(ctx),
					"batch_index", i,
				)
				batchRPCShortCircuitsTotal.Inc()
				return nil, false, context.DeadlineExceeded
			}

			start := i * s.maxUpstreamBatchSize
			end := int(math.Min(float64(start+s.maxUpstreamBatchSize), float64(len(cacheMisses))))
			elems := cacheMisses[start:end]
494
			res, err := s.BackendGroups[group.backendGroup].Forward(ctx, createBatchRequest(elems), isBatch)
495
			if err != nil {
496 497
				if errors.Is(err, ErrConsensusGetReceiptsCantBeBatched) ||
					errors.Is(err, ErrConsensusGetReceiptsInvalidTarget) {
498 499
					return nil, false, err
				}
500 501 502 503
				log.Error(
					"error forwarding RPC batch",
					"batch_size", len(elems),
					"backend_group", group,
504
					"req_id", GetReqID(ctx),
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
					"err", err,
				)
				res = nil
				for _, elem := range elems {
					res = append(res, NewRPCErrorRes(elem.Req.ID, err))
				}
			}

			for i := range elems {
				responses[elems[i].Index] = res[i]

				// TODO(inphi): batch put these
				if res[i].Error == nil && res[i].Result != nil {
					if err := s.cache.PutRPC(ctx, elems[i].Req, res[i]); err != nil {
						log.Warn(
							"cache put error",
							"req_id", GetReqID(ctx),
							"err", err,
						)
					}
				}
			}
527 528 529
		}
	}

530
	return responses, cached, nil
531 532 533
}

func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
534
	ctx := s.populateContext(w, r)
535 536 537 538
	if ctx == nil {
		return
	}

539 540
	log.Info("received WS connection", "req_id", GetReqID(ctx))

541
	clientConn, err := s.upgrader.Upgrade(w, r, nil)
Matthew Slipper's avatar
Matthew Slipper committed
542
	if err != nil {
543
		log.Error("error upgrading client conn", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
Matthew Slipper's avatar
Matthew Slipper committed
544 545 546
		return
	}

547
	proxier, err := s.wsBackendGroup.ProxyWS(ctx, clientConn, s.wsMethodWhitelist)
Matthew Slipper's avatar
Matthew Slipper committed
548
	if err != nil {
549
		if errors.Is(err, ErrNoBackends) {
550
			RecordUnserviceableRequest(ctx, RPCRequestSourceWS)
Matthew Slipper's avatar
Matthew Slipper committed
551
		}
552
		log.Error("error dialing ws backend", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
553
		clientConn.Close()
Matthew Slipper's avatar
Matthew Slipper committed
554 555 556
		return
	}

557
	activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Inc()
558 559
	go func() {
		// Below call blocks so run it in a goroutine.
560
		if err := proxier.Proxy(ctx); err != nil {
561
			log.Error("error proxying websocket", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
562
		}
563
		activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Dec()
564
	}()
565 566

	log.Info("accepted WS connection", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx))
Matthew Slipper's avatar
Matthew Slipper committed
567 568
}

569
func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context {
570 571
	vars := mux.Vars(r)
	authorization := vars["authorization"]
572 573 574 575 576 577 578 579
	xff := r.Header.Get("X-Forwarded-For")
	if xff == "" {
		ipPort := strings.Split(r.RemoteAddr, ":")
		if len(ipPort) == 2 {
			xff = ipPort[0]
		}
	}
	ctx := context.WithValue(r.Context(), ContextKeyXForwardedFor, xff) // nolint:staticcheck
580

581
	if len(s.authenticatedPaths) > 0 {
582 583 584 585 586
		if authorization == "" || s.authenticatedPaths[authorization] == "" {
			log.Info("blocked unauthorized request", "authorization", authorization)
			httpResponseCodesTotal.WithLabelValues("401").Inc()
			w.WriteHeader(401)
			return nil
587
		}
588

589
		ctx = context.WithValue(ctx, ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck
590 591
	}

592 593
	return context.WithValue(
		ctx,
594
		ContextKeyReqID, // nolint:staticcheck
595 596
		randStr(10),
	)
597 598
}

599 600 601 602 603 604 605 606
func randStr(l int) string {
	b := make([]byte, l)
	if _, err := rand.Read(b); err != nil {
		panic(err)
	}
	return hex.EncodeToString(b)
}

607
func (s *Server) isUnlimitedOrigin(origin string) bool {
608 609 610 611 612 613 614
	for _, pat := range s.limExemptOrigins {
		if pat.MatchString(origin) {
			return true
		}
	}

	return false
615 616 617
}

func (s *Server) isUnlimitedUserAgent(origin string) bool {
618 619 620 621 622 623
	for _, pat := range s.limExemptUserAgents {
		if pat.MatchString(origin) {
			return true
		}
	}
	return false
624 625
}

626 627 628 629
func (s *Server) isGlobalLimit(method string) bool {
	return s.globallyLimitedMethods[method]
}

630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
func (s *Server) rateLimitSender(ctx context.Context, req *RPCReq) error {
	var params []string
	if err := json.Unmarshal(req.Params, &params); err != nil {
		log.Debug("error unmarshaling raw transaction params", "err", err, "req_Id", GetReqID(ctx))
		return ErrParseErr
	}

	if len(params) != 1 {
		log.Debug("raw transaction request has invalid number of params", "req_id", GetReqID(ctx))
		// The error below is identical to the one Geth responds with.
		return ErrInvalidParams("missing value for required argument 0")
	}

	var data hexutil.Bytes
	if err := data.UnmarshalText([]byte(params[0])); err != nil {
		log.Debug("error decoding raw tx data", "err", err, "req_id", GetReqID(ctx))
		// Geth returns the raw error from UnmarshalText.
		return ErrInvalidParams(err.Error())
	}

	// Inflates a types.Transaction object from the transaction's raw bytes.
	tx := new(types.Transaction)
	if err := tx.UnmarshalBinary(data); err != nil {
		log.Debug("could not unmarshal transaction", "err", err, "req_id", GetReqID(ctx))
		return ErrInvalidParams(err.Error())
	}

	// Convert the transaction into a Message object so that we can get the
	// sender. This method performs an ecrecover, which can be expensive.
659
	msg, err := core.TransactionToMessage(tx, types.LatestSignerForChainID(tx.ChainId()), nil)
660 661 662 663
	if err != nil {
		log.Debug("could not get message from transaction", "err", err, "req_id", GetReqID(ctx))
		return ErrInvalidParams(err.Error())
	}
664
	ok, err := s.senderLim.Take(ctx, fmt.Sprintf("%s:%d", msg.From.Hex(), tx.Nonce()))
665 666 667 668 669
	if err != nil {
		log.Error("error taking from sender limiter", "err", err, "req_id", GetReqID(ctx))
		return ErrInternal
	}
	if !ok {
670
		log.Debug("sender rate limit exceeded", "sender", msg.From.Hex(), "req_id", GetReqID(ctx))
671 672 673 674 675 676
		return ErrOverSenderRateLimit
	}

	return nil
}

677 678 679 680 681 682 683 684
func setCacheHeader(w http.ResponseWriter, cached bool) {
	if cached {
		w.Header().Set(cacheStatusHdr, "HIT")
	} else {
		w.Header().Set(cacheStatusHdr, "MISS")
	}
}

685
func writeRPCError(ctx context.Context, w http.ResponseWriter, id json.RawMessage, err error) {
686
	var res *RPCRes
687
	if r, ok := err.(*RPCErr); ok {
688
		res = NewRPCErrorRes(id, r)
689
	} else {
690
		res = NewRPCErrorRes(id, ErrInternal)
Matthew Slipper's avatar
Matthew Slipper committed
691
	}
692
	writeRPCRes(ctx, w, res)
693 694
}

695
func writeRPCRes(ctx context.Context, w http.ResponseWriter, res *RPCRes) {
696 697 698 699
	statusCode := 200
	if res.IsError() && res.Error.HTTPErrorCode != 0 {
		statusCode = res.Error.HTTPErrorCode
	}
700

Matthew Slipper's avatar
Matthew Slipper committed
701
	w.Header().Set("content-type", "application/json")
702
	w.WriteHeader(statusCode)
703 704
	ww := &recordLenWriter{Writer: w}
	enc := json.NewEncoder(ww)
705 706
	if err := enc.Encode(res); err != nil {
		log.Error("error writing rpc response", "err", err)
707 708
		RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
		return
Matthew Slipper's avatar
Matthew Slipper committed
709
	}
710
	httpResponseCodesTotal.WithLabelValues(strconv.Itoa(statusCode)).Inc()
711
	RecordResponsePayloadSize(ctx, ww.Len)
Matthew Slipper's avatar
Matthew Slipper committed
712 713
}

714
func writeBatchRPCRes(ctx context.Context, w http.ResponseWriter, res []*RPCRes) {
Matthew Slipper's avatar
Matthew Slipper committed
715
	w.Header().Set("content-type", "application/json")
716 717 718 719 720 721 722 723 724 725 726
	w.WriteHeader(200)
	ww := &recordLenWriter{Writer: w}
	enc := json.NewEncoder(ww)
	if err := enc.Encode(res); err != nil {
		log.Error("error writing batch rpc response", "err", err)
		RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
		return
	}
	RecordResponsePayloadSize(ctx, ww.Len)
}

Matthew Slipper's avatar
Matthew Slipper committed
727 728
func instrumentedHdlr(h http.Handler) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
729
		respTimer := prometheus.NewTimer(httpRequestDurationSumm)
Matthew Slipper's avatar
Matthew Slipper committed
730
		h.ServeHTTP(w, r)
731
		respTimer.ObserveDuration()
Matthew Slipper's avatar
Matthew Slipper committed
732 733
	}
}
734 735 736 737 738 739 740 741 742

func GetAuthCtx(ctx context.Context) string {
	authUser, ok := ctx.Value(ContextKeyAuth).(string)
	if !ok {
		return "none"
	}

	return authUser
}
743 744 745 746 747 748 749 750

func GetReqID(ctx context.Context) string {
	reqId, ok := ctx.Value(ContextKeyReqID).(string)
	if !ok {
		return ""
	}
	return reqId
}
751 752 753 754 755 756 757 758

func GetXForwardedFor(ctx context.Context) string {
	xff, ok := ctx.Value(ContextKeyXForwardedFor).(string)
	if !ok {
		return ""
	}
	return xff
}
759 760 761 762 763 764 765 766 767 768 769

type recordLenWriter struct {
	io.Writer
	Len int
}

func (w *recordLenWriter) Write(p []byte) (n int, err error) {
	n, err = w.Writer.Write(p)
	w.Len += n
	return
}
inphi's avatar
inphi committed
770 771 772 773 774 775 776 777 778 779

type NoopRPCCache struct{}

func (n *NoopRPCCache) GetRPC(context.Context, *RPCReq) (*RPCRes, error) {
	return nil, nil
}

func (n *NoopRPCCache) PutRPC(context.Context, *RPCReq, *RPCRes) error {
	return nil
}
780

781 782 783 784 785 786 787
func truncate(str string, maxLen int) string {
	if maxLen == 0 {
		maxLen = maxRequestBodyLogLen
	}

	if len(str) > maxLen {
		return str[:maxLen] + "..."
788 789 790 791
	} else {
		return str
	}
}
792 793 794 795 796 797 798 799 800 801 802 803 804

type batchElem struct {
	Req   *RPCReq
	Index int
}

func createBatchRequest(elems []batchElem) []*RPCReq {
	batch := make([]*RPCReq, len(elems))
	for i := range elems {
		batch[i] = elems[i].Req
	}
	return batch
}