package middleware import ( "net/http" "strings" "spend-sparrow/internal/log" "spend-sparrow/internal/service" "spend-sparrow/internal/types" "spend-sparrow/internal/utils" ) type csrfResponseWriter struct { http.ResponseWriter auth service.Auth session *types.Session } func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter { return &csrfResponseWriter{ ResponseWriter: w, auth: auth, session: session, } } func (rr *csrfResponseWriter) Write(data []byte) (int, error) { dataStr := string(data) csrfToken, err := rr.auth.GetCsrfToken(rr.session) if err == nil { dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken) } return rr.ResponseWriter.Write([]byte(dataStr)) } func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := GetSession(r) if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete || r.Method == http.MethodPatch { csrfToken := r.Header.Get("Csrf-Token") if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { log.Info("CSRF-Token \"%s\" not correct", csrfToken) if r.Header.Get("Hx-Request") == "true" { utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) } else { http.Error(w, "CSRF-Token not correct", http.StatusBadRequest) } return } } responseWriter := newCsrfResponseWriter(w, auth, session) next.ServeHTTP(responseWriter, r) }) } }