backend.go 20.3 KB
Newer Older
Matthew Slipper's avatar
Matthew Slipper committed
1 2 3
package proxyd

import (
4
	"bytes"
5
	"context"
6
	"crypto/tls"
7 8 9 10 11 12 13
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"math"
	"math/rand"
	"net/http"
14
	"sort"
15
	"strconv"
16
	"strings"
17
	"sync"
18
	"time"
19 20 21 22

	"github.com/ethereum/go-ethereum/log"
	"github.com/gorilla/websocket"
	"github.com/prometheus/client_golang/prometheus"
23
	"golang.org/x/sync/semaphore"
Matthew Slipper's avatar
Matthew Slipper committed
24 25 26
)

const (
27 28
	JSONRPCVersion       = "2.0"
	JSONRPCErrorInternal = -32000
Matthew Slipper's avatar
Matthew Slipper committed
29 30 31
)

var (
32
	ErrParseErr = &RPCErr{
33 34 35
		Code:          -32700,
		Message:       "parse error",
		HTTPErrorCode: 400,
36 37
	}
	ErrInternal = &RPCErr{
38 39 40
		Code:          JSONRPCErrorInternal,
		Message:       "internal error",
		HTTPErrorCode: 500,
41 42
	}
	ErrMethodNotWhitelisted = &RPCErr{
43 44 45
		Code:          JSONRPCErrorInternal - 1,
		Message:       "rpc method is not whitelisted",
		HTTPErrorCode: 403,
46 47
	}
	ErrBackendOffline = &RPCErr{
48 49 50
		Code:          JSONRPCErrorInternal - 10,
		Message:       "backend offline",
		HTTPErrorCode: 503,
51 52
	}
	ErrNoBackends = &RPCErr{
53 54 55
		Code:          JSONRPCErrorInternal - 11,
		Message:       "no backends available for method",
		HTTPErrorCode: 503,
56 57
	}
	ErrBackendOverCapacity = &RPCErr{
58 59 60
		Code:          JSONRPCErrorInternal - 12,
		Message:       "backend is over capacity",
		HTTPErrorCode: 429,
61 62
	}
	ErrBackendBadResponse = &RPCErr{
63 64 65
		Code:          JSONRPCErrorInternal - 13,
		Message:       "backend returned an invalid response",
		HTTPErrorCode: 500,
66
	}
67 68 69 70
	ErrTooManyBatchRequests = &RPCErr{
		Code:    JSONRPCErrorInternal - 14,
		Message: "too many RPC calls in batch request",
	}
71 72 73 74 75
	ErrGatewayTimeout = &RPCErr{
		Code:          JSONRPCErrorInternal - 15,
		Message:       "gateway timeout",
		HTTPErrorCode: 504,
	}
76 77
	ErrOverRateLimit = &RPCErr{
		Code:          JSONRPCErrorInternal - 16,
78
		Message:       "over rate limit",
79 80
		HTTPErrorCode: 429,
	}
81 82

	ErrBackendUnexpectedJSONRPC = errors.New("backend returned an unexpected JSON-RPC response")
Matthew Slipper's avatar
Matthew Slipper committed
83 84
)

85 86 87 88 89 90 91 92
func ErrInvalidRequest(msg string) *RPCErr {
	return &RPCErr{
		Code:          -32601,
		Message:       msg,
		HTTPErrorCode: 400,
	}
}

Matthew Slipper's avatar
Matthew Slipper committed
93
type Backend struct {
94 95 96 97 98
	Name                 string
	rpcURL               string
	wsURL                string
	authUsername         string
	authPassword         string
99
	rateLimiter          BackendRateLimiter
100
	client               *LimitedHTTPClient
101 102 103 104 105 106
	dialer               *websocket.Dialer
	maxRetries           int
	maxResponseSize      int64
	maxRPS               int
	maxWSConns           int
	outOfServiceInterval time.Duration
107
	stripTrailingXFF     bool
inphi's avatar
inphi committed
108
	proxydIP             string
Matthew Slipper's avatar
Matthew Slipper committed
109 110 111 112 113
}

type BackendOpt func(b *Backend)

func WithBasicAuth(username, password string) BackendOpt {
114 115 116 117
	return func(b *Backend) {
		b.authUsername = username
		b.authPassword = password
	}
Matthew Slipper's avatar
Matthew Slipper committed
118 119 120
}

func WithTimeout(timeout time.Duration) BackendOpt {
121 122 123
	return func(b *Backend) {
		b.client.Timeout = timeout
	}
Matthew Slipper's avatar
Matthew Slipper committed
124 125 126
}

func WithMaxRetries(retries int) BackendOpt {
127 128 129
	return func(b *Backend) {
		b.maxRetries = retries
	}
Matthew Slipper's avatar
Matthew Slipper committed
130 131 132
}

func WithMaxResponseSize(size int64) BackendOpt {
133 134 135
	return func(b *Backend) {
		b.maxResponseSize = size
	}
Matthew Slipper's avatar
Matthew Slipper committed
136 137
}

138
func WithOutOfServiceDuration(interval time.Duration) BackendOpt {
139
	return func(b *Backend) {
140
		b.outOfServiceInterval = interval
141
	}
Matthew Slipper's avatar
Matthew Slipper committed
142 143
}

144 145 146 147 148 149 150 151 152 153 154 155
func WithMaxRPS(maxRPS int) BackendOpt {
	return func(b *Backend) {
		b.maxRPS = maxRPS
	}
}

func WithMaxWSConns(maxConns int) BackendOpt {
	return func(b *Backend) {
		b.maxWSConns = maxConns
	}
}

156 157 158 159 160 161 162 163 164
func WithTLSConfig(tlsConfig *tls.Config) BackendOpt {
	return func(b *Backend) {
		if b.client.Transport == nil {
			b.client.Transport = &http.Transport{}
		}
		b.client.Transport.(*http.Transport).TLSClientConfig = tlsConfig
	}
}

165 166 167 168 169 170
func WithStrippedTrailingXFF() BackendOpt {
	return func(b *Backend) {
		b.stripTrailingXFF = true
	}
}

inphi's avatar
inphi committed
171 172 173 174 175 176
func WithProxydIP(ip string) BackendOpt {
	return func(b *Backend) {
		b.proxydIP = ip
	}
}

177 178 179 180
func NewBackend(
	name string,
	rpcURL string,
	wsURL string,
181
	rateLimiter BackendRateLimiter,
182
	rpcSemaphore *semaphore.Weighted,
183 184
	opts ...BackendOpt,
) *Backend {
185
	backend := &Backend{
186 187 188
		Name:            name,
		rpcURL:          rpcURL,
		wsURL:           wsURL,
189
		rateLimiter:     rateLimiter,
190
		maxResponseSize: math.MaxInt64,
191 192 193 194
		client: &LimitedHTTPClient{
			Client:      http.Client{Timeout: 5 * time.Second},
			sem:         rpcSemaphore,
			backendName: name,
195
		},
196
		dialer: &websocket.Dialer{},
197 198 199 200 201 202
	}

	for _, opt := range opts {
		opt(backend)
	}

inphi's avatar
inphi committed
203 204 205 206
	if !backend.stripTrailingXFF && backend.proxydIP == "" {
		log.Warn("proxied requests' XFF header will not contain the proxyd ip address")
	}

207
	return backend
Matthew Slipper's avatar
Matthew Slipper committed
208 209
}

210
func (b *Backend) Forward(ctx context.Context, reqs []*RPCReq, isBatch bool) ([]*RPCRes, error) {
211
	if !b.Online() {
212
		RecordBatchRPCError(ctx, b.Name, reqs, ErrBackendOffline)
213 214
		return nil, ErrBackendOffline
	}
215
	if b.IsRateLimited() {
216
		RecordBatchRPCError(ctx, b.Name, reqs, ErrBackendOverCapacity)
217 218
		return nil, ErrBackendOverCapacity
	}
219 220 221 222 223

	var lastError error
	// <= to account for the first attempt not technically being
	// a retry
	for i := 0; i <= b.maxRetries; i++ {
224 225 226 227 228 229 230 231 232 233 234 235 236 237
		RecordBatchRPCForward(ctx, b.Name, reqs, RPCRequestSourceHTTP)
		metricLabelMethod := reqs[0].Method
		if isBatch {
			metricLabelMethod = "<batch>"
		}
		timer := prometheus.NewTimer(
			rpcBackendRequestDurationSumm.WithLabelValues(
				b.Name,
				metricLabelMethod,
				strconv.FormatBool(isBatch),
			),
		)

		res, err := b.doForward(ctx, reqs, isBatch)
238 239 240 241 242 243 244 245 246 247 248 249 250 251
		switch err {
		case nil: // do nothing
		// ErrBackendUnexpectedJSONRPC occurs because infura responds with a single JSON-RPC object
		// to a batch request whenever any Request Object in the batch would induce a partial error.
		// We don't label the the backend offline in this case. But the error is still returned to
		// callers so failover can occur if needed.
		case ErrBackendUnexpectedJSONRPC:
			log.Debug(
				"Reecived unexpected JSON-RPC response",
				"name", b.Name,
				"req_id", GetReqID(ctx),
				"err", err,
			)
		default:
252
			lastError = err
253 254 255 256 257 258
			log.Warn(
				"backend request failed, trying again",
				"name", b.Name,
				"req_id", GetReqID(ctx),
				"err", err,
			)
259 260
			timer.ObserveDuration()
			RecordBatchRPCError(ctx, b.Name, reqs, err)
261
			sleepContext(ctx, calcBackoff(i))
262 263
			continue
		}
264 265 266
		timer.ObserveDuration()

		MaybeRecordErrorsInRPCRes(ctx, b.Name, reqs, res)
267
		return res, err
268 269
	}

270
	b.setOffline()
271
	return nil, wrapErr(lastError, "permanent error forwarding request")
Matthew Slipper's avatar
Matthew Slipper committed
272 273
}

274
func (b *Backend) ProxyWS(clientConn *websocket.Conn, methodWhitelist *StringSet) (*WSProxier, error) {
275 276 277 278 279 280 281
	if !b.Online() {
		return nil, ErrBackendOffline
	}
	if b.IsWSSaturated() {
		return nil, ErrBackendOverCapacity
	}

282
	backendConn, _, err := b.dialer.Dial(b.wsURL, nil) // nolint:bodyclose
283 284
	if err != nil {
		b.setOffline()
285
		if err := b.rateLimiter.DecBackendWSConns(b.Name); err != nil {
286 287 288 289 290 291
			log.Error("error decrementing backend ws conns", "name", b.Name, "err", err)
		}
		return nil, wrapErr(err, "error dialing backend")
	}

	activeBackendWsConnsGauge.WithLabelValues(b.Name).Inc()
292
	return NewWSProxier(b, clientConn, backendConn, methodWhitelist), nil
293 294 295
}

func (b *Backend) Online() bool {
296
	online, err := b.rateLimiter.IsBackendOnline(b.Name)
297
	if err != nil {
298 299 300 301 302 303 304 305 306 307 308 309 310
		log.Warn(
			"error getting backend availability, assuming it is offline",
			"name", b.Name,
			"err", err,
		)
		return false
	}
	return online
}

func (b *Backend) IsRateLimited() bool {
	if b.maxRPS == 0 {
		return false
311 312
	}

313
	usedLimit, err := b.rateLimiter.IncBackendRPS(b.Name)
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
	if err != nil {
		log.Error(
			"error getting backend used rate limit, assuming limit is exhausted",
			"name", b.Name,
			"err", err,
		)
		return true
	}

	return b.maxRPS < usedLimit
}

func (b *Backend) IsWSSaturated() bool {
	if b.maxWSConns == 0 {
		return false
	}

331
	incremented, err := b.rateLimiter.IncBackendWSConns(b.Name, b.maxWSConns)
332 333 334 335 336 337 338 339 340 341 342 343 344
	if err != nil {
		log.Error(
			"error getting backend used ws conns, assuming limit is exhausted",
			"name", b.Name,
			"err", err,
		)
		return true
	}

	return !incremented
}

func (b *Backend) setOffline() {
345
	err := b.rateLimiter.SetBackendOffline(b.Name, b.outOfServiceInterval)
346 347 348 349 350 351 352 353 354
	if err != nil {
		log.Warn(
			"error setting backend offline",
			"name", b.Name,
			"err", err,
		)
	}
}

355
func (b *Backend) doForward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool) ([]*RPCRes, error) {
356 357 358 359 360 361 362 363 364 365 366
	isSingleElementBatch := len(rpcReqs) == 1

	// Single element batches are unwrapped before being sent
	// since Alchemy handles single requests better than batches.

	var body []byte
	if isSingleElementBatch {
		body = mustMarshalJSON(rpcReqs[0])
	} else {
		body = mustMarshalJSON(rpcReqs)
	}
367

368
	httpReq, err := http.NewRequestWithContext(ctx, "POST", b.rpcURL, bytes.NewReader(body))
369 370 371 372 373 374 375 376
	if err != nil {
		return nil, wrapErr(err, "error creating backend request")
	}

	if b.authPassword != "" {
		httpReq.SetBasicAuth(b.authUsername, b.authPassword)
	}

377 378
	xForwardedFor := GetXForwardedFor(ctx)
	if b.stripTrailingXFF {
379
		xForwardedFor = stripXFF(xForwardedFor)
inphi's avatar
inphi committed
380 381
	} else if b.proxydIP != "" {
		xForwardedFor = fmt.Sprintf("%s, %s", xForwardedFor, b.proxydIP)
382 383
	}

384
	httpReq.Header.Set("content-type", "application/json")
385
	httpReq.Header.Set("X-Forwarded-For", xForwardedFor)
386

387
	httpRes, err := b.client.DoLimited(httpReq)
388 389 390 391
	if err != nil {
		return nil, wrapErr(err, "error in backend request")
	}

392 393 394 395
	metricLabelMethod := rpcReqs[0].Method
	if isBatch {
		metricLabelMethod = "<batch>"
	}
396 397 398
	rpcBackendHTTPResponseCodesTotal.WithLabelValues(
		GetAuthCtx(ctx),
		b.Name,
399
		metricLabelMethod,
400
		strconv.Itoa(httpRes.StatusCode),
401
		strconv.FormatBool(isBatch),
402 403
	).Inc()

404 405 406
	// Alchemy returns a 400 on bad JSONs, so handle that case
	if httpRes.StatusCode != 200 && httpRes.StatusCode != 400 {
		return nil, fmt.Errorf("response code %d", httpRes.StatusCode)
407 408
	}

409
	defer httpRes.Body.Close()
410
	resB, err := io.ReadAll(io.LimitReader(httpRes.Body, b.maxResponseSize))
411 412 413 414
	if err != nil {
		return nil, wrapErr(err, "error reading response body")
	}

415
	var res []*RPCRes
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
	if isSingleElementBatch {
		var singleRes RPCRes
		if err := json.Unmarshal(resB, &singleRes); err != nil {
			return nil, ErrBackendBadResponse
		}
		res = []*RPCRes{
			&singleRes,
		}
	} else {
		if err := json.Unmarshal(resB, &res); err != nil {
			// Infura may return a single JSON-RPC response if, for example, the batch contains a request for an unsupported method
			if responseIsNotBatched(resB) {
				return nil, ErrBackendUnexpectedJSONRPC
			}
			return nil, ErrBackendBadResponse
431
		}
432 433 434
	}

	if len(rpcReqs) != len(res) {
435
		return nil, ErrBackendUnexpectedJSONRPC
436 437
	}

438 439 440
	// capture the HTTP status code in the response. this will only
	// ever be 400 given the status check on line 318 above.
	if httpRes.StatusCode != 200 {
441 442 443
		for _, res := range res {
			res.Error.HTTPErrorCode = httpRes.StatusCode
		}
444 445
	}

446
	sortBatchRPCResponse(rpcReqs, res)
447
	return res, nil
Matthew Slipper's avatar
Matthew Slipper committed
448 449
}

450 451 452 453 454
func responseIsNotBatched(b []byte) bool {
	var r RPCRes
	return json.Unmarshal(b, &r) == nil
}

455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
// sortBatchRPCResponse sorts the RPCRes slice according to the position of its corresponding ID in the RPCReq slice
func sortBatchRPCResponse(req []*RPCReq, res []*RPCRes) {
	pos := make(map[string]int, len(req))
	for i, r := range req {
		key := string(r.ID)
		if _, ok := pos[key]; ok {
			panic("bug! detected requests with duplicate IDs")
		}
		pos[key] = i
	}

	sort.Slice(res, func(i, j int) bool {
		l := res[i].ID
		r := res[j].ID
		return pos[string(l)] < pos[string(r)]
	})
}

Matthew Slipper's avatar
Matthew Slipper committed
473
type BackendGroup struct {
474
	Name     string
475
	Backends []*Backend
Matthew Slipper's avatar
Matthew Slipper committed
476 477
}

478 479 480 481 482
func (b *BackendGroup) Forward(ctx context.Context, rpcReqs []*RPCReq, isBatch bool) ([]*RPCRes, error) {
	if len(rpcReqs) == 0 {
		return nil, nil
	}

483 484 485
	rpcRequestsTotal.Inc()

	for _, back := range b.Backends {
486
		res, err := back.Forward(ctx, rpcReqs, isBatch)
487 488 489 490
		if errors.Is(err, ErrMethodNotWhitelisted) {
			return nil, err
		}
		if errors.Is(err, ErrBackendOffline) {
491 492 493 494 495 496
			log.Warn(
				"skipping offline backend",
				"name", back.Name,
				"auth", GetAuthCtx(ctx),
				"req_id", GetReqID(ctx),
			)
497 498
			continue
		}
499
		if errors.Is(err, ErrBackendOverCapacity) {
500 501 502 503 504 505
			log.Warn(
				"skipping over-capacity backend",
				"name", back.Name,
				"auth", GetAuthCtx(ctx),
				"req_id", GetReqID(ctx),
			)
506 507
			continue
		}
508
		if err != nil {
509 510
			log.Error(
				"error forwarding request to backend",
511
				"name", back.Name,
512 513 514 515
				"req_id", GetReqID(ctx),
				"auth", GetAuthCtx(ctx),
				"err", err,
			)
516 517
			continue
		}
518
		return res, nil
519 520
	}

521
	RecordUnserviceableRequest(ctx, RPCRequestSourceHTTP)
522 523 524
	return nil, ErrNoBackends
}

525
func (b *BackendGroup) ProxyWS(ctx context.Context, clientConn *websocket.Conn, methodWhitelist *StringSet) (*WSProxier, error) {
526
	for _, back := range b.Backends {
527
		proxier, err := back.ProxyWS(clientConn, methodWhitelist)
528
		if errors.Is(err, ErrBackendOffline) {
529 530 531 532 533 534
			log.Warn(
				"skipping offline backend",
				"name", back.Name,
				"req_id", GetReqID(ctx),
				"auth", GetAuthCtx(ctx),
			)
535 536 537
			continue
		}
		if errors.Is(err, ErrBackendOverCapacity) {
538 539 540 541 542 543
			log.Warn(
				"skipping over-capacity backend",
				"name", back.Name,
				"req_id", GetReqID(ctx),
				"auth", GetAuthCtx(ctx),
			)
544 545 546
			continue
		}
		if err != nil {
547 548 549 550 551 552 553
			log.Warn(
				"error dialing ws backend",
				"name", back.Name,
				"req_id", GetReqID(ctx),
				"auth", GetAuthCtx(ctx),
				"err", err,
			)
554 555 556
			continue
		}
		return proxier, nil
557 558
	}

559
	return nil, ErrNoBackends
Matthew Slipper's avatar
Matthew Slipper committed
560 561
}

562 563
func calcBackoff(i int) time.Duration {
	jitter := float64(rand.Int63n(250))
564
	ms := math.Min(math.Pow(2, float64(i))*1000+jitter, 3000)
565
	return time.Duration(ms) * time.Millisecond
Matthew Slipper's avatar
Matthew Slipper committed
566 567
}

568
type WSProxier struct {
569 570 571 572
	backend         *Backend
	clientConn      *websocket.Conn
	backendConn     *websocket.Conn
	methodWhitelist *StringSet
573
	clientConnMu    sync.Mutex
Matthew Slipper's avatar
Matthew Slipper committed
574 575
}

576
func NewWSProxier(backend *Backend, clientConn, backendConn *websocket.Conn, methodWhitelist *StringSet) *WSProxier {
577
	return &WSProxier{
578 579 580 581
		backend:         backend,
		clientConn:      clientConn,
		backendConn:     backendConn,
		methodWhitelist: methodWhitelist,
582
	}
Matthew Slipper's avatar
Matthew Slipper committed
583 584
}

585
func (w *WSProxier) Proxy(ctx context.Context) error {
586
	errC := make(chan error, 2)
587 588
	go w.clientPump(ctx, errC)
	go w.backendPump(ctx, errC)
589 590 591 592 593
	err := <-errC
	w.close()
	return err
}

594
func (w *WSProxier) clientPump(ctx context.Context, errC chan error) {
595 596 597 598 599
	for {
		// Block until we get a message.
		msgType, msg, err := w.clientConn.ReadMessage()
		if err != nil {
			errC <- err
600
			if err := w.backendConn.WriteMessage(websocket.CloseMessage, formatWSError(err)); err != nil {
601 602
				log.Error("error writing backendConn message", "err", err)
			}
603 604 605
			return
		}

606
		RecordWSMessage(ctx, w.backend.Name, SourceClient)
607 608 609 610

		// Route control messages to the backend. These don't
		// count towards the total RPC requests count.
		if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
611
			err := w.backendConn.WriteMessage(msgType, msg)
612 613 614 615 616 617 618 619 620 621 622
			if err != nil {
				errC <- err
				return
			}
			continue
		}

		rpcRequestsTotal.Inc()

		// Don't bother sending invalid requests to the backend,
		// just handle them here.
623
		req, err := w.prepareClientMsg(msg)
624
		if err != nil {
625
			var id json.RawMessage
626
			method := MethodUnknown
627 628
			if req != nil {
				id = req.ID
629
				method = req.Method
630
			}
631 632 633 634 635 636
			log.Info(
				"error preparing client message",
				"auth", GetAuthCtx(ctx),
				"req_id", GetReqID(ctx),
				"err", err,
			)
637
			msg = mustMarshalJSON(NewRPCErrorRes(id, err))
638
			RecordRPCError(ctx, BackendProxyd, method, err)
639 640 641 642 643 644 645 646

			// Send error response to client
			err = w.writeClientConn(msgType, msg)
			if err != nil {
				errC <- err
				return
			}
			continue
647 648
		}

649 650 651 652 653 654 655 656 657 658 659 660
		// Send eth_accounts requests directly to the client
		if req.Method == "eth_accounts" {
			msg = mustMarshalJSON(NewRPCRes(req.ID, emptyArrayResponse))
			RecordRPCForward(ctx, BackendProxyd, "eth_accounts", RPCRequestSourceWS)
			err = w.writeClientConn(msgType, msg)
			if err != nil {
				errC <- err
				return
			}
			continue
		}

661 662 663 664 665 666 667 668 669
		RecordRPCForward(ctx, w.backend.Name, req.Method, RPCRequestSourceWS)
		log.Info(
			"forwarded WS message to backend",
			"method", req.Method,
			"auth", GetAuthCtx(ctx),
			"req_id", GetReqID(ctx),
		)

		err = w.backendConn.WriteMessage(msgType, msg)
670 671 672 673 674 675 676
		if err != nil {
			errC <- err
			return
		}
	}
}

677
func (w *WSProxier) backendPump(ctx context.Context, errC chan error) {
678 679 680 681 682
	for {
		// Block until we get a message.
		msgType, msg, err := w.backendConn.ReadMessage()
		if err != nil {
			errC <- err
683
			if err := w.writeClientConn(websocket.CloseMessage, formatWSError(err)); err != nil {
684 685
				log.Error("error writing clientConn message", "err", err)
			}
686 687 688
			return
		}

689
		RecordWSMessage(ctx, w.backend.Name, SourceBackend)
690 691 692

		// Route control messages directly to the client.
		if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
693
			err := w.writeClientConn(msgType, msg)
694 695 696 697 698 699 700 701 702
			if err != nil {
				errC <- err
				return
			}
			continue
		}

		res, err := w.parseBackendMsg(msg)
		if err != nil {
703
			var id json.RawMessage
704 705 706 707
			if res != nil {
				id = res.ID
			}
			msg = mustMarshalJSON(NewRPCErrorRes(id, err))
708
			log.Info("backend responded with error", "err", err)
709
		} else {
710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
			if res.IsError() {
				log.Info(
					"backend responded with RPC error",
					"code", res.Error.Code,
					"msg", res.Error.Message,
					"source", "ws",
					"auth", GetAuthCtx(ctx),
					"req_id", GetReqID(ctx),
				)
				RecordRPCError(ctx, w.backend.Name, MethodUnknown, res.Error)
			} else {
				log.Info(
					"forwarded WS message to client",
					"auth", GetAuthCtx(ctx),
					"req_id", GetReqID(ctx),
				)
			}
727 728
		}

729
		err = w.writeClientConn(msgType, msg)
730 731 732 733 734 735 736 737 738 739
		if err != nil {
			errC <- err
			return
		}
	}
}

func (w *WSProxier) close() {
	w.clientConn.Close()
	w.backendConn.Close()
740
	if err := w.backend.rateLimiter.DecBackendWSConns(w.backend.Name); err != nil {
741 742 743 744 745
		log.Error("error decrementing backend ws conns", "name", w.backend.Name, "err", err)
	}
	activeBackendWsConnsGauge.WithLabelValues(w.backend.Name).Dec()
}

746
func (w *WSProxier) prepareClientMsg(msg []byte) (*RPCReq, error) {
747
	req, err := ParseRPCReq(msg)
748 749 750 751
	if err != nil {
		return nil, err
	}

752
	if !w.methodWhitelist.Has(req.Method) {
753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771
		return req, ErrMethodNotWhitelisted
	}

	if w.backend.IsRateLimited() {
		return req, ErrBackendOverCapacity
	}

	return req, nil
}

func (w *WSProxier) parseBackendMsg(msg []byte) (*RPCRes, error) {
	res, err := ParseRPCRes(bytes.NewReader(msg))
	if err != nil {
		log.Warn("error parsing RPC response", "source", "ws", "err", err)
		return res, ErrBackendBadResponse
	}
	return res, nil
}

772 773 774 775 776 777 778
func (w *WSProxier) writeClientConn(msgType int, msg []byte) error {
	w.clientConnMu.Lock()
	err := w.clientConn.WriteMessage(msgType, msg)
	w.clientConnMu.Unlock()
	return err
}

779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794
func mustMarshalJSON(in interface{}) []byte {
	out, err := json.Marshal(in)
	if err != nil {
		panic(err)
	}
	return out
}

func formatWSError(err error) []byte {
	m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
	if e, ok := err.(*websocket.CloseError); ok {
		if e.Code != websocket.CloseNoStatusReceived {
			m = websocket.FormatCloseMessage(e.Code, e.Text)
		}
	}
	return m
Matthew Slipper's avatar
Matthew Slipper committed
795
}
796 797 798 799 800 801 802

func sleepContext(ctx context.Context, duration time.Duration) {
	select {
	case <-ctx.Done():
	case <-time.After(duration):
	}
}
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817

type LimitedHTTPClient struct {
	http.Client
	sem         *semaphore.Weighted
	backendName string
}

func (c *LimitedHTTPClient) DoLimited(req *http.Request) (*http.Response, error) {
	if err := c.sem.Acquire(req.Context(), 1); err != nil {
		tooManyRequestErrorsTotal.WithLabelValues(c.backendName).Inc()
		return nil, wrapErr(err, "too many requests")
	}
	defer c.sem.Release(1)
	return c.Do(req)
}
818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858

func RecordBatchRPCError(ctx context.Context, backendName string, reqs []*RPCReq, err error) {
	for _, req := range reqs {
		RecordRPCError(ctx, backendName, req.Method, err)
	}
}

func MaybeRecordErrorsInRPCRes(ctx context.Context, backendName string, reqs []*RPCReq, resBatch []*RPCRes) {
	log.Info("forwarded RPC request",
		"backend", backendName,
		"auth", GetAuthCtx(ctx),
		"req_id", GetReqID(ctx),
		"batch_size", len(reqs),
	)

	var lastError *RPCErr
	for i, res := range resBatch {
		if res.IsError() {
			lastError = res.Error
			RecordRPCError(ctx, backendName, reqs[i].Method, res.Error)
		}
	}

	if lastError != nil {
		log.Info(
			"backend responded with RPC error",
			"backend", backendName,
			"last_error_code", lastError.Code,
			"last_error_msg", lastError.Message,
			"req_id", GetReqID(ctx),
			"source", "rpc",
			"auth", GetAuthCtx(ctx),
		)
	}
}

func RecordBatchRPCForward(ctx context.Context, backendName string, reqs []*RPCReq, source string) {
	for _, req := range reqs {
		RecordRPCForward(ctx, backendName, req.Method, source)
	}
}
859 860 861 862 863

func stripXFF(xff string) string {
	ipList := strings.Split(xff, ", ")
	return strings.TrimSpace(ipList[0])
}