package service import ( "bytes" "crypto/rand" "database/sql" "encoding/base64" "log" "net/http" "net/mail" "strings" "time" "me-fit/template" "me-fit/template/auth" "github.com/a-h/templ" "github.com/google/uuid" "golang.org/x/crypto/argon2" ) type User struct { user_uuid uuid.UUID email string } func HandleSignInPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user_comp := UserInfoComp(verifySessionAndReturnUser(db, r)) signIn := auth.SignInOrUp(true) template.Layout(signIn, user_comp).Render(r.Context(), w) } } func HandleSignUpPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user_comp := UserInfoComp(verifySessionAndReturnUser(db, r)) signIn := auth.SignInOrUp(false) template.Layout(signIn, user_comp).Render(r.Context(), w) } } func UserInfoComp(user *User) templ.Component { if user != nil { return auth.UserComp(user.email) } else { return auth.UserComp("") } } func HandleSignUp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") _, err := mail.ParseAddress(email) if err != nil { http.Error(w, "Invalid email", http.StatusBadRequest) return } if len(password) < 8 || !strings.ContainsAny(password, "0123456789") || !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") || !strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") { http.Error(w, "Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character", http.StatusBadRequest) return } user_uuid, err := uuid.NewRandom() if err != nil { log.Printf("Could not generate UUID: %v", err) auth.Error("Internal Server Error").Render(r.Context(), w) return } salt := make([]byte, 16) rand.Read(salt) hash := getHashPassword(password, salt) _, err = db.Exec("INSERT INTO user (user_uuid, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, FALSE, FALSE, ?, ?, datetime())", user_uuid, email, hash, salt) if err != nil { if strings.Contains(err.Error(), "email") { auth.Error("Bad Request").Render(r.Context(), w) return } auth.Error("Internal Server Error").Render(r.Context(), w) log.Printf("Could not insert user: %v", err) return } result := tryCreateSessionAndSetCookie(w, db, user_uuid) if !result { return } w.Header().Add("HX-Redirect", "/") w.WriteHeader(http.StatusOK) } } func HandleSignIn(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") var result bool = true start := time.Now() var user_uuid uuid.UUID var saved_hash []byte var salt []byte err := db.QueryRow("SELECT user_uuid, password, salt FROM user WHERE email = ?", email).Scan(&user_uuid, &saved_hash, &salt) if err != nil { result = false } if result { new_hash := getHashPassword(password, salt) if !bytes.Equal(new_hash, saved_hash) { result = false } } if result { result := tryCreateSessionAndSetCookie(w, db, user_uuid) if !result { return } } duration := time.Since(start) time_to_wait := 100 - duration.Milliseconds() // It is important to sleep for a while to prevent timing attacks // If the email is correct, the server will calculate the hash, which will take some time // This way an attacker could guess emails when comparing the response time // Because of that, we cant use WriteHeader in the middle of the function. We have to wait until the end time.Sleep(time.Duration(time_to_wait) * time.Millisecond) if result { w.Header().Add("HX-Redirect", "/") w.WriteHeader(http.StatusOK) } else { auth.Error("Invalid email or password").Render(r.Context(), w) } } } func HandleSignOutComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { id := getSessionID(r) _, err := db.Exec("DELETE FROM session WHERE session_id = ?", id) if err != nil { log.Printf("Could not delete session: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } c := http.Cookie{ Name: "id", Value: "", MaxAge: -1, Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, Path: "/", } http.SetCookie(w, &c) auth.UserComp("").Render(r.Context(), w) } } // var ( // metricsAuthSignUp = promauto.NewCounterVec( // prometheus.CounterOpts{ // Name: "mefit_api_auth_signup_total", // Help: "The total number of auth signup api requests processed", // }, // []string{"result"}, // ) // // metricsError = promauto.NewCounterVec( // prometheus.CounterOpts{ // Name: "mefit_api_error_total", // Help: "The total number of errors", // }, // []string{"result"}, // ) // // // metricsAuthSignIn = promauto.NewCounterVec( // // prometheus.CounterOpts{ // // Name: "mefit_api_auth_signin_total", // // }, // // []string{"result"}, // // ) // ) func tryCreateSessionAndSetCookie(w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { var session_id_bytes []byte = make([]byte, 32) _, err := rand.Reader.Read(session_id_bytes) if err != nil { log.Printf("Could not generate session ID: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return false } session_id := base64.StdEncoding.EncodeToString(session_id_bytes) _, err = db.Exec("INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime())", session_id, user_uuid) if err != nil { log.Printf("Could not insert session: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return false } cookie := http.Cookie{ Name: "id", Value: session_id, MaxAge: 60 * 60 * 8, // 8 hours Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, Path: "/", } http.SetCookie(w, &cookie) return true } func getSessionID(r *http.Request) string { for _, c := range r.Cookies() { if c.Name == "id" { return c.Value } } return "" } func verifySessionAndReturnUser(db *sql.DB, r *http.Request) *User { session_id := getSessionID(r) if session_id == "" { return nil } var user User var created_at time.Time err := db.QueryRow(` SELECT u.user_uuid, u.email, s.created_at FROM session s INNER JOIN user u ON s.user_uuid = u.user_uuid WHERE session_id = ?`, session_id).Scan(&user.user_uuid, &user.email, &created_at) if err != nil { log.Printf("Could not verify session: %v", err) return nil } if created_at.Add(time.Duration(8 * time.Hour)).Before(time.Now()) { return nil } return &user } func getHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) }