diff --git a/handler/default.go b/handler/default.go index c643ac6..84babf6 100644 --- a/handler/default.go +++ b/handler/default.go @@ -38,6 +38,7 @@ func GetHandler(d *sql.DB, serverSettings *types.ServerSettings) http.Handler { return middleware.Wrapper( router, middleware.Log, + middleware.SecFetchFilter, middleware.ContentSecurityPolicy, middleware.Cors(serverSettings), middleware.Corp, diff --git a/middleware/sec_fetch_filter.go b/middleware/sec_fetch_filter.go new file mode 100644 index 0000000..32b1f80 --- /dev/null +++ b/middleware/sec_fetch_filter.go @@ -0,0 +1,29 @@ +package middleware + +import "net/http" + +func SecFetchFilter(next http.Handler) http.Handler { + + // A map is slower than a slice, but it's easier to check if a value exists + allowedSites := map[string]interface{}{ + "same-origin": nil, + "none": nil, + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + secFetchSite := r.Header.Get("Sec-Fetch-Site") + + if secFetchSite == "" { + next.ServeHTTP(w, r) + return + } + + _, exists := allowedSites[r.Header.Get("Sec-Fetch-Site")] + if !exists { + next.ServeHTTP(w, r) + return + } + + w.WriteHeader(http.StatusForbidden) + }) +}