package middleware import ( "fmt" "me-fit/service" "strings" "net/http" ) type csrfResponseWriter struct { http.ResponseWriter auth service.Auth session *service.Session } func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *service.Session) *csrfResponseWriter { return &csrfResponseWriter{ ResponseWriter: w, auth: auth, session: session, } } func (rr *csrfResponseWriter) Write(data []byte) (int, error) { dataStr := string(data) if strings.Contains(dataStr, "") { csrfToken, err := rr.auth.GetCsrfToken(rr.session) if err == nil { csrfField := fmt.Sprintf(``, csrfToken) dataStr = strings.ReplaceAll(dataStr, "", csrfField+"") } } return rr.ResponseWriter.Write([]byte(dataStr)) } func (rr *csrfResponseWriter) WriteHeader(statusCode int) { rr.ResponseWriter.WriteHeader(statusCode) } func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := GetSession(r) if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete || r.Method == http.MethodPatch { csrfToken := r.FormValue("csrf-token") if csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { http.Error(w, "", http.StatusForbidden) return } } if session == nil { var err error session, err = auth.SignInAnonymous() if err != nil { http.Error(w, "", http.StatusInternalServerError) return } } cookie := http.Cookie{ Name: "id", Value: session.Id, MaxAge: 60 * 60 * 8, // 8 hours Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, Path: "/", } http.SetCookie(w, &cookie) responseWriter := newCsrfResponseWriter(w, auth, session) next.ServeHTTP(responseWriter, r) }) } }