diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go new file mode 100644 index 0000000..cc695a8 --- /dev/null +++ b/handler/middleware/cross_site_request_forgery.go @@ -0,0 +1,21 @@ +package middleware + +import "net/http" + +func CrossSiteRequestForgery() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" { + // Check the CSRF token + csrfToken := r.Header.Get("X-CSRF-Token") + sessionToken := r.Header.Get("X-Session-Token") + if csrfToken != sessionToken { + http.Error(w, "CSRF token mismatch", http.StatusForbidden) + return + } + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/handler/middleware/user.go b/handler/middleware/user.go new file mode 100644 index 0000000..edd720d --- /dev/null +++ b/handler/middleware/user.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "me-fit/service" + + "net/http" +) + +func UserAuth(service *service.AuthService) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check the user is logged in + sessionToken := r.Header.Get("X-Session-Token") + if sessionToken == "" { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) + } +}