diff --git a/handler.go b/handler.go index 20964d7..4a59910 100644 --- a/handler.go +++ b/handler.go @@ -33,11 +33,13 @@ func getHandler(db *sql.DB) http.Handler { router.Handle("/auth/verify", service.HandleSignUpVerifyPage(db)) // Hint for the user to verify their email router.Handle("/auth/delete-account", service.HandleDeleteAccountPage(db)) router.Handle("/auth/verify-email", service.HandleSignUpVerifyResponsePage(db)) // The link contained in the email + router.Handle("/auth/change-password", service.HandleChangePasswordPage(db)) router.Handle("/api/auth/signup", service.HandleSignUpComp(db)) router.Handle("/api/auth/signin", service.HandleSignInComp(db)) router.Handle("/api/auth/signout", service.HandleSignOutComp(db)) router.Handle("/api/auth/delete-account", service.HandleDeleteAccountComp(db)) router.Handle("/api/auth/verify-resend", service.HandleVerifyResendComp(db)) + router.Handle("/api/auth/change-password", service.HandleChangePasswordComp(db)) return middleware.Logging(middleware.EnableCors(router)) } diff --git a/service/auth.go b/service/auth.go index 338e77c..622c9b2 100644 --- a/service/auth.go +++ b/service/auth.go @@ -6,6 +6,7 @@ import ( "crypto/subtle" "database/sql" "encoding/base64" + "errors" "net/http" "net/mail" "strings" @@ -87,7 +88,7 @@ func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc { func HandleDeleteAccountPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - // An enverified email should be able to delete their account + // An unverified email should be able to delete their account user := utils.GetUserFromSession(db, r) if user == nil { utils.DoRedirect(w, r, "/auth/signin") @@ -96,7 +97,7 @@ func HandleDeleteAccountPage(db *sql.DB) http.HandlerFunc { comp := auth.DeleteAccountComp() err := template.Layout(comp, userComp).Render(r.Context(), w) if err != nil { - utils.LogError("Failed to render verify page", err) + utils.LogError("Failed to render delete account page", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) } } @@ -145,6 +146,24 @@ func HandleSignUpVerifyResponsePage(db *sql.DB) http.HandlerFunc { } } +func HandleChangePasswordPage(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + + user := utils.GetUserFromSession(db, r) + if user == nil { + utils.DoRedirect(w, r, "/auth/signin") + } else { + userComp := UserInfoComp(user) + comp := auth.ChangePasswordComp() + err := template.Layout(comp, userComp).Render(r.Context(), w) + if err != nil { + utils.LogError("Failed to render change password page", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + } + } +} + func UserInfoComp(user *types.User) templ.Component { if user != nil { @@ -165,12 +184,9 @@ func HandleSignUpComp(db *sql.DB) http.HandlerFunc { return } - if len(password) < 8 || - !strings.ContainsAny(password, "0123456789") || - !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || - !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") || - !strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") { - http.Error(w, "Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character", http.StatusBadRequest) + err = checkPassword(password) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -360,6 +376,59 @@ func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc { } } +func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := utils.GetUserFromSession(db, r) + if user == nil { + utils.DoRedirect(w, r, "/auth/signin") + return + } + + currPass := r.FormValue("current-password") + newPass := r.FormValue("new-password") + + err := checkPassword(newPass) + if err != nil { + utils.TriggerToast(w, r, "error", err.Error()) + return + } + + if currPass == newPass { + utils.TriggerToast(w, r, "error", "Please use a new password") + return + } + + var ( + storedHash []byte + salt []byte + ) + + err = db.QueryRow("SELECT password, salt FROM user WHERE user_uuid = ?", user.Id).Scan(&storedHash, &salt) + if err != nil { + utils.LogError("Could not get password", err) + utils.TriggerToast(w, r, "error", "Internal Server Error") + return + } + + currHash := getHashPassword(currPass, salt) + if subtle.ConstantTimeCompare(currHash, storedHash) == 0 { + utils.TriggerToast(w, r, "error", "Current Password is not correct") + return + } + + newHash := getHashPassword(newPass, salt) + + _, err = db.Exec("UPDATE user SET password = ? WHERE user_uuid = ?", newHash, user.Id) + if err != nil { + utils.LogError("Could not update password", err) + utils.TriggerToast(w, r, "error", "Internal Server Error") + return + } + + utils.TriggerToast(w, r, "success", "Password changed") + } +} + func sendVerificationEmail(db *sql.DB, userId string, email string) { var token string @@ -434,3 +503,16 @@ func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sq func getHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) } + +func checkPassword(password string) error { + + if len(password) < 8 || + !strings.ContainsAny(password, "0123456789") || + !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || + !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") || + !strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") { + return errors.New("Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character") + } else { + return nil + } +} diff --git a/template/auth/change_password.templ b/template/auth/change_password.templ new file mode 100644 index 0000000..2d56221 --- /dev/null +++ b/template/auth/change_password.templ @@ -0,0 +1,22 @@ +package auth + +templ ChangePasswordComp() { +
+} diff --git a/template/auth/user.templ b/template/auth/user.templ index ed40d10..315b5a7 100644 --- a/template/auth/user.templ +++ b/template/auth/user.templ @@ -23,6 +23,9 @@ templ UserComp(user string) {