chore(auth): #331 add tests for sign out

This commit is contained in:
2024-12-22 23:40:09 +01:00
parent fb6cc0acda
commit 52cd85d904
7 changed files with 142 additions and 85 deletions

View File

@@ -40,7 +40,7 @@ func (handler AuthImpl) Handle(router *http.ServeMux) {
router.Handle("/auth/verify-email", handler.handleSignUpVerifyResponsePage()) router.Handle("/auth/verify-email", handler.handleSignUpVerifyResponsePage())
router.Handle("/api/auth/signup", handler.handleSignUp()) router.Handle("/api/auth/signup", handler.handleSignUp())
router.Handle("/api/auth/signout", handler.handleSignOut()) router.Handle("POST /api/auth/signout", handler.handleSignOut())
router.Handle("/auth/delete-account", handler.handleDeleteAccountPage()) router.Handle("/auth/delete-account", handler.handleDeleteAccountPage())
router.Handle("/api/auth/delete-account", handler.handleDeleteAccountComp()) router.Handle("/api/auth/delete-account", handler.handleDeleteAccountComp())

View File

@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"me-fit/log"
"me-fit/service" "me-fit/service"
"me-fit/types" "me-fit/types"
) )
@@ -25,14 +26,12 @@ func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *ty
func (rr *csrfResponseWriter) Write(data []byte) (int, error) { func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
dataStr := string(data) dataStr := string(data)
if strings.Contains(dataStr, "</form>") {
csrfToken, err := rr.auth.GetCsrfToken(rr.session) csrfToken, err := rr.auth.GetCsrfToken(rr.session)
if err == nil { if err == nil {
csrfField := fmt.Sprintf(`<input type="hidden" name="csrf-token" value="%s">`, csrfToken) csrfInput := fmt.Sprintf(`<input type="hidden" name="csrf-token" value="%s" />`, csrfToken)
dataStr = strings.ReplaceAll(dataStr, "</form>", csrfField+"</form>") dataStr = strings.ReplaceAll(dataStr, "</form>", csrfInput+"</form>")
dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken) dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken)
} }
}
return rr.ResponseWriter.Write([]byte(dataStr)) return rr.ResponseWriter.Write([]byte(dataStr))
} }
@@ -57,6 +56,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
csrfToken = r.Header.Get("csrf-token") csrfToken = r.Header.Get("csrf-token")
} }
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
log.Info("CSRF-Token not correct")
http.Error(w, "CSRF-Token not correct", http.StatusBadRequest) http.Error(w, "CSRF-Token not correct", http.StatusBadRequest)
return return
} }

View File

@@ -163,6 +163,70 @@ func TestIntegrationAuth(t *testing.T) {
assert.NotEqual(t, anonymousSession.Value, cookie.Value, "Session ID did not change") assert.NotEqual(t, anonymousSession.Value, cookie.Value, "Session ID did not change")
}) })
}) })
t.Run("SignOut", func(t *testing.T) {
t.Run("should fail if csrf token is not valid", func(t *testing.T) {
t.Parallel()
_, basePath, ctx := setupIntegrationTest(t)
req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/sign-out", nil)
assert.Nil(t, err)
req.Header.Set("csrf-token", "invalid-csrf-token")
resp, err := httpClient.Do(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
t.Run(`should delete current session and redirect to "/"`, func(t *testing.T) {
t.Parallel()
db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New()
sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(`
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
assert.Nil(t, err)
_, err = db.Exec(`
INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
assert.Nil(t, err)
req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/", nil)
assert.Nil(t, err)
req.Header.Set("Cookie", "id="+sessionId)
resp, err := httpClient.Do(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var csrfToken string
err = db.QueryRow("SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken)
assert.Nil(t, err)
req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signout", nil)
assert.Nil(t, err)
req.Header.Set("csrf-token", csrfToken)
req.Header.Set("Cookie", "id="+sessionId)
resp, err = httpClient.Do(req)
assert.Nil(t, err)
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
assert.Equal(t, "/", resp.Header.Get("Location"))
cookie := findCookie(resp, "id")
assert.NotNil(t, cookie)
assert.Equal(t, "", cookie.Value)
assert.Equal(t, -1, cookie.MaxAge)
var rows int
err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
assert.Nil(t, err)
assert.Equal(t, 0, rows)
})
})
t.Run("DeleteAccount", func(t *testing.T) { t.Run("DeleteAccount", func(t *testing.T) {
t.Run(`should redirect to "/" if not signed in`, func(t *testing.T) { t.Run(`should redirect to "/" if not signed in`, func(t *testing.T) {
t.Parallel() t.Parallel()

View File

@@ -443,12 +443,14 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
return "", types.ErrInternal return "", types.ErrInternal
} }
token := types.NewToken(uuid.Nil, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) token := types.NewToken(session.UserId, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(8*time.Hour))
err = service.db.InsertToken(token) err = service.db.InsertToken(token)
if err != nil { if err != nil {
return "", types.ErrInternal return "", types.ErrInternal
} }
log.Info("CSRF-Token created: %v", tokenStr)
return tokenStr, nil return tokenStr, nil
} }

View File

@@ -4,24 +4,16 @@ templ UserComp(user string) {
<div id="user-info" class="flex gap-5 items-center"> <div id="user-info" class="flex gap-5 items-center">
if user != "" { if user != "" {
<div class="group inline-block relative"> <div class="group inline-block relative">
<button <button class="font-semibold py-2 px-4 inline-flex items-center">
class="font-semibold py-2 px-4 inline-flex items-center"
>
<span class="mr-1">{ user }</span> <span class="mr-1">{ user }</span>
<svg <svg class="fill-current h-4 w-4" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 20 20">
class="fill-current h-4 w-4" <path d="M9.293 12.95l.707.707L15.657 8l-1.414-1.414L10 10.828 5.757 6.586 4.343 8z"></path>
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
>
<path
d="M9.293 12.95l.707.707L15.657 8l-1.414-1.414L10 10.828 5.757 6.586 4.343 8z"
></path>
</svg> </svg>
</button> </button>
<div class="absolute hidden group-hover:block w-full"> <div class="absolute hidden group-hover:block w-full">
<ul class="menu bg-base-300 rounded-box w-fit float-right mr-4 p-3"> <ul class="menu bg-base-300 rounded-box w-fit float-right mr-4 p-3">
<li class="mb-1"> <li class="mb-1">
<a hx-get="/api/auth/signout" hx-target="#user-info">Sign Out</a> <a hx-post="/api/auth/signout" hx-target="#user-info">Sign Out</a>
</li> </li>
<li class="mb-1"> <li class="mb-1">
<a href="/auth/change-password">Change Password</a> <a href="/auth/change-password">Change Password</a>

View File

@@ -3,6 +3,7 @@ package template
templ Layout(slot templ.Component, user templ.Component, environment string) { templ Layout(slot templ.Component, user templ.Component, environment string) {
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="utf-8" /> <meta charset="utf-8" />
<title>ME-FIT</title> <title>ME-FIT</title>
@@ -12,18 +13,16 @@ templ Layout(slot templ.Component, user templ.Component, environment string) {
if environment == "prod" { if environment == "prod" {
<script defer src="https://umami.me-fit.eu/script.js" data-website-id="3c8efb09-44e4-4372-8a1e-c3bc675cd89a"></script> <script defer src="https://umami.me-fit.eu/script.js" data-website-id="3c8efb09-44e4-4372-8a1e-c3bc675cd89a"></script>
} }
<meta <meta name="htmx-config" content='{
name="htmx-config"
content='{
"includeIndicatorStyles": false, "includeIndicatorStyles": false,
"selfRequestsOnly": true, "selfRequestsOnly": true,
"allowScriptTags": false "allowScriptTags": false
}' }' />
/>
<script src="/static/js/htmx.min.js"></script> <script src="/static/js/htmx.min.js"></script>
<script src="/static/js/toast.js"></script> <script src="/static/js/toast.js"></script>
</head> </head>
<body>
<body hx-headers='{"csrf-token": "CSRF_TOKEN"}'>
<div class="h-screen flex flex-col"> <div class="h-screen flex flex-col">
<div class="flex justify-end items-center gap-2 py-1 px-2 h-12 md:gap-10 md:px-10 md:py-2 shadow"> <div class="flex justify-end items-center gap-2 py-1 px-2 h-12 md:gap-10 md:px-10 md:py-2 shadow">
<a href="/" class="flex-1 flex gap-2"> <a href="/" class="flex-1 flex gap-2">
@@ -44,5 +43,6 @@ templ Layout(slot templ.Component, user templ.Component, environment string) {
</div> </div>
</div> </div>
</body> </body>
</html> </html>
} }

View File

@@ -60,8 +60,7 @@ if includePlaceholder {
<th>{ w.Reps }</th> <th>{ w.Reps }</th>
<th> <th>
<div class="tooltip" data-tip="Delete Entry"> <div class="tooltip" data-tip="Delete Entry">
<button hx-headers='{"csrf-token": "CSRF_TOKEN"}' hx-delete={ "api/workout/" + w.Id } hx-target="closest tr" <button hx-delete={ "api/workout/" + w.Id } hx-target="closest tr" type="submit">
type="submit">
Delete Delete
</button> </button>
</div> </div>