Commit f63c396f authored by Janoš Guljaš's avatar Janoš Guljaš Committed by GitHub

fix unconfigured cors origin (#1421)

parent 7988481f
......@@ -29,17 +29,18 @@ import (
)
type testServerOptions struct {
Storer storage.Storer
Resolver resolver.Interface
Pss pss.Interface
Traversal traversal.Service
WsPath string
Tags *tags.Tags
GatewayMode bool
WsPingPeriod time.Duration
Logger logging.Logger
PreventRedirect bool
Feeds feeds.Factory
Storer storage.Storer
Resolver resolver.Interface
Pss pss.Interface
Traversal traversal.Service
WsPath string
Tags *tags.Tags
GatewayMode bool
WsPingPeriod time.Duration
Logger logging.Logger
PreventRedirect bool
Feeds feeds.Factory
CORSAllowedOrigins []string
}
func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket.Conn, string) {
......@@ -53,8 +54,9 @@ func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket.
o.WsPingPeriod = 60 * time.Second
}
s := api.New(o.Tags, o.Storer, o.Resolver, o.Pss, o.Traversal, o.Feeds, o.Logger, nil, api.Options{
GatewayMode: o.GatewayMode,
WsPingPeriod: o.WsPingPeriod,
CORSAllowedOrigins: o.CORSAllowedOrigins,
GatewayMode: o.GatewayMode,
WsPingPeriod: o.WsPingPeriod,
})
ts := httptest.NewServer(s)
t.Cleanup(ts.Close)
......
// Copyright 2021 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package api_test
import (
"net/http"
"testing"
)
func TestCORSHeaders(t *testing.T) {
for _, tc := range []struct {
name string
origin string
allowedOrigins []string
wantCORS bool
}{
{
name: "none",
},
{
name: "no origin",
allowedOrigins: []string{"https://gateway.ethswarm.org"},
wantCORS: false,
},
{
name: "single explicit",
origin: "https://gateway.ethswarm.org",
allowedOrigins: []string{"https://gateway.ethswarm.org"},
wantCORS: true,
},
{
name: "single explicit blocked",
origin: "http://a-hacker.me",
allowedOrigins: []string{"https://gateway.ethswarm.org"},
wantCORS: false,
},
{
name: "multiple explicit",
origin: "https://staging.gateway.ethswarm.org",
allowedOrigins: []string{"https://gateway.ethswarm.org", "https://staging.gateway.ethswarm.org"},
wantCORS: true,
},
{
name: "multiple explicit blocked",
origin: "http://a-hacker.me",
allowedOrigins: []string{"https://gateway.ethswarm.org", "https://staging.gateway.ethswarm.org"},
wantCORS: false,
},
{
name: "wildcard",
origin: "http://localhost:1234",
allowedOrigins: []string{"*"},
wantCORS: true,
},
{
name: "wildcard",
origin: "https://gateway.ethswarm.org",
allowedOrigins: []string{"*"},
wantCORS: true,
},
{
name: "with origin only",
origin: "https://gateway.ethswarm.org",
allowedOrigins: nil,
wantCORS: false,
},
{
name: "with origin only not nil",
origin: "https://gateway.ethswarm.org",
allowedOrigins: []string{},
wantCORS: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
client, _, _ := newTestServer(t, testServerOptions{
CORSAllowedOrigins: tc.allowedOrigins,
})
req, err := http.NewRequest(http.MethodGet, "/", nil)
if err != nil {
t.Fatal(err)
}
if tc.origin != "" {
req.Header.Set("Origin", tc.origin)
}
r, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
got := r.Header.Get("Access-Control-Allow-Origin")
if tc.wantCORS {
if got != tc.origin {
t.Errorf("got Access-Control-Allow-Origin %q, want %q", got, tc.origin)
}
} else {
if got != "" {
t.Errorf("got Access-Control-Allow-Origin %q, want none", got)
}
}
})
}
}
......@@ -196,7 +196,7 @@ func (s *server) setupRouting() {
s.pageviewMetricsHandler,
func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if o := r.Header.Get("Origin"); o != "" && (len(s.CORSAllowedOrigins) == 0 || s.checkOrigin(r)) {
if o := r.Header.Get("Origin"); o != "" && s.checkOrigin(r) {
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Origin", o)
w.Header().Set("Access-Control-Allow-Headers", "Origin, Accept, Authorization, Content-Type, X-Requested-With, Access-Control-Request-Headers, Access-Control-Request-Method")
......
......@@ -12,7 +12,7 @@ import (
// corsHandler sets CORS headers to HTTP response if allowed origins are configured.
func (s *Service) corsHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if o := r.Header.Get("Origin"); o != "" && (len(s.corsAllowedOrigins) == 0 || checkOrigin(r, s.corsAllowedOrigins)) {
if o := r.Header.Get("Origin"); o != "" && checkOrigin(r, s.corsAllowedOrigins) {
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Origin", o)
w.Header().Set("Access-Control-Allow-Headers", "Origin, Accept, Authorization, Content-Type, X-Requested-With, Access-Control-Request-Headers, Access-Control-Request-Method")
......
......@@ -60,6 +60,18 @@ func TestCORSHeaders(t *testing.T) {
allowedOrigins: []string{"*"},
wantCORS: true,
},
{
name: "with origin only",
origin: "https://gateway.ethswarm.org",
allowedOrigins: nil,
wantCORS: false,
},
{
name: "with origin only not nil",
origin: "https://gateway.ethswarm.org",
allowedOrigins: []string{},
wantCORS: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
testServer := newTestServer(t, testServerOptions{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment