Commit e9d1e561 authored by Janoš Guljaš's avatar Janoš Guljaš Committed by GitHub

limit max http body size and detect it (#505)

parent 35509564
......@@ -51,6 +51,9 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
data, err := ioutil.ReadAll(r.Body)
if err != nil {
if jsonhttp.HandleBodyReadError(err, w) {
return
}
s.Logger.Debugf("chunk upload: read chunk data error: %v, addr %s", err, address)
s.Logger.Error("chunk upload: read chunk data error")
jsonhttp.InternalServerError(w, "cannot read chunk data")
......
......@@ -10,6 +10,7 @@ import (
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
......@@ -54,8 +55,11 @@ func (s *server) setupRouting() {
})
handle(router, "/chunks/{addr}", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.chunkGetHandler),
"POST": http.HandlerFunc(s.chunkUploadHandler),
"GET": http.HandlerFunc(s.chunkGetHandler),
"POST": web.ChainHandlers(
jsonhttp.NewMaxBodyBytesHandler(swarm.ChunkWithSpanSize),
web.FinalHandlerFunc(s.chunkUploadHandler),
),
})
handle(router, "/bzz/{address}/{path:.*}", jsonhttp.MethodHandler{
......
......@@ -19,3 +19,37 @@ func (h MethodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func NotFoundHandler(w http.ResponseWriter, _ *http.Request) {
NotFound(w, nil)
}
// NewMaxBodyBytesHandler is an http middleware constructor that limits the
// maximal number of bytes that can be read from the request body. When a body
// is read, the error can be handled with a helper function HandleBodyReadError
// in order to respond with Request Entity Too Large response.
// See TestNewMaxBodyBytesHandler as an example.
func NewMaxBodyBytesHandler(limit int64) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ContentLength > limit {
RequestEntityTooLarge(w, nil)
return
}
r.Body = http.MaxBytesReader(w, r.Body, limit)
h.ServeHTTP(w, r)
})
}
}
// HandleBodyReadError checks for particular errors and writes appropriate
// response accordingly. If no known error is found, no response is written and
// the function returns false.
func HandleBodyReadError(err error, w http.ResponseWriter) (responded bool) {
if err == nil {
return false
}
// http.MaxBytesReader returns an unexported error,
// this is the only way to detect it
if err.Error() == "http: request body too large" {
RequestEntityTooLarge(w, nil)
return true
}
return false
}
......@@ -114,3 +114,83 @@ func TestNotFoundHandler(t *testing.T) {
testContentType(t, w)
}
func TestNewMaxBodyBytesHandler(t *testing.T) {
var limit int64 = 10
h := jsonhttp.NewMaxBodyBytesHandler(limit)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := ioutil.ReadAll(r.Body)
if err != nil {
if jsonhttp.HandleBodyReadError(err, w) {
return
}
jsonhttp.InternalServerError(w, nil)
return
}
jsonhttp.OK(w, nil)
}))
for _, tc := range []struct {
name string
body string
withoutContentLength bool
wantCode int
}{
{
name: "empty",
wantCode: http.StatusOK,
},
{
name: "within limit without content length header",
body: "data",
withoutContentLength: true,
wantCode: http.StatusOK,
},
{
name: "within limit",
body: "data",
wantCode: http.StatusOK,
},
{
name: "over limit",
body: "long test data",
wantCode: http.StatusRequestEntityTooLarge,
},
{
name: "over limit without content length header",
body: "long test data",
withoutContentLength: true,
wantCode: http.StatusRequestEntityTooLarge,
},
} {
t.Run(tc.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.body))
if tc.withoutContentLength {
r.Header.Del("Content-Length")
r.ContentLength = 0
}
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
if w.Code != tc.wantCode {
t.Errorf("got http response code %d, want %d", w.Code, tc.wantCode)
}
var m *jsonhttp.StatusResponse
if err := json.Unmarshal(w.Body.Bytes(), &m); err != nil {
t.Errorf("json unmarshal response body: %s", err)
}
if m.Code != tc.wantCode {
t.Errorf("got message code %d, want %d", m.Code, tc.wantCode)
}
wantMessage := http.StatusText(tc.wantCode)
if m.Message != wantMessage {
t.Errorf("got message message %q, want %q", m.Message, wantMessage)
}
})
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment