package middleware import ( "fmt" "net/http" "strings" "me-fit/log" "me-fit/service" "me-fit/types" ) type csrfResponseWriter struct { http.ResponseWriter auth service.Auth session *types.Session } func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter { return &csrfResponseWriter{ ResponseWriter: w, auth: auth, session: session, } } func (rr *csrfResponseWriter) Write(data []byte) (int, error) { dataStr := string(data) csrfToken, err := rr.auth.GetCsrfToken(rr.session) if err == nil { csrfInput := fmt.Sprintf(``, csrfToken) dataStr = strings.ReplaceAll(dataStr, "", csrfInput+"") dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken) } return rr.ResponseWriter.Write([]byte(dataStr)) } func (rr *csrfResponseWriter) WriteHeader(statusCode int) { rr.ResponseWriter.WriteHeader(statusCode) } func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := GetSession(r) if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete || r.Method == http.MethodPatch { csrfToken := r.FormValue("csrf-token") if csrfToken == "" { csrfToken = r.Header.Get("csrf-token") } if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { log.Info("CSRF-Token not correct") http.Error(w, "CSRF-Token not correct", http.StatusBadRequest) return } } responseWriter := newCsrfResponseWriter(w, auth, session) next.ServeHTTP(responseWriter, r) }) } }