diff --git a/handler/default.go b/handler/default.go index fb2f75a..2c7ad6e 100644 --- a/handler/default.go +++ b/handler/default.go @@ -35,5 +35,9 @@ func GetHandler(d *sql.DB, serverSettings *types.ServerSettings) http.Handler { authHandler.handle(router) - return middleware.Logging(middleware.ContentSecurityPolicy(middleware.EnableCors(serverSettings, router))) + return middleware.Wrapper( + router, + middleware.Log, + middleware.ContentSecurityPolicy, + middleware.Cors(serverSettings)) } diff --git a/middleware/content_security_policiy.go b/middleware/content_security_policiy.go index 11aebc5..68f55a6 100644 --- a/middleware/content_security_policiy.go +++ b/middleware/content_security_policiy.go @@ -5,6 +5,8 @@ import "net/http" func ContentSecurityPolicy(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // While this value can be overridden, it can't be moved to after the next.ServeHTTP call, + // because if the response writer get's closed, the headers can't be set anymore w.Header().Set("Content-Security-Policy", "default-src 'self' https://umami.me-fit.eu") next.ServeHTTP(w, r) diff --git a/middleware/cors.go b/middleware/cors.go index 86bb71b..28b3c6a 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -6,17 +6,18 @@ import ( "net/http" ) -func EnableCors(serverSettings *types.ServerSettings, next http.Handler) http.Handler { +func Cors(serverSettings *types.ServerSettings) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", serverSettings.BaseUrl) + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE") - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", serverSettings.BaseUrl) - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE") + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } - - next.ServeHTTP(w, r) - }) + next.ServeHTTP(w, r) + }) + } } diff --git a/middleware/logger.go b/middleware/logger.go index 6b55d99..656be5e 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -30,7 +30,7 @@ func (w *WrappedWriter) WriteHeader(code int) { w.StatusCode = code } -func Logging(next http.Handler) http.Handler { +func Log(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/middleware/wrapper.go b/middleware/wrapper.go new file mode 100644 index 0000000..80dcb20 --- /dev/null +++ b/middleware/wrapper.go @@ -0,0 +1,13 @@ +package middleware + +import "net/http" + +func Wrapper(next http.Handler, handlers ...func(http.Handler) http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lastHandler := next + for i := len(handlers) - 1; i >= 0; i-- { + lastHandler = handlers[i](lastHandler) + } + lastHandler.ServeHTTP(w, r) + }) +}