package service import ( "crypto/rand" "crypto/subtle" "database/sql" "encoding/base64" "log/slog" "net/http" "net/mail" "strings" "time" "me-fit/template" "me-fit/template/auth" tempMail "me-fit/template/mail" "me-fit/types" "me-fit/utils" "github.com/a-h/templ" "github.com/google/uuid" "golang.org/x/crypto/argon2" ) func HandleSignInPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) if user == nil || !user.SessionValid { userComp := UserInfoComp(nil) signIn := auth.SignInOrUpComp(true) template.Layout(signIn, userComp).Render(r.Context(), w) return } else if !user.EmailVerified { utils.DoRedirect(w, r, "/auth/verify") return } else { utils.DoRedirect(w, r, "/") return } } } func HandleSignUpPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) if user == nil || !user.SessionValid { userComp := UserInfoComp(nil) signUpComp := auth.SignInOrUpComp(false) template.Layout(signUpComp, userComp).Render(r.Context(), w) return } else if !user.EmailVerified { utils.DoRedirect(w, r, "/auth/verify") return } else { utils.DoRedirect(w, r, "/") return } } } func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) if user == nil || !user.SessionValid { utils.DoRedirect(w, r, "/auth/signin") return } if user.EmailVerified { utils.DoRedirect(w, r, "/") return } else { userComp := UserInfoComp(user) signIn := auth.VerifyComp() template.Layout(signIn, userComp).Render(r.Context(), w) } } } func HandleSignUpVerifyResponsePage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { utils.DoRedirect(w, r, "/auth/verify") return } userId, err := uuid.Parse(code) if err != nil { utils.DoRedirect(w, r, "/auth/verify") return } _, err = db.Exec("UPDATE user SET email_verified = true, email_verified_at = datetime() WHERE user_uuid = ?", userId) utils.DoRedirect(w, r, "/") } } func UserInfoComp(user *types.User) templ.Component { if user != nil { return auth.UserComp(user.Email) } else { return auth.UserComp("") } } func HandleSignUpComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") _, err := mail.ParseAddress(email) if err != nil { http.Error(w, "Invalid email", http.StatusBadRequest) return } if len(password) < 8 || !strings.ContainsAny(password, "0123456789") || !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") || !strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") { http.Error(w, "Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character", http.StatusBadRequest) return } userId, err := uuid.NewRandom() if err != nil { slog.Error("Could not generate UUID: %v", err) auth.Error("Internal Server Error").Render(r.Context(), w) return } salt := make([]byte, 16) _, err = rand.Read(salt) if err != nil { slog.Error("Could not generate salt: %v", err) auth.Error("Internal Server Error").Render(r.Context(), w) return } hash := getHashPassword(password, salt) _, err = db.Exec("INSERT INTO user (user_uuid, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, FALSE, FALSE, ?, ?, datetime())", userId, email, hash, salt) if err != nil { // This does leak information about the email being in use, though not specifically stated // It needs to be refacoteres to "If the email is not already in use, an email has been send to your address", or something // The happy path, currently a redirect, needs to send the same message! // Then it is also important to have the same compute time in both paths // Otherwise an attacker could guess emails when comparing the response time if strings.Contains(err.Error(), "email") { auth.Error("Bad Request").Render(r.Context(), w) return } auth.Error("Internal Server Error").Render(r.Context(), w) slog.Error("Could not insert user: %v", err) return } result := tryCreateSessionAndSetCookie(r, w, db, userId) if !result { return } // Send verification email as a goroutine go sendVerificationEmail(db, r, userId.String(), email) utils.DoRedirect(w, r, "/auth/verify") } } func HandleSignInComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") var result bool = true start := time.Now() var ( userId uuid.UUID savedHash []byte salt []byte emailVerified bool ) err := db.QueryRow("SELECT user_uuid, password, salt, email_verified FROM user WHERE email = ?", email).Scan(&userId, &savedHash, &salt, &emailVerified) if err != nil { result = false } if result { new_hash := getHashPassword(password, salt) if subtle.ConstantTimeCompare(new_hash, savedHash) == 0 { result = false } } if result { result := tryCreateSessionAndSetCookie(r, w, db, userId) if !result { return } } duration := time.Since(start) timeToWait := 100 - duration.Milliseconds() // It is important to sleep for a while to prevent timing attacks // If the email is correct, the server will calculate the hash, which will take some time // This way an attacker could guess emails when comparing the response time // Because of that, we cant use WriteHeader in the middle of the function. We have to wait until the end // Unfortunatly this makes the code harder to read time.Sleep(time.Duration(timeToWait) * time.Millisecond) if result { if !emailVerified { utils.DoRedirect(w, r, "/auth/verify") } else { utils.DoRedirect(w, r, "/") } } else { auth.Error("Invalid email or password").Render(r.Context(), w) } } } func HandleSignOutComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) if user != nil { _, err := db.Exec("DELETE FROM session WHERE session_id = ?", user.SessionId) if err != nil { slog.Error("Could not delete session: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } } c := http.Cookie{ Name: "id", Value: "", MaxAge: -1, Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, Path: "/", } http.SetCookie(w, &c) w.Header().Add("HX-Redirect", "/") } } func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) if user == nil || !user.SessionValid || user.EmailVerified { utils.DoRedirect(w, r, "/auth/signin") return } sendVerificationEmail(db, r, user.Id.String(), user.Email) w.Write([]byte("
Verification email sent
")) } } func sendVerificationEmail(db *sql.DB, r *http.Request, userId string, email string) { registerComp := tempMail.Register(userId) var writer strings.Builder registerComp.Render(r.Context(), &writer) utils.SendMail(email, "Welcome to ME-FIT", writer.String()) } func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { var session_id_bytes []byte = make([]byte, 32) _, err := rand.Reader.Read(session_id_bytes) if err != nil { slog.Error("Could not generate session ID: %v", err) auth.Error("Internal Server Error").Render(r.Context(), w) return false } session_id := base64.StdEncoding.EncodeToString(session_id_bytes) // Delete old inactive sessions _, err = db.Exec("DELETE FROM session WHERE created_at < datetime('now','-8 hours') AND user_uuid = ?", user_uuid) if err != nil { slog.Error("Could not delete old sessions: " + err.Error()) } _, err = db.Exec("INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime())", session_id, user_uuid) if err != nil { slog.Error("Could not insert session: " + err.Error()) auth.Error("Internal Server Error").Render(r.Context(), w) return false } cookie := http.Cookie{ Name: "id", Value: session_id, MaxAge: 60 * 60 * 8, // 8 hours Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, Path: "/", } http.SetCookie(w, &cookie) return true } func getHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) }