diff --git a/handler/auth.go b/handler/auth.go index 8a8577b..2a049fc 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -48,9 +48,9 @@ func (handler AuthImpl) Handle(router *http.ServeMux) { router.Handle("GET /auth/change-password", handler.handleChangePasswordPage()) router.Handle("POST /api/auth/change-password", handler.handleChangePasswordComp()) - router.Handle("/auth/forgot-password", handler.handleForgotPasswordPage()) - router.Handle("/api/auth/forgot-password", handler.handleForgotPasswordComp()) - router.Handle("/api/auth/forgot-password-actual", handler.handleForgotPasswordResponseComp()) + router.Handle("GET /auth/forgot-password", handler.handleForgotPasswordPage()) + router.Handle("POST /api/auth/forgot-password", handler.handleForgotPasswordComp()) + router.Handle("POST /api/auth/forgot-password-actual", handler.handleForgotPasswordResponseComp()) } var ( @@ -355,14 +355,13 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { if err != nil { utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) } else { - utils.TriggerToast(w, r, "info", "If the email exists, an email has been sent", http.StatusOK) + utils.TriggerToast(w, r, "info", "If the address exists, an email has been sent.", http.StatusOK) } } } func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - pageUrl, err := url.Parse(r.Header.Get("HX-Current-URL")) if err != nil { log.Error("Could not get current URL: %v", err) diff --git a/handler/index_and_404.go b/handler/index_and_404.go index 0c81c6b..8490cf6 100644 --- a/handler/index_and_404.go +++ b/handler/index_and_404.go @@ -36,13 +36,15 @@ func (handler IndexImpl) handleIndexAnd404() http.HandlerFunc { var comp templ.Component + var status int if r.URL.Path != "/" { comp = template.NotFound() - w.WriteHeader(http.StatusNotFound) + status = http.StatusNotFound } else { comp = template.Index() + status = http.StatusOK } - handler.render.RenderLayout(r, w, comp, user) + handler.render.RenderLayoutWithStatus(r, w, comp, user, status) } } diff --git a/main_test.go b/main_test.go index bcef8a7..b2b61bb 100644 --- a/main_test.go +++ b/main_test.go @@ -1191,7 +1191,153 @@ func TestIntegrationAuth(t *testing.T) { }) }) - t.Run("ForgotPassword", func(t *testing.T) { + t.Run("ForgotPasswordMail", func(t *testing.T) { + t.Run(`should redirect to "/" if signed in`, func(t *testing.T) { + t.Parallel() + + d, basePath, ctx := setupIntegrationTest(t) + userId := uuid.New() + + pass := service.GetHashPassword("password", []byte("salt")) + _, err := d.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) + assert.Nil(t, err) + + sessionId := "session-id" + _, err = d.Exec(` + INSERT INTO session (session_id, user_id, created_at, expires_at) + VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId) + assert.Nil(t, err) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) + assert.Nil(t, err) + req.Header.Set("Cookie", "id="+sessionId) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/", resp.Header.Get("Location")) + }) + t.Run(`should fail if csrf token is invalid`, func(t *testing.T) { + t.Parallel() + + d, basePath, ctx := setupIntegrationTest(t) + userId := uuid.New() + + pass := service.GetHashPassword("password", []byte("salt")) + _, err := d.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) + assert.Nil(t, err) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) + assert.Nil(t, err) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + anonymousSessionId := findCookie(resp, "id").Value + assert.NotEqual(t, "", anonymousSessionId) + + formData := url.Values{ + "email": {"mail@mail.de"}, + "csrf-token": {"invalid-csrf-token"}, + } + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("HX-Request", "true") + resp, err = httpClient.Do(req) + assert.Nil(t, err) + + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + + var rows int + err = d.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) + assert.Nil(t, err) + assert.Equal(t, 0, rows) + }) + t.Run(`should fail but respond with uniform message`, func(t *testing.T) { + t.Parallel() + + _, basePath, ctx := setupIntegrationTest(t) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) + assert.Nil(t, err) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + anonymousSessionId := findCookie(resp, "id").Value + assert.NotEqual(t, "", anonymousSessionId) + body, err := html.Parse(resp.Body) + assert.Nil(t, err) + anonymousCsrfToken := findCsrfToken(body) + + formData := url.Values{ + "email": {"non-existent@mail.de"}, + "csrf-token": {anonymousCsrfToken}, + } + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("HX-Request", "true") + req.Header.Set("Cookie", "id="+anonymousSessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + msg := "If the address exists, an email has been sent." + assert.Contains(t, resp.Header.Get("HX-Trigger"), msg) + }) + t.Run(`should generate token and respond with uniform message`, func(t *testing.T) { + t.Parallel() + + db, basePath, ctx := setupIntegrationTest(t) + + userId := uuid.New() + pass := service.GetHashPassword("password", []byte("salt")) + _, err := db.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) + assert.Nil(t, err) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) + assert.Nil(t, err) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + anonymousSessionId := findCookie(resp, "id").Value + assert.NotEqual(t, "", anonymousSessionId) + body, err := html.Parse(resp.Body) + assert.Nil(t, err) + anonymousCsrfToken := findCsrfToken(body) + + formData := url.Values{ + "email": {"mail@mail.de"}, + "csrf-token": {anonymousCsrfToken}, + } + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("HX-Request", "true") + req.Header.Set("Cookie", "id="+anonymousSessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + msg := "If the address exists, an email has been sent." + assert.Contains(t, resp.Header.Get("HX-Trigger"), msg) + + var rows int + err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) + assert.Nil(t, err) + assert.Equal(t, 1, rows) + }) + }) + + t.Run("ForgotPasswordResponse", func(t *testing.T) { t.Run("should change password and invalidate ALL sessions", func(t *testing.T) { t.Parallel()