mock_backend_test.go 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
package integration_tests

import (
	"bytes"
	"context"
	"encoding/json"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"sync"
11 12

	"github.com/ethereum-optimism/optimism/go/proxyd"
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
)

type RecordedRequest struct {
	Method  string
	Headers http.Header
	Body    []byte
}

type MockBackend struct {
	handler  http.Handler
	server   *httptest.Server
	mtx      sync.RWMutex
	requests []*RecordedRequest
}

func SingleResponseHandler(code int, response string) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(code)
31
		_, _ = w.Write([]byte(response))
32 33 34
	}
}

35 36
type RPCResponseHandler struct {
	mtx          sync.RWMutex
37
	rpcResponses map[string]interface{}
38 39
}

40
func NewRPCResponseHandler(rpcResponses map[string]interface{}) *RPCResponseHandler {
41 42 43 44 45
	return &RPCResponseHandler{
		rpcResponses: rpcResponses,
	}
}

46
func (h *RPCResponseHandler) SetResponse(method string, response interface{}) {
47 48
	h.mtx.Lock()
	defer h.mtx.Unlock()
49 50 51 52 53 54 55 56 57

	switch response.(type) {
	case string:
	case nil:
		break
	default:
		panic("invalid response type")
	}

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
	h.rpcResponses[method] = response
}

func (h *RPCResponseHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	body, err := ioutil.ReadAll(r.Body)
	if err != nil {
		panic(err)
	}
	req, err := proxyd.ParseRPCReq(body)
	if err != nil {
		panic(err)
	}
	h.mtx.RLock()
	res := h.rpcResponses[req.Method]
	h.mtx.RUnlock()
	if res == "" {
		w.WriteHeader(400)
		return
	}
77

78 79 80 81 82 83 84 85
	out := &proxyd.RPCRes{
		JSONRPC: proxyd.JSONRPCVersion,
		Result:  res,
		ID:      req.ID,
	}
	enc := json.NewEncoder(w)
	if err := enc.Encode(out); err != nil {
		panic(err)
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
	}
}

func NewMockBackend(handler http.Handler) *MockBackend {
	mb := &MockBackend{
		handler: handler,
	}
	mb.server = httptest.NewServer(http.HandlerFunc(mb.wrappedHandler))
	return mb
}

func (m *MockBackend) URL() string {
	return m.server.URL
}

func (m *MockBackend) Close() {
	m.server.Close()
}

func (m *MockBackend) SetHandler(handler http.Handler) {
	m.mtx.Lock()
	m.handler = handler
	m.mtx.Unlock()
}

func (m *MockBackend) Reset() {
	m.mtx.Lock()
	m.requests = nil
	m.mtx.Unlock()
}

func (m *MockBackend) Requests() []*RecordedRequest {
	m.mtx.RLock()
	defer m.mtx.RUnlock()
	out := make([]*RecordedRequest, len(m.requests))
	for i := 0; i < len(m.requests); i++ {
		out[i] = m.requests[i]
	}
	return out
}

func (m *MockBackend) wrappedHandler(w http.ResponseWriter, r *http.Request) {
	m.mtx.Lock()
	body, err := ioutil.ReadAll(r.Body)
	if err != nil {
		panic(err)
	}
	clone := r.Clone(context.Background())
	clone.Body = ioutil.NopCloser(bytes.NewReader(body))
	m.requests = append(m.requests, &RecordedRequest{
		Method:  r.Method,
		Headers: r.Header.Clone(),
		Body:    body,
	})
	m.handler.ServeHTTP(w, clone)
	m.mtx.Unlock()
}