diff --git a/db/auth.go b/db/auth.go index 6bfb729..20a3c4a 100644 --- a/db/auth.go +++ b/db/auth.go @@ -100,8 +100,8 @@ type Auth interface { InsertSession(session *Session) error GetSession(sessionId string) (*Session, error) + GetSessions(userId uuid.UUID) ([]*Session, error) DeleteSession(sessionId string) error - DeleteOtherSessions(userId uuid.UUID, sessionId string) error DeleteOldSessions(userId uuid.UUID) error } @@ -417,16 +417,33 @@ func (db AuthSqlite) GetSession(sessionId string) (*Session, error) { return NewSession(sessionId, userId, createdAt, expiresAt), nil } -func (db AuthSqlite) DeleteOtherSessions(userId uuid.UUID, sessionId string) error { - _, err := db.db.Exec(` - DELETE FROM session - WHERE session_id != ? - AND user_id = ?`, sessionId, userId) +func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*Session, error) { + + sessions, err := db.db.Query(` + SELECT session_id, created_at, expires_at + FROM session + WHERE user_id = ?`, userId) if err != nil { - log.Error("Could not delete other active sessions: %v", err) - return types.ErrInternal + log.Error("Could not get sessions: %v", err) + return nil, types.ErrInternal } - return nil + + var result []*Session + + for sessions.Next() { + var ( + sessionId string + createdAt time.Time + expiresAt time.Time + ) + + sessions.Scan(&sessionId, &createdAt, &expiresAt) + + session := NewSession(sessionId, userId, createdAt, expiresAt) + result = append(result, session) + } + + return result, nil } func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error { diff --git a/handler/auth.go b/handler/auth.go index c0acc90..dfc36f8 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -48,9 +48,9 @@ func (handler AuthImpl) Handle(router *http.ServeMux) { router.Handle("/auth/change-password", handler.handleChangePasswordPage()) router.Handle("/api/auth/change-password", handler.handleChangePasswordComp()) - router.Handle("/auth/reset-password", handler.handleResetPasswordPage()) - router.Handle("/api/auth/reset-password", handler.handleForgotPasswordComp()) - router.Handle("/api/auth/reset-password-actual", handler.handleForgotPasswordResponseComp()) + router.Handle("/auth/forgot-password", handler.handleForgotPasswordPage()) + router.Handle("/api/auth/forgot-password", handler.handleForgotPasswordComp()) + router.Handle("/api/auth/forgot-password-actual", handler.handleForgotPasswordResponseComp()) } var ( @@ -312,12 +312,12 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { } } -func (handler AuthImpl) handleResetPasswordPage() http.HandlerFunc { +func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := middleware.GetUser(r) - if user == nil { - utils.DoRedirect(w, r, "/auth/signin") + if user != nil { + utils.DoRedirect(w, r, "/") return } @@ -335,7 +335,11 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { return } - err := handler.service.SendForgotPasswordMail(email) + _, err := utils.WaitMinimumTime(securityWaitDuration, func() (interface{}, error) { + err := handler.service.SendForgotPasswordMail(email) + return nil, err + }) + if err != nil { utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) } else { @@ -355,11 +359,6 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { } token := pageUrl.Query().Get("token") - if token == "" { - utils.TriggerToast(w, r, "error", "No token", http.StatusBadRequest) - return - } - newPass := r.FormValue("new-password") err = handler.service.ForgotPassword(token, newPass) diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 9fbca5f..1186228 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -61,7 +61,9 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler } } - if session == nil && (strings.Contains(r.RequestURI, "/auth/signup") || strings.Contains(r.RequestURI, "/auth/signin")) { + // Always sign in anonymous + // This way, there is no way to forget creating a csrf token + if session == nil { session, _ = auth.SignInAnonymous() cookie := CreateSessionCookie(session.Id) diff --git a/main_test.go b/main_test.go index a2527e6..40fabf2 100644 --- a/main_test.go +++ b/main_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "me-fit/db" "me-fit/service" "me-fit/types" @@ -215,7 +216,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var sessionIds []string - sessions, err := db.Query("SELECT session_id FROM session ORDER BY session_id") + sessions, err := db.Query(`SELECT session_id FROM session WHERE NOT user_id = ? ORDER BY session_id`, uuid.Nil) assert.Nil(t, err) for sessions.Next() { var sessionId string @@ -228,69 +229,68 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, "other", sessionIds[0]) assert.Equal(t, "session-id", sessionIds[1]) }) - t.Run("should forget password and invalidate other sessions from user", func(t *testing.T) { + t.Run("should forget password and invalidate all user sessions", func(t *testing.T) { t.Parallel() - db, basePath, ctx := setupIntegrationTest(t) + d, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := d.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - sessionId := "session-id" assert.Nil(t, err) - _, err = db.Exec(` + _, err = d.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) - VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) - _, err = db.Exec(` - INSERT INTO session (session_id, user_id, created_at, expires_at) - VALUES ("second", ?, datetime(), datetime("now", "+1 day"))`, userId) + VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId) assert.Nil(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/change-password", nil) + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) assert.Nil(t, err) - req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) assert.Nil(t, err) + sessionId := findCookie(resp, "id").Value html, err := html.Parse(resp.Body) assert.Nil(t, err) - csrfToken := findCsrfToken(html) assert.NotEqual(t, "", csrfToken) formData := url.Values{ - "current-password": {"password"}, - "new-password": {"MyNewSecurePassword1!"}, - "csrf-token": {csrfToken}, + "email": {"mail@mail.de"}, + "csrf-token": {csrfToken}, } - - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) assert.Nil(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) req.Header.Set("HX-Request", "true") resp, err = httpClient.Do(req) assert.Nil(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - var sessionIds []string - sessions, err := db.Query("SELECT session_id FROM session ORDER BY session_id") + var token string + err = d.QueryRow("SELECT token FROM token WHERE type = ?", db.TokenTypePasswordReset).Scan(&token) assert.Nil(t, err) - for sessions.Next() { - var sessionId string - err = sessions.Scan(&sessionId) - assert.Nil(t, err) - sessionIds = append(sessionIds, sessionId) - } - assert.Equal(t, 2, len(sessionIds)) - assert.Equal(t, "other", sessionIds[0]) - assert.Equal(t, "session-id", sessionIds[1]) + formData = url.Values{ + "new-password": {"MyNewSecurePassword1!"}, + "csrf-token": {csrfToken}, + } + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+sessionId) + req.Header.Set("HX-Request", "true") + req.Header.Set("HX-Current-URL", basePath+"/auth/change-password?token="+url.QueryEscape(token)) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + sessions, err := d.Query("SELECT session_id FROM session WHERE user_id = ?", userId) + assert.Nil(t, err) + assert.False(t, sessions.Next()) }) } diff --git a/service/auth.go b/service/auth.go index 9aafb78..04592fc 100644 --- a/service/auth.go +++ b/service/auth.go @@ -23,6 +23,7 @@ var ( ErrInvalidEmail = errors.New("invalid email") ErrAccountExists = errors.New("account already exists") ErrSessionIdInvalid = errors.New("session ID is invalid") + ErrTokenInvalid = errors.New("token is invalid") ) type User struct { @@ -95,7 +96,6 @@ func NewAuthImpl(db db.Auth, random Random, clock Clock, mail Mail, serverSettin } func (service AuthImpl) SignIn(email string, password string) (*Session, error) { - log.Info("Sign in %s", email) user, err := service.db.GetUserByEmail(email) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -149,7 +149,6 @@ func (service AuthImpl) SignInSession(sessionId string) (*Session, error) { } func (service AuthImpl) SignInAnonymous() (*Session, error) { - log.Info("Sign in anonymous") sessionDb, err := service.createSession(uuid.Nil) if err != nil { return nil, types.ErrInternal @@ -350,9 +349,17 @@ func (service AuthImpl) ChangePassword(session *Session, currPass, newPass strin return err } - err = service.db.DeleteOtherSessions(session.User.Id, session.Id) + sessions, err := service.db.GetSessions(userDb.Id) if err != nil { - return err + return types.ErrInternal + } + for _, s := range sessions { + if s.Id != session.Id { + err = service.db.DeleteSession(s.Id) + if err != nil { + return types.ErrInternal + } + } } return nil @@ -399,7 +406,7 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { token, err := service.db.GetToken(tokenStr) if err != nil { - return err + return ErrTokenInvalid } err = service.db.DeleteToken(tokenStr) @@ -407,6 +414,11 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { return err } + if token.Type != db.TokenTypePasswordReset || + token.ExpiresAt.Before(service.clock.Now()) { + return ErrTokenInvalid + } + user, err := service.db.GetUser(token.UserId) if err != nil { log.Error("Could not get user from token: %v", err) @@ -421,6 +433,18 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { return err } + sessions, err := service.db.GetSessions(user.Id) + if err != nil { + return types.ErrInternal + } + + for _, session := range sessions { + err = service.db.DeleteSession(session.Id) + if err != nil { + return types.ErrInternal + } + } + return nil } diff --git a/template/auth/change_password.templ b/template/auth/change_password.templ index fbd1637..eeacbff 100644 --- a/template/auth/change_password.templ +++ b/template/auth/change_password.templ @@ -4,7 +4,7 @@ templ ChangePasswordComp(isPasswordReset bool) {