Commit c4338dad authored by significance's avatar significance Committed by GitHub

Apply CORS Options for Debug Api (#1288)

parent 0dd6ddb8
......@@ -10,6 +10,7 @@ package debugapi
import (
"crypto/ecdsa"
"net/http"
"unicode/utf8"
"github.com/ethereum/go-ethereum/common"
"github.com/ethersphere/bee/pkg/accounting"
......@@ -50,10 +51,59 @@ type server struct {
Chequebook chequebook.Service
Swap swap.ApiInterface
metricsRegistry *prometheus.Registry
Options
http.Handler
}
func New(overlay swarm.Address, publicKey, pssPublicKey ecdsa.PublicKey, ethereumAddress common.Address, p2p p2p.DebugService, pingpong pingpong.Interface, topologyDriver topology.Driver, storer storage.Storer, logger logging.Logger, tracer *tracing.Tracer, tags *tags.Tags, accounting accounting.Interface, settlement settlement.Interface, chequebookEnabled bool, swap swap.ApiInterface, chequebook chequebook.Service) Service {
type Options struct {
CORSAllowedOrigins []string
}
// checkOrigin returns true if the origin is not set or is equal to the request host.
func (s *server) checkOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
hosts := append(s.CORSAllowedOrigins, scheme+"://"+r.Host)
for _, v := range hosts {
if equalASCIIFold(origin[0], v) || v == "*" {
return true
}
}
return false
}
// equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
func New(overlay swarm.Address, publicKey, pssPublicKey ecdsa.PublicKey, ethereumAddress common.Address, p2p p2p.DebugService, pingpong pingpong.Interface, topologyDriver topology.Driver, storer storage.Storer, logger logging.Logger, tracer *tracing.Tracer, tags *tags.Tags, accounting accounting.Interface, settlement settlement.Interface, chequebookEnabled bool, swap swap.ApiInterface, chequebook chequebook.Service, o Options) Service {
s := &server{
Overlay: overlay,
PublicKey: publicKey,
......
......@@ -57,7 +57,7 @@ func newTestServer(t *testing.T, o testServerOptions) *testServer {
settlement := swapmock.New(o.SettlementOpts...)
chequebook := chequebookmock.NewChequebook(o.ChequebookOpts...)
swapserv := swapmock.NewApiInterface(o.SwapOpts...)
s := debugapi.New(o.Overlay, o.PublicKey, o.PSSPublicKey, o.EthereumAddress, o.P2P, o.Pingpong, topologyDriver, o.Storer, logging.New(ioutil.Discard, 0), nil, o.Tags, acc, settlement, true, swapserv, chequebook)
s := debugapi.New(o.Overlay, o.PublicKey, o.PSSPublicKey, o.EthereumAddress, o.P2P, o.Pingpong, topologyDriver, o.Storer, logging.New(ioutil.Discard, 0), nil, o.Tags, acc, settlement, true, swapserv, chequebook, debugapi.Options{})
ts := httptest.NewServer(s)
t.Cleanup(ts.Close)
......
......@@ -152,6 +152,18 @@ func (s *server) setupRouting() {
baseRouter.Handle("/", web.ChainHandlers(
httpaccess.NewHTTPAccessLogHandler(s.Logger, logrus.InfoLevel, s.Tracer, "debug api access"),
handlers.CompressHandler,
func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if o := r.Header.Get("Origin"); o != "" && (len(s.CORSAllowedOrigins) == 0 || s.checkOrigin(r)) {
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Origin", o)
w.Header().Set("Access-Control-Allow-Headers", "Origin, Accept, Authorization, Content-Type, X-Requested-With, Access-Control-Request-Headers, Access-Control-Request-Method")
w.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, OPTIONS, POST, PUT, DELETE")
w.Header().Set("Access-Control-Max-Age", "3600")
}
h.ServeHTTP(w, r)
})
},
// todo: add recovery handler
web.NoCacheHeadersHandler,
web.FinalHandler(router),
......
......@@ -465,7 +465,9 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
if o.DebugAPIAddr != "" {
// Debug API server
debugAPIService := debugapi.New(swarmAddress, publicKey, pssPrivateKey.PublicKey, overlayEthAddress, p2ps, pingPong, kad, storer, logger, tracer, tagService, acc, settlement, o.SwapEnable, swapService, chequebookService)
debugAPIService := debugapi.New(swarmAddress, publicKey, pssPrivateKey.PublicKey, overlayEthAddress, p2ps, pingPong, kad, storer, logger, tracer, tagService, acc, settlement, o.SwapEnable, swapService, chequebookService, debugapi.Options{
CORSAllowedOrigins: o.CORSAllowedOrigins,
})
// register metrics from components
debugAPIService.MustRegisterMetrics(p2ps.Metrics()...)
debugAPIService.MustRegisterMetrics(pingPong.Metrics()...)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment