package middleware import ( "fmt" "net/http" "strings" "me-fit/service" ) type csrfResponseWriter struct { http.ResponseWriter auth service.Auth session *service.Session } func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *service.Session) *csrfResponseWriter { return &csrfResponseWriter{ ResponseWriter: w, auth: auth, session: session, } } func (rr *csrfResponseWriter) Write(data []byte) (int, error) { dataStr := string(data) if strings.Contains(dataStr, "") { csrfToken, err := rr.auth.GetCsrfToken(rr.session) if err == nil { csrfField := fmt.Sprintf(``, csrfToken) dataStr = strings.ReplaceAll(dataStr, "", csrfField+"") dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken) } } return rr.ResponseWriter.Write([]byte(dataStr)) } func (rr *csrfResponseWriter) WriteHeader(statusCode int) { rr.ResponseWriter.WriteHeader(statusCode) } func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := GetSession(r) if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete || r.Method == http.MethodPatch { csrfToken := r.FormValue("csrf-token") if csrfToken == "" { csrfToken = r.Header.Get("csrf-token") } if csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { http.Error(w, "", http.StatusForbidden) return } } // Always sign in anonymous // This way, there is no way to forget creating a csrf token if session == nil { session, _ = auth.SignInAnonymous() cookie := CreateSessionCookie(session.Id) http.SetCookie(w, &cookie) } responseWriter := newCsrfResponseWriter(w, auth, session) next.ServeHTTP(responseWriter, r) }) } }