package middleware import ( "log/slog" "net/http" "spend-sparrow/internal/service" "spend-sparrow/internal/utils" "strings" ) type csrfResponseWriter struct { http.ResponseWriter csrfToken string } func newCsrfResponseWriter(w http.ResponseWriter, csrfToken string) *csrfResponseWriter { return &csrfResponseWriter{ ResponseWriter: w, csrfToken: csrfToken, } } func (rr *csrfResponseWriter) Write(data []byte) (int, error) { dataStr := string(data) if rr.csrfToken != "" { dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", rr.csrfToken) } return rr.ResponseWriter.Write([]byte(dataStr)) } func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := GetSession(r) ctx := r.Context() if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete || r.Method == http.MethodPatch { csrfToken := r.Header.Get("Csrf-Token") if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(ctx, csrfToken, session.Id) { slog.Info("CSRF-Token not correct", "token", csrfToken) if r.Header.Get("Hx-Request") == "true" { utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) } else { http.Error(w, "CSRF-Token not correct", http.StatusBadRequest) } return } } token, err := auth.GetCsrfToken(ctx, session) if err != nil { if r.Header.Get("Hx-Request") == "true" { utils.TriggerToastWithStatus(w, r, "error", "Could not generate CSRF Token", http.StatusBadRequest) } else { http.Error(w, "Could not generate CSRF Token", http.StatusBadRequest) } return } responseWriter := newCsrfResponseWriter(w, token) next.ServeHTTP(responseWriter, r) }) } }