package service import ( "crypto/rand" "crypto/subtle" "database/sql" "encoding/base64" "log/slog" "net/http" "net/mail" "strings" "time" "me-fit/template" "me-fit/template/auth" "github.com/a-h/templ" "github.com/google/uuid" "golang.org/x/crypto/argon2" ) type User struct { id uuid.UUID email string } func HandleSignInPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user_comp := UserInfoComp(verifySessionAndReturnUser(db, r)) signIn := auth.SignInOrUp(true) template.Layout(signIn, user_comp).Render(r.Context(), w) } } func HandleSignUpPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user_comp := UserInfoComp(verifySessionAndReturnUser(db, r)) signIn := auth.SignInOrUp(false) template.Layout(signIn, user_comp).Render(r.Context(), w) } } func UserInfoComp(user *User) templ.Component { if user != nil { return auth.UserComp(user.email) } else { return auth.UserComp("") } } func HandleSignUpComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") _, err := mail.ParseAddress(email) if err != nil { http.Error(w, "Invalid email", http.StatusBadRequest) return } if len(password) < 8 || !strings.ContainsAny(password, "0123456789") || !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") || !strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") { http.Error(w, "Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character", http.StatusBadRequest) return } userId, err := uuid.NewRandom() if err != nil { slog.Error("Could not generate UUID: %v", err) auth.Error("Internal Server Error").Render(r.Context(), w) return } salt := make([]byte, 16) _, err = rand.Read(salt) if err != nil { slog.Error("Could not generate salt: %v", err) auth.Error("Internal Server Error").Render(r.Context(), w) return } hash := getHashPassword(password, salt) _, err = db.Exec("INSERT INTO user (user_uuid, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, FALSE, FALSE, ?, ?, datetime())", userId, email, hash, salt) if err != nil { // This does leak information about the email being in use, though not specifically stated // It needs to be refacoteres to "If the email is not already in use, an email has been send to your address", or something // The happy path, currently a redirect, needs to send the same message! // Then it is also important to have the same compute time in both paths // Otherwise an attacker could guess emails when comparing the response time if strings.Contains(err.Error(), "email") { auth.Error("Bad Request").Render(r.Context(), w) return } auth.Error("Internal Server Error").Render(r.Context(), w) slog.Error("Could not insert user: %v", err) return } result := tryCreateSessionAndSetCookie(r, w, db, userId) if !result { return } w.Header().Add("HX-Redirect", "/") } } func HandleSignInComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") var result bool = true start := time.Now() var userId uuid.UUID var savedHash []byte var salt []byte err := db.QueryRow("SELECT user_uuid, password, salt FROM user WHERE email = ?", email).Scan(&userId, &savedHash, &salt) if err != nil { result = false } if result { new_hash := getHashPassword(password, salt) if subtle.ConstantTimeCompare(new_hash, savedHash) == 0 { result = false } } if result { result := tryCreateSessionAndSetCookie(r, w, db, userId) if !result { return } } duration := time.Since(start) timeToWait := 100 - duration.Milliseconds() // It is important to sleep for a while to prevent timing attacks // If the email is correct, the server will calculate the hash, which will take some time // This way an attacker could guess emails when comparing the response time // Because of that, we cant use WriteHeader in the middle of the function. We have to wait until the end // Unfortunatly this makes the code harder to read time.Sleep(time.Duration(timeToWait) * time.Millisecond) if result { w.Header().Add("HX-Redirect", "/") w.WriteHeader(http.StatusOK) } else { auth.Error("Invalid email or password").Render(r.Context(), w) } } } func HandleSignOutComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := getSessionID(r) _, err := db.Exec("DELETE FROM session WHERE session_id = ?", id) if err != nil { slog.Error("Could not delete session: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } c := http.Cookie{ Name: "id", Value: "", MaxAge: -1, Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, Path: "/", } http.SetCookie(w, &c) auth.UserComp("").Render(r.Context(), w) } } func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { var session_id_bytes []byte = make([]byte, 32) _, err := rand.Reader.Read(session_id_bytes) if err != nil { slog.Error("Could not generate session ID: %v", err) auth.Error("Internal Server Error").Render(r.Context(), w) return false } session_id := base64.StdEncoding.EncodeToString(session_id_bytes) _, err = db.Exec("INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime())", session_id, user_uuid) if err != nil { slog.Error("Could not insert session: %v", err) auth.Error("Internal Server Error").Render(r.Context(), w) return false } cookie := http.Cookie{ Name: "id", Value: session_id, MaxAge: 60 * 60 * 8, // 8 hours Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, Path: "/", } http.SetCookie(w, &cookie) return true } func getSessionID(r *http.Request) string { for _, c := range r.Cookies() { if c.Name == "id" { return c.Value } } return "" } func verifySessionAndReturnUser(db *sql.DB, r *http.Request) *User { sessionId := getSessionID(r) if sessionId == "" { return nil } var user User var createdAt time.Time err := db.QueryRow(` SELECT u.user_uuid, u.email, s.created_at FROM session s INNER JOIN user u ON s.user_uuid = u.user_uuid WHERE session_id = ?`, sessionId).Scan(&user.id, &user.email, &createdAt) if err != nil { slog.Error("Could not verify session: " + err.Error()) return nil } if createdAt.Add(time.Duration(8 * time.Hour)).Before(time.Now()) { return nil } return &user } func getHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) }