diff --git a/handler/auth.go b/handler/auth.go index 34dcf47..a448f04 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -40,7 +40,7 @@ func (handler AuthImpl) Handle(router *http.ServeMux) { router.Handle("/auth/verify-email", handler.handleSignUpVerifyResponsePage()) router.Handle("/api/auth/signup", handler.handleSignUp()) - router.Handle("/api/auth/signout", handler.handleSignOut()) + router.Handle("POST /api/auth/signout", handler.handleSignOut()) router.Handle("/auth/delete-account", handler.handleDeleteAccountPage()) router.Handle("/api/auth/delete-account", handler.handleDeleteAccountComp()) diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 28b9648..41d18bd 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -5,6 +5,7 @@ import ( "net/http" "strings" + "me-fit/log" "me-fit/service" "me-fit/types" ) @@ -25,13 +26,11 @@ func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *ty func (rr *csrfResponseWriter) Write(data []byte) (int, error) { dataStr := string(data) - if strings.Contains(dataStr, "") { - csrfToken, err := rr.auth.GetCsrfToken(rr.session) - if err == nil { - csrfField := fmt.Sprintf(``, csrfToken) - dataStr = strings.ReplaceAll(dataStr, "", csrfField+"") - dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken) - } + csrfToken, err := rr.auth.GetCsrfToken(rr.session) + if err == nil { + csrfInput := fmt.Sprintf(``, csrfToken) + dataStr = strings.ReplaceAll(dataStr, "", csrfInput+"") + dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken) } return rr.ResponseWriter.Write([]byte(dataStr)) @@ -57,6 +56,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler csrfToken = r.Header.Get("csrf-token") } if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { + log.Info("CSRF-Token not correct") http.Error(w, "CSRF-Token not correct", http.StatusBadRequest) return } diff --git a/main_test.go b/main_test.go index 3eb8787..765efc7 100644 --- a/main_test.go +++ b/main_test.go @@ -163,6 +163,70 @@ func TestIntegrationAuth(t *testing.T) { assert.NotEqual(t, anonymousSession.Value, cookie.Value, "Session ID did not change") }) }) + t.Run("SignOut", func(t *testing.T) { + t.Run("should fail if csrf token is not valid", func(t *testing.T) { + t.Parallel() + + _, basePath, ctx := setupIntegrationTest(t) + + req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/sign-out", nil) + assert.Nil(t, err) + req.Header.Set("csrf-token", "invalid-csrf-token") + resp, err := httpClient.Do(req) + assert.Nil(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + }) + t.Run(`should delete current session and redirect to "/"`, func(t *testing.T) { + t.Parallel() + + db, basePath, ctx := setupIntegrationTest(t) + + userId := uuid.New() + sessionId := "session-id" + + pass := service.GetHashPassword("password", []byte("salt")) + _, err := db.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) + assert.Nil(t, err) + _, err = db.Exec(` + INSERT INTO session (session_id, user_id, created_at, expires_at) + VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) + assert.Nil(t, err) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/", nil) + assert.Nil(t, err) + req.Header.Set("Cookie", "id="+sessionId) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var csrfToken string + err = db.QueryRow("SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken) + assert.Nil(t, err) + + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signout", nil) + assert.Nil(t, err) + req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Cookie", "id="+sessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/", resp.Header.Get("Location")) + + cookie := findCookie(resp, "id") + assert.NotNil(t, cookie) + assert.Equal(t, "", cookie.Value) + assert.Equal(t, -1, cookie.MaxAge) + + var rows int + err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) + assert.Nil(t, err) + assert.Equal(t, 0, rows) + }) + }) t.Run("DeleteAccount", func(t *testing.T) { t.Run(`should redirect to "/" if not signed in`, func(t *testing.T) { t.Parallel() diff --git a/service/auth.go b/service/auth.go index 7a7a687..f168e11 100644 --- a/service/auth.go +++ b/service/auth.go @@ -443,12 +443,14 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) { return "", types.ErrInternal } - token := types.NewToken(uuid.Nil, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) + token := types.NewToken(session.UserId, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(8*time.Hour)) err = service.db.InsertToken(token) if err != nil { return "", types.ErrInternal } + log.Info("CSRF-Token created: %v", tokenStr) + return tokenStr, nil } diff --git a/template/auth/user.templ b/template/auth/user.templ index 315b5a7..96b4b86 100644 --- a/template/auth/user.templ +++ b/template/auth/user.templ @@ -1,38 +1,30 @@ package auth templ UserComp(user string) { -
- if user != "" { -
- - -
- } else { - Sign Up - Sign In - } +
+ if user != "" { +
+ +
+ } else { + Sign Up + Sign In + } +
} diff --git a/template/layout.templ b/template/layout.templ index 311d41f..78fe1f7 100644 --- a/template/layout.templ +++ b/template/layout.templ @@ -1,48 +1,48 @@ package template templ Layout(slot templ.Component, user templ.Component, environment string) { - - - - - ME-FIT - - - - if environment == "prod" { - - } - - - - - -
-
- - ME-FIT logo - ME-FIT - - @user -
-
- if slot != nil { - @slot - } -
-
-
- -
- - + }' /> + + + + + +
+
+ + ME-FIT logo + ME-FIT + + @user +
+
+ if slot != nil { + @slot + } +
+
+
+ +
+ + + } diff --git a/template/workout/workout.templ b/template/workout/workout.templ index 65d6e4d..35e62da 100644 --- a/template/workout/workout.templ +++ b/template/workout/workout.templ @@ -60,8 +60,7 @@ if includePlaceholder { { w.Reps }
-