From 61fe5e64bba008bf4c862b21b22fa803552ae2a2 Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Fri, 6 Dec 2024 22:42:23 +0100 Subject: [PATCH] feat(security): #286 first try on csrf --- .../middleware/cross_site_request_forgery.go | 43 ++++++++++++++++--- log/default.go | 8 +++- main.go | 3 +- service/auth.go | 6 +-- types/{server_settings.go => settings.go} | 4 +- 5 files changed, 52 insertions(+), 12 deletions(-) rename types/{server_settings.go => settings.go} (94%) diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index cfb45e7..2edbaa9 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -1,15 +1,48 @@ package middleware import ( + "fmt" "me-fit/service" + "strings" + "net/http" ) -func CrossSiteRequestForgery(auth *service.Auth) func(http.Handler) http.Handler { +type csrfResponseWriter struct { + http.ResponseWriter + auth service.Auth + session *service.Session +} + +func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *service.Session) *csrfResponseWriter { + return &csrfResponseWriter{ + ResponseWriter: w, + auth: auth, + session: session, + } +} + + +TODO: Create session for CSRF token + +func (rr *csrfResponseWriter) Write(data []byte) (int, error) { + dataStr := string(data) + if strings.Contains(dataStr, "") { + csrfToken, err := rr.auth.GetCsrfToken(rr.session) + if err == nil { + csrfField := fmt.Sprintf(``, csrfToken) + dataStr = strings.ReplaceAll(dataStr, "", csrfField+"") + } + } + + return rr.ResponseWriter.Write([]byte(dataStr)) +} + +func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // session := r.Context().Value(SessionKey) + session := GetSession(r) if r.Method == http.MethodPost || r.Method == http.MethodPut || @@ -17,14 +50,14 @@ func CrossSiteRequestForgery(auth *service.Auth) func(http.Handler) http.Handler r.Method == http.MethodPatch { csrfToken := r.FormValue("csrf-token") - if csrfToken == "" { + if csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { http.Error(w, "", http.StatusForbidden) return } - } - next.ServeHTTP(w, r) + responseWriter := newCsrfResponseWriter(w, auth, session) + next.ServeHTTP(responseWriter, r) }) } } diff --git a/log/default.go b/log/default.go index 7d53fce..1ca6f3e 100644 --- a/log/default.go +++ b/log/default.go @@ -45,6 +45,12 @@ func Info(message string, args ...interface{}) { func format(message string, args []interface{}) string { var w strings.Builder - fmt.Fprintf(&w, message, args) + + if len(args) > 0 { + fmt.Fprintf(&w, message, args...) + } else { + w.WriteString(message) + } + return w.String() } diff --git a/main.go b/main.go index 0ad5eb0..56e6b79 100644 --- a/main.go +++ b/main.go @@ -77,7 +77,7 @@ func run(ctx context.Context, database *sql.DB, env func(string) string) { } func startServer(s *http.Server) { - log.Info("Starting server on %v", s.Addr) + log.Info("Starting server on %q", s.Addr) if err := s.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Error("error listening and serving: %v", err) } @@ -130,6 +130,7 @@ func createHandler(d *sql.DB, serverSettings *types.Settings) http.Handler { middleware.Log, middleware.ContentSecurityPolicy, middleware.Cors(serverSettings), + middleware.CrossSiteRequestForgery(authService), middleware.Corp, middleware.Coop, ) diff --git a/service/auth.go b/service/auth.go index 5f71614..8c18a08 100644 --- a/service/auth.go +++ b/service/auth.go @@ -71,7 +71,7 @@ type Auth interface { SendForgotPasswordMail(email string) error ForgotPassword(token string, newPass string) error - IsCsrfTokenValid(tokenStr string, userId uuid.UUID) bool + IsCsrfTokenValid(tokenStr string, sessionId string) bool GetCsrfToken(session *Session) (string, error) } @@ -394,14 +394,14 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { return nil } -func (service AuthImpl) IsCsrfTokenValid(tokenStr string, userId uuid.UUID) bool { +func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool { token, err := service.db.GetToken(tokenStr) if err != nil { return false } if token.Type != db.TokenTypeCsrf || - token.UserId != userId || + token.SessionId != sessionId || token.ExpiresAt.Before(service.clock.Now()) { return false diff --git a/types/server_settings.go b/types/settings.go similarity index 94% rename from types/server_settings.go rename to types/settings.go index e8d8606..b8fead1 100644 --- a/types/server_settings.go +++ b/types/settings.go @@ -77,8 +77,8 @@ func NewSettingsFromEnv(env func(string) string) *Settings { log.Fatal("SMTP and Prometheus must be enabled in production") } - log.Info("BASE_URL is %v", settings.BaseUrl) - log.Info("ENVIRONMENT is %v", settings.Environment) + log.Info("BASE_URL is %q", settings.BaseUrl) + log.Info("ENVIRONMENT is %q", settings.Environment) return settings }