Commit fe7a64f8 authored by acud's avatar acud Committed by GitHub

api: add pss CORS check (#988)

Co-authored-by: default avatarJanoš Guljaš <janos@users.noreply.github.com>
Co-authored-by: default avatarsig <dan@1up.digital>
parent 87de8da0
......@@ -13,6 +13,7 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"
"github.com/ethersphere/bee/pkg/logging"
m "github.com/ethersphere/bee/pkg/metrics"
......@@ -194,9 +195,53 @@ func (s *server) newTracingHandler(spanName string) func(h http.Handler) http.Ha
}
}
// checkOrigin returns true if the origin is not set or is equal to the request host.
func (s *server) checkOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
hosts := append(s.CORSAllowedOrigins, scheme+"://"+r.Host)
for _, v := range hosts {
if equalASCIIFold(origin[0], v) || v == "*" {
return true
}
}
return false
}
func lookaheadBufferSize(size int64) int {
if size <= largeBufferFilesizeThreshold {
return smallFileBufferSize
}
return largeFileBufferSize
}
// equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
......@@ -22,10 +22,6 @@ import (
)
var (
upgrader = websocket.Upgrader{
ReadBufferSize: swarm.ChunkSize,
WriteBufferSize: swarm.ChunkSize,
}
writeDeadline = 4 * time.Second // write deadline. should be smaller than the shutdown timeout on api close
targetMaxLength = 2 // max target length in bytes, in order to prevent grieving by excess computation
)
......@@ -86,6 +82,13 @@ func (s *server) pssPostHandler(w http.ResponseWriter, r *http.Request) {
}
func (s *server) pssWsHandler(w http.ResponseWriter, r *http.Request) {
upgrader := websocket.Upgrader{
ReadBufferSize: swarm.ChunkSize,
WriteBufferSize: swarm.ChunkSize,
CheckOrigin: s.checkOrigin,
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
s.Logger.Debugf("pss ws: upgrade: %v", err)
......
......@@ -170,7 +170,7 @@ func (s *server) setupRouting() {
s.pageviewMetricsHandler,
func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if o := r.Header.Get("Origin"); o != "" && (s.CORSAllowedOrigins == nil || containsOrigin(o, s.CORSAllowedOrigins)) {
if o := r.Header.Get("Origin"); o != "" && (len(s.CORSAllowedOrigins) == 0 || s.checkOrigin(r)) {
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Origin", o)
w.Header().Set("Access-Control-Allow-Headers", "Origin, Accept, Authorization, Content-Type, X-Requested-With, Access-Control-Request-Headers, Access-Control-Request-Method")
......@@ -185,15 +185,6 @@ func (s *server) setupRouting() {
)
}
func containsOrigin(s string, l []string) (ok bool) {
for _, e := range l {
if e == s || e == "*" {
return true
}
}
return false
}
func (s *server) gatewayModeForbidEndpointHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if s.GatewayMode {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment