server.go 21.7 KB
Newer Older
Matthew Slipper's avatar
Matthew Slipper committed
1 2 3
package proxyd

import (
4 5 6 7
	"context"
	"encoding/json"
	"errors"
	"fmt"
8
	"io"
9
	"math"
10
	"net/http"
11
	"regexp"
12 13
	"strconv"
	"strings"
14
	"sync"
15 16
	"time"

17 18 19
	"github.com/ethereum/go-ethereum/common/hexutil"
	"github.com/ethereum/go-ethereum/core/types"

20
	"github.com/ethereum/go-ethereum/log"
21
	"github.com/go-redis/redis/v8"
22 23 24 25
	"github.com/gorilla/mux"
	"github.com/gorilla/websocket"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/rs/cors"
Matthew Slipper's avatar
Matthew Slipper committed
26 27
)

28
const (
29 30 31
	ContextKeyAuth              = "authorization"
	ContextKeyReqID             = "req_id"
	ContextKeyXForwardedFor     = "x_forwarded_for"
32
	MaxBatchRPCCallsHardLimit   = 100
33 34
	cacheStatusHdr              = "X-Proxyd-Cache-Status"
	defaultServerTimeout        = time.Second * 10
35
	maxRequestBodyLogLen        = 2000
36
	defaultMaxUpstreamBatchSize = 10
37 38
)

39 40
var emptyArrayResponse = json.RawMessage("[]")

Matthew Slipper's avatar
Matthew Slipper committed
41
type Server struct {
42
	BackendGroups          map[string]*BackendGroup
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
	wsBackendGroup         *BackendGroup
	wsMethodWhitelist      *StringSet
	rpcMethodMappings      map[string]string
	maxBodySize            int64
	enableRequestLog       bool
	maxRequestBodyLogLen   int
	authenticatedPaths     map[string]string
	timeout                time.Duration
	maxUpstreamBatchSize   int
	maxBatchSize           int
	upgrader               *websocket.Upgrader
	mainLim                FrontendRateLimiter
	overrideLims           map[string]FrontendRateLimiter
	senderLim              FrontendRateLimiter
	limExemptOrigins       []*regexp.Regexp
	limExemptUserAgents    []*regexp.Regexp
	globallyLimitedMethods map[string]bool
	rpcServer              *http.Server
	wsServer               *http.Server
	cache                  RPCCache
	srvMu                  sync.Mutex
Matthew Slipper's avatar
Matthew Slipper committed
64 65
}

66 67
type limiterFunc func(method string) bool

68
func NewServer(
69 70 71 72
	backendGroups map[string]*BackendGroup,
	wsBackendGroup *BackendGroup,
	wsMethodWhitelist *StringSet,
	rpcMethodMappings map[string]string,
73
	maxBodySize int64,
74
	authenticatedPaths map[string]string,
75
	timeout time.Duration,
76
	maxUpstreamBatchSize int,
inphi's avatar
inphi committed
77
	cache RPCCache,
78
	rateLimitConfig RateLimitConfig,
79
	senderRateLimitConfig SenderRateLimitConfig,
80 81
	enableRequestLog bool,
	maxRequestBodyLogLen int,
82
	maxBatchSize int,
83
	redisClient *redis.Client,
84
) (*Server, error) {
inphi's avatar
inphi committed
85 86 87
	if cache == nil {
		cache = &NoopRPCCache{}
	}
88 89 90 91 92

	if maxBodySize == 0 {
		maxBodySize = math.MaxInt64
	}

93 94 95 96
	if timeout == 0 {
		timeout = defaultServerTimeout
	}

97 98 99 100
	if maxUpstreamBatchSize == 0 {
		maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
	}

101 102 103 104
	if maxBatchSize == 0 || maxBatchSize > MaxBatchRPCCallsHardLimit {
		maxBatchSize = MaxBatchRPCCallsHardLimit
	}

105 106 107
	limiterFactory := func(dur time.Duration, max int, prefix string) FrontendRateLimiter {
		if rateLimitConfig.UseRedis {
			return NewRedisFrontendRateLimiter(redisClient, dur, max, prefix)
108 109
		}

110 111 112 113
		return NewMemoryFrontendRateLimit(dur, max)
	}

	var mainLim FrontendRateLimiter
114 115
	limExemptOrigins := make([]*regexp.Regexp, 0)
	limExemptUserAgents := make([]*regexp.Regexp, 0)
116 117
	if rateLimitConfig.BaseRate > 0 {
		mainLim = limiterFactory(time.Duration(rateLimitConfig.BaseInterval), rateLimitConfig.BaseRate, "main")
118
		for _, origin := range rateLimitConfig.ExemptOrigins {
119 120 121 122 123
			pattern, err := regexp.Compile(origin)
			if err != nil {
				return nil, err
			}
			limExemptOrigins = append(limExemptOrigins, pattern)
124 125
		}
		for _, agent := range rateLimitConfig.ExemptUserAgents {
126 127 128 129 130
			pattern, err := regexp.Compile(agent)
			if err != nil {
				return nil, err
			}
			limExemptUserAgents = append(limExemptUserAgents, pattern)
131 132
		}
	} else {
133
		mainLim = NoopFrontendRateLimiter
134 135
	}

136
	overrideLims := make(map[string]FrontendRateLimiter)
137
	globalMethodLims := make(map[string]bool)
138 139
	for method, override := range rateLimitConfig.MethodOverrides {
		var err error
140
		overrideLims[method] = limiterFactory(time.Duration(override.Interval), override.Limit, method)
141 142 143
		if err != nil {
			return nil, err
		}
144 145 146 147

		if override.Global {
			globalMethodLims[method] = true
		}
148
	}
149 150 151 152
	var senderLim FrontendRateLimiter
	if senderRateLimitConfig.Enabled {
		senderLim = limiterFactory(time.Duration(senderRateLimitConfig.Interval), senderRateLimitConfig.Limit, "senders")
	}
153

Matthew Slipper's avatar
Matthew Slipper committed
154
	return &Server{
155
		BackendGroups:        backendGroups,
156 157 158 159 160 161 162 163
		wsBackendGroup:       wsBackendGroup,
		wsMethodWhitelist:    wsMethodWhitelist,
		rpcMethodMappings:    rpcMethodMappings,
		maxBodySize:          maxBodySize,
		authenticatedPaths:   authenticatedPaths,
		timeout:              timeout,
		maxUpstreamBatchSize: maxUpstreamBatchSize,
		cache:                cache,
164 165
		enableRequestLog:     enableRequestLog,
		maxRequestBodyLogLen: maxRequestBodyLogLen,
166
		maxBatchSize:         maxBatchSize,
167 168 169
		upgrader: &websocket.Upgrader{
			HandshakeTimeout: 5 * time.Second,
		},
170 171 172 173 174 175
		mainLim:                mainLim,
		overrideLims:           overrideLims,
		globallyLimitedMethods: globalMethodLims,
		senderLim:              senderLim,
		limExemptOrigins:       limExemptOrigins,
		limExemptUserAgents:    limExemptUserAgents,
176
	}, nil
Matthew Slipper's avatar
Matthew Slipper committed
177 178
}

179
func (s *Server) RPCListenAndServe(host string, port int) error {
180
	s.srvMu.Lock()
Matthew Slipper's avatar
Matthew Slipper committed
181 182
	hdlr := mux.NewRouter()
	hdlr.HandleFunc("/healthz", s.HandleHealthz).Methods("GET")
183 184
	hdlr.HandleFunc("/", s.HandleRPC).Methods("POST")
	hdlr.HandleFunc("/{authorization}", s.HandleRPC).Methods("POST")
185
	c := cors.New(cors.Options{
186 187
		AllowedOrigins: []string{"*"},
	})
Matthew Slipper's avatar
Matthew Slipper committed
188
	addr := fmt.Sprintf("%s:%d", host, port)
189
	s.rpcServer = &http.Server{
190
		Handler: instrumentedHdlr(c.Handler(hdlr)),
Matthew Slipper's avatar
Matthew Slipper committed
191 192 193
		Addr:    addr,
	}
	log.Info("starting HTTP server", "addr", addr)
194
	s.srvMu.Unlock()
195 196 197 198
	return s.rpcServer.ListenAndServe()
}

func (s *Server) WSListenAndServe(host string, port int) error {
199
	s.srvMu.Lock()
200 201 202 203 204 205 206 207 208 209 210 211
	hdlr := mux.NewRouter()
	hdlr.HandleFunc("/", s.HandleWS)
	hdlr.HandleFunc("/{authorization}", s.HandleWS)
	c := cors.New(cors.Options{
		AllowedOrigins: []string{"*"},
	})
	addr := fmt.Sprintf("%s:%d", host, port)
	s.wsServer = &http.Server{
		Handler: instrumentedHdlr(c.Handler(hdlr)),
		Addr:    addr,
	}
	log.Info("starting WS server", "addr", addr)
212
	s.srvMu.Unlock()
213
	return s.wsServer.ListenAndServe()
214 215 216
}

func (s *Server) Shutdown() {
217 218
	s.srvMu.Lock()
	defer s.srvMu.Unlock()
219
	if s.rpcServer != nil {
220
		_ = s.rpcServer.Shutdown(context.Background())
221 222
	}
	if s.wsServer != nil {
223
		_ = s.wsServer.Shutdown(context.Background())
224
	}
Matthew Slipper's avatar
Matthew Slipper committed
225 226 227
}

func (s *Server) HandleHealthz(w http.ResponseWriter, r *http.Request) {
228
	_, _ = w.Write([]byte("OK"))
Matthew Slipper's avatar
Matthew Slipper committed
229 230 231
}

func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
232
	ctx := s.populateContext(w, r)
233 234 235
	if ctx == nil {
		return
	}
236 237 238
	var cancel context.CancelFunc
	ctx, cancel = context.WithTimeout(ctx, s.timeout)
	defer cancel()
239

240 241 242 243
	origin := r.Header.Get("Origin")
	userAgent := r.Header.Get("User-Agent")
	// Use XFF in context since it will automatically be replaced by the remote IP
	xff := stripXFF(GetXForwardedFor(ctx))
244 245 246 247 248 249 250 251 252
	isUnlimitedOrigin := s.isUnlimitedOrigin(origin)
	isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent)

	if xff == "" {
		writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP"))
		return
	}

	isLimited := func(method string) bool {
253 254
		isGloballyLimitedMethod := s.isGlobalLimit(method)
		if !isGloballyLimitedMethod && (isUnlimitedOrigin || isUnlimitedUserAgent) {
255 256 257
			return false
		}

258
		var lim FrontendRateLimiter
259 260
		if method == "" {
			lim = s.mainLim
261
		} else {
262
			lim = s.overrideLims[method]
263
		}
264 265 266 267 268

		if lim == nil {
			return false
		}

269 270 271 272 273
		ok, err := lim.Take(ctx, xff)
		if err != nil {
			log.Warn("error taking rate limit", "err", err)
			return true
		}
274
		return !ok
275
	}
276 277

	if isLimited("") {
278
		RecordRPCError(ctx, BackendProxyd, "unknown", ErrOverRateLimit)
279 280 281 282 283 284 285 286
		log.Warn(
			"rate limited request",
			"req_id", GetReqID(ctx),
			"auth", GetAuthCtx(ctx),
			"user_agent", userAgent,
			"origin", origin,
			"remote_ip", xff,
		)
287
		writeRPCError(ctx, w, nil, ErrOverRateLimit)
288 289 290
		return
	}

291 292 293 294
	log.Info(
		"received RPC request",
		"req_id", GetReqID(ctx),
		"auth", GetAuthCtx(ctx),
295
		"user_agent", userAgent,
296 297
		"origin", origin,
		"remote_ip", xff,
298
	)
299

300
	body, err := io.ReadAll(io.LimitReader(r.Body, s.maxBodySize))
Matthew Slipper's avatar
Matthew Slipper committed
301
	if err != nil {
302 303 304 305 306 307
		log.Error("error reading request body", "err", err)
		writeRPCError(ctx, w, nil, ErrInternal)
		return
	}
	RecordRequestPayloadSize(ctx, len(body))

308 309 310 311 312 313 314
	if s.enableRequestLog {
		log.Info("Raw RPC request",
			"body", truncate(string(body), s.maxRequestBodyLogLen),
			"req_id", GetReqID(ctx),
			"auth", GetAuthCtx(ctx),
		)
	}
315

Matthew Slipper's avatar
Matthew Slipper committed
316
	if IsBatch(body) {
317 318 319 320 321 322 323 324
		reqs, err := ParseBatchRPCReq(body)
		if err != nil {
			log.Error("error parsing batch RPC request", "err", err)
			RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
			writeRPCError(ctx, w, nil, ErrParseErr)
			return
		}

325 326 327
		RecordBatchSize(len(reqs))

		if len(reqs) > s.maxBatchSize {
328 329 330 331 332 333 334 335 336 337
			RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrTooManyBatchRequests)
			writeRPCError(ctx, w, nil, ErrTooManyBatchRequests)
			return
		}

		if len(reqs) == 0 {
			writeRPCError(ctx, w, nil, ErrInvalidRequest("must specify at least one batch call"))
			return
		}

338
		batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, isLimited, true)
339 340 341 342 343 344 345
		if err == context.DeadlineExceeded {
			writeRPCError(ctx, w, nil, ErrGatewayTimeout)
			return
		}
		if err != nil {
			writeRPCError(ctx, w, nil, ErrInternal)
			return
346 347
		}

348
		setCacheHeader(w, batchContainsCached)
349 350 351 352
		writeBatchRPCRes(ctx, w, batchRes)
		return
	}

353
	rawBody := json.RawMessage(body)
354
	backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false)
355
	if err != nil {
356
		writeRPCError(ctx, w, nil, ErrInternal)
Matthew Slipper's avatar
Matthew Slipper committed
357 358
		return
	}
359
	setCacheHeader(w, cached)
360
	writeRPCRes(ctx, w, backendRes[0])
361 362
}

363
func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, error) {
364 365 366 367 368 369 370 371 372
	// A request set is transformed into groups of batches.
	// Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints)
	// A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's
	// forwarded to the backend. This is done to ensure that the order of JSON-RPC Responses match the Request order
	// as the backend MAY return Responses out of order.
	// NOTE: Duplicate request ids induces 1-sized JSON-RPC batches
	type batchGroup struct {
		groupID      int
		backendGroup string
373
	}
374

375 376 377
	responses := make([]*RPCRes, len(reqs))
	batches := make(map[batchGroup][]batchElem)
	ids := make(map[string]int, len(reqs))
378

379 380 381 382 383 384 385
	for i := range reqs {
		parsedReq, err := ParseRPCReq(reqs[i])
		if err != nil {
			log.Info("error parsing RPC call", "source", "rpc", "err", err)
			responses[i] = NewRPCErrorRes(nil, err)
			continue
		}
386

387 388 389 390 391 392
		if err := ValidateRPCReq(parsedReq); err != nil {
			RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
			responses[i] = NewRPCErrorRes(nil, err)
			continue
		}

393 394 395 396 397 398
		if parsedReq.Method == "eth_accounts" {
			RecordRPCForward(ctx, BackendProxyd, "eth_accounts", RPCRequestSourceHTTP)
			responses[i] = NewRPCRes(parsedReq.ID, emptyArrayResponse)
			continue
		}

399 400 401 402 403 404 405
		group := s.rpcMethodMappings[parsedReq.Method]
		if group == "" {
			// use unknown below to prevent DOS vector that fills up memory
			// with arbitrary method names.
			log.Info(
				"blocked request for non-whitelisted method",
				"source", "rpc",
406
				"req_id", GetReqID(ctx),
407
				"method", parsedReq.Method,
408
			)
409 410 411 412 413
			RecordRPCError(ctx, BackendProxyd, MethodUnknown, ErrMethodNotWhitelisted)
			responses[i] = NewRPCErrorRes(parsedReq.ID, ErrMethodNotWhitelisted)
			continue
		}

414 415 416 417 418 419 420 421 422 423 424
		// Take rate limit for specific methods.
		// NOTE: eventually, this should apply to all batch requests. However,
		// since we don't have data right now on the size of each batch, we
		// only apply this to the methods that have an additional rate limit.
		if _, ok := s.overrideLims[parsedReq.Method]; ok && isLimited(parsedReq.Method) {
			log.Info(
				"rate limited specific RPC",
				"source", "rpc",
				"req_id", GetReqID(ctx),
				"method", parsedReq.Method,
			)
425 426
			RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit)
			responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit)
427 428 429
			continue
		}

430 431 432 433 434 435 436 437 438 439 440
		// Apply a sender-based rate limit if it is enabled. Note that sender-based rate
		// limits apply regardless of origin or user-agent. As such, they don't use the
		// isLimited method.
		if parsedReq.Method == "eth_sendRawTransaction" && s.senderLim != nil {
			if err := s.rateLimitSender(ctx, parsedReq); err != nil {
				RecordRPCError(ctx, BackendProxyd, parsedReq.Method, err)
				responses[i] = NewRPCErrorRes(parsedReq.ID, err)
				continue
			}
		}

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
		id := string(parsedReq.ID)
		// If this is a duplicate Request ID, move the Request to a new batchGroup
		ids[id]++
		batchGroupID := ids[id]
		batchGroup := batchGroup{groupID: batchGroupID, backendGroup: group}
		batches[batchGroup] = append(batches[batchGroup], batchElem{parsedReq, i})
	}

	var cached bool
	for group, batch := range batches {
		var cacheMisses []batchElem

		for _, req := range batch {
			backendRes, _ := s.cache.GetRPC(ctx, req.Req)
			if backendRes != nil {
				responses[req.Index] = backendRes
				cached = true
			} else {
				cacheMisses = append(cacheMisses, req)
			}
		}

		// Create minibatches - each minibatch must be no larger than the maxUpstreamBatchSize
		numBatches := int(math.Ceil(float64(len(cacheMisses)) / float64(s.maxUpstreamBatchSize)))
		for i := 0; i < numBatches; i++ {
			if ctx.Err() == context.DeadlineExceeded {
				log.Info("short-circuiting batch RPC",
					"req_id", GetReqID(ctx),
					"auth", GetAuthCtx(ctx),
					"batch_index", i,
				)
				batchRPCShortCircuitsTotal.Inc()
				return nil, false, context.DeadlineExceeded
			}

			start := i * s.maxUpstreamBatchSize
			end := int(math.Min(float64(start+s.maxUpstreamBatchSize), float64(len(cacheMisses))))
			elems := cacheMisses[start:end]
479
			res, err := s.BackendGroups[group.backendGroup].Forward(ctx, createBatchRequest(elems), isBatch)
480 481 482 483 484
			if err != nil {
				log.Error(
					"error forwarding RPC batch",
					"batch_size", len(elems),
					"backend_group", group,
485
					"req_id", GetReqID(ctx),
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
					"err", err,
				)
				res = nil
				for _, elem := range elems {
					res = append(res, NewRPCErrorRes(elem.Req.ID, err))
				}
			}

			for i := range elems {
				responses[elems[i].Index] = res[i]

				// TODO(inphi): batch put these
				if res[i].Error == nil && res[i].Result != nil {
					if err := s.cache.PutRPC(ctx, elems[i].Req, res[i]); err != nil {
						log.Warn(
							"cache put error",
							"req_id", GetReqID(ctx),
							"err", err,
						)
					}
				}
			}
508 509 510
		}
	}

511
	return responses, cached, nil
512 513 514
}

func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
515
	ctx := s.populateContext(w, r)
516 517 518 519
	if ctx == nil {
		return
	}

520 521
	log.Info("received WS connection", "req_id", GetReqID(ctx))

522
	clientConn, err := s.upgrader.Upgrade(w, r, nil)
Matthew Slipper's avatar
Matthew Slipper committed
523
	if err != nil {
524
		log.Error("error upgrading client conn", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
Matthew Slipper's avatar
Matthew Slipper committed
525 526 527
		return
	}

528
	proxier, err := s.wsBackendGroup.ProxyWS(ctx, clientConn, s.wsMethodWhitelist)
Matthew Slipper's avatar
Matthew Slipper committed
529
	if err != nil {
530
		if errors.Is(err, ErrNoBackends) {
531
			RecordUnserviceableRequest(ctx, RPCRequestSourceWS)
Matthew Slipper's avatar
Matthew Slipper committed
532
		}
533
		log.Error("error dialing ws backend", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
534
		clientConn.Close()
Matthew Slipper's avatar
Matthew Slipper committed
535 536 537
		return
	}

538
	activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Inc()
539 540
	go func() {
		// Below call blocks so run it in a goroutine.
541
		if err := proxier.Proxy(ctx); err != nil {
542
			log.Error("error proxying websocket", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx), "err", err)
543
		}
544
		activeClientWsConnsGauge.WithLabelValues(GetAuthCtx(ctx)).Dec()
545
	}()
546 547

	log.Info("accepted WS connection", "auth", GetAuthCtx(ctx), "req_id", GetReqID(ctx))
Matthew Slipper's avatar
Matthew Slipper committed
548 549
}

550
func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context.Context {
551 552
	vars := mux.Vars(r)
	authorization := vars["authorization"]
553 554 555 556 557 558 559 560
	xff := r.Header.Get("X-Forwarded-For")
	if xff == "" {
		ipPort := strings.Split(r.RemoteAddr, ":")
		if len(ipPort) == 2 {
			xff = ipPort[0]
		}
	}
	ctx := context.WithValue(r.Context(), ContextKeyXForwardedFor, xff) // nolint:staticcheck
561

562
	if len(s.authenticatedPaths) == 0 {
563 564 565
		// handle the edge case where auth is disabled
		// but someone sends in an auth key anyway
		if authorization != "" {
566
			log.Info("blocked authenticated request against unauthenticated proxy")
567
			httpResponseCodesTotal.WithLabelValues("404").Inc()
568 569 570
			w.WriteHeader(404)
			return nil
		}
571 572 573 574 575 576
	} else {
		if authorization == "" || s.authenticatedPaths[authorization] == "" {
			log.Info("blocked unauthorized request", "authorization", authorization)
			httpResponseCodesTotal.WithLabelValues("401").Inc()
			w.WriteHeader(401)
			return nil
577
		}
578

579
		ctx = context.WithValue(ctx, ContextKeyAuth, s.authenticatedPaths[authorization]) // nolint:staticcheck
580 581
	}

582 583
	return context.WithValue(
		ctx,
584
		ContextKeyReqID, // nolint:staticcheck
585 586
		randStr(10),
	)
587 588
}

589
func (s *Server) isUnlimitedOrigin(origin string) bool {
590 591 592 593 594 595 596
	for _, pat := range s.limExemptOrigins {
		if pat.MatchString(origin) {
			return true
		}
	}

	return false
597 598 599
}

func (s *Server) isUnlimitedUserAgent(origin string) bool {
600 601 602 603 604 605
	for _, pat := range s.limExemptUserAgents {
		if pat.MatchString(origin) {
			return true
		}
	}
	return false
606 607
}

608 609 610 611
func (s *Server) isGlobalLimit(method string) bool {
	return s.globallyLimitedMethods[method]
}

612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
func (s *Server) rateLimitSender(ctx context.Context, req *RPCReq) error {
	var params []string
	if err := json.Unmarshal(req.Params, &params); err != nil {
		log.Debug("error unmarshaling raw transaction params", "err", err, "req_Id", GetReqID(ctx))
		return ErrParseErr
	}

	if len(params) != 1 {
		log.Debug("raw transaction request has invalid number of params", "req_id", GetReqID(ctx))
		// The error below is identical to the one Geth responds with.
		return ErrInvalidParams("missing value for required argument 0")
	}

	var data hexutil.Bytes
	if err := data.UnmarshalText([]byte(params[0])); err != nil {
		log.Debug("error decoding raw tx data", "err", err, "req_id", GetReqID(ctx))
		// Geth returns the raw error from UnmarshalText.
		return ErrInvalidParams(err.Error())
	}

	// Inflates a types.Transaction object from the transaction's raw bytes.
	tx := new(types.Transaction)
	if err := tx.UnmarshalBinary(data); err != nil {
		log.Debug("could not unmarshal transaction", "err", err, "req_id", GetReqID(ctx))
		return ErrInvalidParams(err.Error())
	}

	// Convert the transaction into a Message object so that we can get the
	// sender. This method performs an ecrecover, which can be expensive.
	msg, err := tx.AsMessage(types.LatestSignerForChainID(tx.ChainId()), nil)
	if err != nil {
		log.Debug("could not get message from transaction", "err", err, "req_id", GetReqID(ctx))
		return ErrInvalidParams(err.Error())
	}

647
	ok, err := s.senderLim.Take(ctx, fmt.Sprintf("%s:%d", msg.From().Hex(), tx.Nonce()))
648 649 650 651 652 653 654 655 656 657 658 659
	if err != nil {
		log.Error("error taking from sender limiter", "err", err, "req_id", GetReqID(ctx))
		return ErrInternal
	}
	if !ok {
		log.Debug("sender rate limit exceeded", "sender", msg.From(), "req_id", GetReqID(ctx))
		return ErrOverSenderRateLimit
	}

	return nil
}

660 661 662 663 664 665 666 667
func setCacheHeader(w http.ResponseWriter, cached bool) {
	if cached {
		w.Header().Set(cacheStatusHdr, "HIT")
	} else {
		w.Header().Set(cacheStatusHdr, "MISS")
	}
}

668
func writeRPCError(ctx context.Context, w http.ResponseWriter, id json.RawMessage, err error) {
669
	var res *RPCRes
670
	if r, ok := err.(*RPCErr); ok {
671
		res = NewRPCErrorRes(id, r)
672
	} else {
673
		res = NewRPCErrorRes(id, ErrInternal)
Matthew Slipper's avatar
Matthew Slipper committed
674
	}
675
	writeRPCRes(ctx, w, res)
676 677
}

678
func writeRPCRes(ctx context.Context, w http.ResponseWriter, res *RPCRes) {
679 680 681 682
	statusCode := 200
	if res.IsError() && res.Error.HTTPErrorCode != 0 {
		statusCode = res.Error.HTTPErrorCode
	}
683

Matthew Slipper's avatar
Matthew Slipper committed
684
	w.Header().Set("content-type", "application/json")
685
	w.WriteHeader(statusCode)
686 687
	ww := &recordLenWriter{Writer: w}
	enc := json.NewEncoder(ww)
688 689
	if err := enc.Encode(res); err != nil {
		log.Error("error writing rpc response", "err", err)
690 691
		RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
		return
Matthew Slipper's avatar
Matthew Slipper committed
692
	}
693
	httpResponseCodesTotal.WithLabelValues(strconv.Itoa(statusCode)).Inc()
694
	RecordResponsePayloadSize(ctx, ww.Len)
Matthew Slipper's avatar
Matthew Slipper committed
695 696
}

697
func writeBatchRPCRes(ctx context.Context, w http.ResponseWriter, res []*RPCRes) {
Matthew Slipper's avatar
Matthew Slipper committed
698
	w.Header().Set("content-type", "application/json")
699 700 701 702 703 704 705 706 707 708 709
	w.WriteHeader(200)
	ww := &recordLenWriter{Writer: w}
	enc := json.NewEncoder(ww)
	if err := enc.Encode(res); err != nil {
		log.Error("error writing batch rpc response", "err", err)
		RecordRPCError(ctx, BackendProxyd, MethodUnknown, err)
		return
	}
	RecordResponsePayloadSize(ctx, ww.Len)
}

Matthew Slipper's avatar
Matthew Slipper committed
710 711
func instrumentedHdlr(h http.Handler) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
712
		respTimer := prometheus.NewTimer(httpRequestDurationSumm)
Matthew Slipper's avatar
Matthew Slipper committed
713
		h.ServeHTTP(w, r)
714
		respTimer.ObserveDuration()
Matthew Slipper's avatar
Matthew Slipper committed
715 716
	}
}
717 718 719 720 721 722 723 724 725

func GetAuthCtx(ctx context.Context) string {
	authUser, ok := ctx.Value(ContextKeyAuth).(string)
	if !ok {
		return "none"
	}

	return authUser
}
726 727 728 729 730 731 732 733

func GetReqID(ctx context.Context) string {
	reqId, ok := ctx.Value(ContextKeyReqID).(string)
	if !ok {
		return ""
	}
	return reqId
}
734 735 736 737 738 739 740 741

func GetXForwardedFor(ctx context.Context) string {
	xff, ok := ctx.Value(ContextKeyXForwardedFor).(string)
	if !ok {
		return ""
	}
	return xff
}
742 743 744 745 746 747 748 749 750 751 752

type recordLenWriter struct {
	io.Writer
	Len int
}

func (w *recordLenWriter) Write(p []byte) (n int, err error) {
	n, err = w.Writer.Write(p)
	w.Len += n
	return
}
inphi's avatar
inphi committed
753 754 755 756 757 758 759 760 761 762

type NoopRPCCache struct{}

func (n *NoopRPCCache) GetRPC(context.Context, *RPCReq) (*RPCRes, error) {
	return nil, nil
}

func (n *NoopRPCCache) PutRPC(context.Context, *RPCReq, *RPCRes) error {
	return nil
}
763

764 765 766 767 768 769 770
func truncate(str string, maxLen int) string {
	if maxLen == 0 {
		maxLen = maxRequestBodyLogLen
	}

	if len(str) > maxLen {
		return str[:maxLen] + "..."
771 772 773 774
	} else {
		return str
	}
}
775 776 777 778 779 780 781 782 783 784 785 786 787

type batchElem struct {
	Req   *RPCReq
	Index int
}

func createBatchRequest(elems []batchElem) []*RPCReq {
	batch := make([]*RPCReq, len(elems))
	for i := range elems {
		batch[i] = elems[i].Req
	}
	return batch
}