diff --git a/db/auth.go b/db/auth.go index 579faad..6bfb729 100644 --- a/db/auth.go +++ b/db/auth.go @@ -101,6 +101,7 @@ type Auth interface { InsertSession(session *Session) error GetSession(sessionId string) (*Session, error) DeleteSession(sessionId string) error + DeleteOtherSessions(userId uuid.UUID, sessionId string) error DeleteOldSessions(userId uuid.UUID) error } @@ -416,9 +417,23 @@ func (db AuthSqlite) GetSession(sessionId string) (*Session, error) { return NewSession(sessionId, userId, createdAt, expiresAt), nil } +func (db AuthSqlite) DeleteOtherSessions(userId uuid.UUID, sessionId string) error { + _, err := db.db.Exec(` + DELETE FROM session + WHERE session_id != ? + AND user_id = ?`, sessionId, userId) + if err != nil { + log.Error("Could not delete other active sessions: %v", err) + return types.ErrInternal + } + return nil +} + func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error { - // Delete old inactive sessions - _, err := db.db.Exec("DELETE FROM session WHERE created_at < datetime('now','-8 hours') AND user_id = ?", userId) + _, err := db.db.Exec(` + DELETE FROM session + WHERE expires_at < datetime('now') + AND user_id = ?`, userId) if err != nil { log.Error("Could not delete old sessions: %v", err) return types.ErrInternal diff --git a/handler/auth.go b/handler/auth.go index e027e1c..c0acc90 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -93,12 +93,10 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc { }) if err != nil { - if err == service.ErrInvaidCredentials { - utils.TriggerToast(w, r, "error", "Invalid email or password") - http.Error(w, "Invalid email or password", http.StatusUnauthorized) + if err == service.ErrInvalidCredentials { + utils.TriggerToast(w, r, "error", "Invalid email or password", http.StatusUnauthorized) } else { - log.Error("Error signing in: %v", err) - http.Error(w, "An error occurred", http.StatusInternalServerError) + utils.TriggerToast(w, r, "error", "An error occurred", http.StatusInternalServerError) } return } @@ -198,16 +196,16 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc { if err != nil { if errors.Is(err, types.ErrInternal) { - utils.TriggerToast(w, r, "error", "An error occurred") + utils.TriggerToast(w, r, "error", "An error occurred", http.StatusInternalServerError) return } else if errors.Is(err, service.ErrInvalidEmail) { - utils.TriggerToast(w, r, "error", "The email provided is invalid") + utils.TriggerToast(w, r, "error", "The email provided is invalid", http.StatusBadRequest) return } // If the "service.ErrAccountExists", then just continue } - utils.TriggerToast(w, r, "success", "A link to activate your account has been emailed to the address provided.") + utils.TriggerToast(w, r, "success", "A link to activate your account has been emailed to the address provided.", http.StatusOK) } } @@ -261,15 +259,13 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { password := r.FormValue("password") - _, err := handler.service.SignIn(user.Email, password) + err := handler.service.DeleteAccount(user, password) if err != nil { - utils.TriggerToast(w, r, "error", "Password not correct") - return - } - - err = handler.service.DeleteAccount(user) - if err != nil { - utils.TriggerToast(w, r, "error", "Internal Server Error") + if err == service.ErrInvalidCredentials { + utils.TriggerToast(w, r, "error", "Password not correct", http.StatusUnauthorized) + } else { + utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) + } return } @@ -297,8 +293,8 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc { func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := middleware.GetUser(r) - if user == nil { + session := middleware.GetSession(r) + if session.User == nil { utils.DoRedirect(w, r, "/auth/signin") return } @@ -306,13 +302,13 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { currPass := r.FormValue("current-password") newPass := r.FormValue("new-password") - err := handler.service.ChangePassword(user, currPass, newPass) + err := handler.service.ChangePassword(session, currPass, newPass) if err != nil { - utils.TriggerToast(w, r, "error", "Password not correct") + utils.TriggerToast(w, r, "error", "Password not correct", http.StatusUnauthorized) return } - utils.TriggerToast(w, r, "success", "Password changed") + utils.TriggerToast(w, r, "success", "Password changed", http.StatusOK) } } @@ -335,15 +331,15 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { email := r.FormValue("email") if email == "" { - utils.TriggerToast(w, r, "error", "Please enter an email") + utils.TriggerToast(w, r, "error", "Please enter an email", http.StatusBadRequest) return } err := handler.service.SendForgotPasswordMail(email) if err != nil { - utils.TriggerToast(w, r, "error", "Internal Server Error") + utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) } else { - utils.TriggerToast(w, r, "info", "If the email exists, an email has been sent") + utils.TriggerToast(w, r, "info", "If the email exists, an email has been sent", http.StatusOK) } } } @@ -354,13 +350,13 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { pageUrl, err := url.Parse(r.Header.Get("HX-Current-URL")) if err != nil { log.Error("Could not get current URL: %v", err) - utils.TriggerToast(w, r, "error", "Internal Server Error") + utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) return } token := pageUrl.Query().Get("token") if token == "" { - utils.TriggerToast(w, r, "error", "No token") + utils.TriggerToast(w, r, "error", "No token", http.StatusBadRequest) return } @@ -368,9 +364,9 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { err = handler.service.ForgotPassword(token, newPass) if err != nil { - utils.TriggerToast(w, r, "error", err.Error()) + utils.TriggerToast(w, r, "error", err.Error(), http.StatusInternalServerError) } else { - utils.TriggerToast(w, r, "success", "Password changed") + utils.TriggerToast(w, r, "success", "Password changed", http.StatusOK) } } } diff --git a/handler/workout.go b/handler/workout.go index ebb563f..169deff 100644 --- a/handler/workout.go +++ b/handler/workout.go @@ -2,7 +2,6 @@ package handler import ( "me-fit/handler/middleware" - "me-fit/log" "me-fit/service" "me-fit/template/workout" "me-fit/utils" @@ -67,7 +66,7 @@ func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc { wo := service.NewWorkoutDto("", dateStr, typeStr, setsStr, repsStr) wo, err := handler.service.AddWorkout(session.User, wo) if err != nil { - utils.TriggerToast(w, r, "error", "Invalid input values") + utils.TriggerToast(w, r, "error", "Invalid input values", http.StatusBadRequest) http.Error(w, "Invalid input values", http.StatusBadRequest) return } @@ -111,25 +110,19 @@ func (handler WorkoutImpl) handleDeleteWorkout() http.HandlerFunc { rowId := r.PathValue("id") if rowId == "" { - http.Error(w, "Missing required fields", http.StatusBadRequest) - log.Warn("Missing required fields for workout delete") - utils.TriggerToast(w, r, "error", "Missing ID field") + utils.TriggerToast(w, r, "error", "Missing ID field", http.StatusBadRequest) return } rowIdInt, err := strconv.Atoi(rowId) if err != nil { - http.Error(w, "Invalid ID", http.StatusBadRequest) - log.Warn("Invalid ID for workout delete") - utils.TriggerToast(w, r, "error", "Invalid ID") + utils.TriggerToast(w, r, "error", "Invalid ID", http.StatusBadRequest) return } err = handler.service.DeleteWorkout(session.User, rowIdInt) if err != nil { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - log.Error("Could not delete workout: %v", err.Error()) - utils.TriggerToast(w, r, "error", "Internal Server Error") + utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) return } } diff --git a/main_test.go b/main_test.go index 5447ed0..a2527e6 100644 --- a/main_test.go +++ b/main_test.go @@ -163,7 +163,6 @@ func TestIntegrationAuth(t *testing.T) { t.Parallel() db, basePath, ctx := setupIntegrationTest(t) - // TODO: take decision, if tests should be fully end to end (e.g. always send a signup request) or halway end to end (e.g. insert user into db) userId := uuid.New() userIdOther := uuid.New() @@ -225,10 +224,73 @@ func TestIntegrationAuth(t *testing.T) { sessionIds = append(sessionIds, sessionId) } - t.Fatalf("sessionIds: %v", sessionIds) assert.Equal(t, 2, len(sessionIds)) - assert.Equal(t, "session-id", sessionIds[0]) assert.Equal(t, "other", sessionIds[0]) + assert.Equal(t, "session-id", sessionIds[1]) + }) + t.Run("should forget password and invalidate other sessions from user", func(t *testing.T) { + t.Parallel() + + db, basePath, ctx := setupIntegrationTest(t) + userId := uuid.New() + + pass := service.GetHashPassword("password", []byte("salt")) + _, err := db.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) + + sessionId := "session-id" + assert.Nil(t, err) + _, err = db.Exec(` + INSERT INTO session (session_id, user_id, created_at, expires_at) + VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) + assert.Nil(t, err) + _, err = db.Exec(` + INSERT INTO session (session_id, user_id, created_at, expires_at) + VALUES ("second", ?, datetime(), datetime("now", "+1 day"))`, userId) + assert.Nil(t, err) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/change-password", nil) + assert.Nil(t, err) + req.Header.Set("Cookie", "id="+sessionId) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + + html, err := html.Parse(resp.Body) + assert.Nil(t, err) + + csrfToken := findCsrfToken(html) + assert.NotEqual(t, "", csrfToken) + + formData := url.Values{ + "current-password": {"password"}, + "new-password": {"MyNewSecurePassword1!"}, + "csrf-token": {csrfToken}, + } + + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+sessionId) + req.Header.Set("HX-Request", "true") + resp, err = httpClient.Do(req) + assert.Nil(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var sessionIds []string + sessions, err := db.Query("SELECT session_id FROM session ORDER BY session_id") + assert.Nil(t, err) + for sessions.Next() { + var sessionId string + err = sessions.Scan(&sessionId) + assert.Nil(t, err) + sessionIds = append(sessionIds, sessionId) + } + + assert.Equal(t, 2, len(sessionIds)) + assert.Equal(t, "other", sessionIds[0]) + assert.Equal(t, "session-id", sessionIds[1]) }) } diff --git a/service/auth.go b/service/auth.go index d5e171a..9aafb78 100644 --- a/service/auth.go +++ b/service/auth.go @@ -18,11 +18,11 @@ import ( ) var ( - ErrInvaidCredentials = errors.New("invalid email or password") - ErrInvalidPassword = errors.New("password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character") - ErrInvalidEmail = errors.New("invalid email") - ErrAccountExists = errors.New("account already exists") - ErrSessionIdInvalid = errors.New("session ID is invalid") + ErrInvalidCredentials = errors.New("invalid email or password") + ErrInvalidPassword = errors.New("password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character") + ErrInvalidEmail = errors.New("invalid email") + ErrAccountExists = errors.New("account already exists") + ErrSessionIdInvalid = errors.New("session ID is invalid") ) type User struct { @@ -65,9 +65,9 @@ type Auth interface { SignInAnonymous() (*Session, error) SignOut(sessionId string) error - DeleteAccount(user *User) error + DeleteAccount(user *User, currPass string) error - ChangePassword(user *User, currPass, newPass string) error + ChangePassword(session *Session, currPass, newPass string) error SendForgotPasswordMail(email string) error ForgotPassword(token string, newPass string) error @@ -99,7 +99,7 @@ func (service AuthImpl) SignIn(email string, password string) (*Session, error) user, err := service.db.GetUserByEmail(email) if err != nil { if errors.Is(err, db.ErrNotFound) { - return nil, ErrInvaidCredentials + return nil, ErrInvalidCredentials } else { return nil, types.ErrInternal } @@ -108,7 +108,7 @@ func (service AuthImpl) SignIn(email string, password string) (*Session, error) hash := GetHashPassword(password, user.Salt) if subtle.ConstantTimeCompare(hash, user.Password) == 0 { - return nil, ErrInvaidCredentials + return nil, ErrInvalidCredentials } session, err := service.createSession(user.Id) @@ -299,9 +299,19 @@ func (service AuthImpl) SignOut(sessionId string) error { return service.db.DeleteSession(sessionId) } -func (service AuthImpl) DeleteAccount(user *User) error { +func (service AuthImpl) DeleteAccount(user *User, currPass string) error { - err := service.db.DeleteUser(user.Id) + userDb, err := service.db.GetUser(user.Id) + if err != nil { + return types.ErrInternal + } + + currHash := GetHashPassword(currPass, userDb.Salt) + if subtle.ConstantTimeCompare(currHash, userDb.Password) == 0 { + return ErrInvalidCredentials + } + + err = service.db.DeleteUser(user.Id) if err != nil { return err } @@ -311,7 +321,7 @@ func (service AuthImpl) DeleteAccount(user *User) error { return nil } -func (service AuthImpl) ChangePassword(user *User, currPass, newPass string) error { +func (service AuthImpl) ChangePassword(session *Session, currPass, newPass string) error { if !isPasswordValid(newPass) { return ErrInvalidPassword @@ -321,18 +331,18 @@ func (service AuthImpl) ChangePassword(user *User, currPass, newPass string) err return ErrInvalidPassword } - _, err := service.SignIn(user.Email, currPass) + userDb, err := service.db.GetUser(session.User.Id) if err != nil { return err } - userDb, err := service.db.GetUser(user.Id) - if err != nil { - return err + currHash := GetHashPassword(currPass, userDb.Salt) + + if subtle.ConstantTimeCompare(currHash, userDb.Password) == 0 { + return ErrInvalidCredentials } newHash := GetHashPassword(newPass, userDb.Salt) - userDb.Password = newHash err = service.db.UpdateUser(userDb) @@ -340,11 +350,15 @@ func (service AuthImpl) ChangePassword(user *User, currPass, newPass string) err return err } + err = service.db.DeleteOtherSessions(session.User.Id, session.Id) + if err != nil { + return err + } + return nil } func (service AuthImpl) SendForgotPasswordMail(email string) error { - tokenStr, err := service.random.String(32) if err != nil { return err diff --git a/service/auth_test.go b/service/auth_test.go index 96ac697..513dcc8 100644 --- a/service/auth_test.go +++ b/service/auth_test.go @@ -80,7 +80,7 @@ func TestSignIn(t *testing.T) { _, err := underTest.SignIn("test@test.de", "wrong password") - assert.Equal(t, ErrInvaidCredentials, err) + assert.Equal(t, ErrInvalidCredentials, err) }) t.Run("should return ErrInvalidCretentials if user has not been found", func(t *testing.T) { t.Parallel() @@ -94,7 +94,7 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) _, err := underTest.SignIn("test", "test") - assert.Equal(t, ErrInvaidCredentials, err) + assert.Equal(t, ErrInvalidCredentials, err) }) t.Run("should forward ErrInternal on any other error", func(t *testing.T) { t.Parallel() diff --git a/utils/http.go b/utils/http.go index d370bc2..f2f692d 100644 --- a/utils/http.go +++ b/utils/http.go @@ -8,9 +8,10 @@ import ( "time" ) -func TriggerToast(w http.ResponseWriter, r *http.Request, class string, message string) { +func TriggerToast(w http.ResponseWriter, r *http.Request, class string, message string, statusCode int) { if isHtmx(r) { w.Header().Set("HX-Trigger", fmt.Sprintf(`{"toast": "%v|%v"}`, class, message)) + w.WriteHeader(statusCode) } else { log.Error("Trying to trigger toast in non-HTMX request") }