package service import ( "context" "crypto/rand" "crypto/subtle" "database/sql" "encoding/base64" "errors" "log/slog" "net/http" "net/mail" "strings" "time" "me-fit/db" "me-fit/template/auth" tempMail "me-fit/template/mail" "me-fit/types" "me-fit/utils" "github.com/a-h/templ" "github.com/google/uuid" "golang.org/x/crypto/argon2" ) // TESTED var ( ErrInternalServer = errors.New("Internal Server Error") ErrInvalidUsernameOrPassword = errors.New("Invalid username or password") ) // NOT TESTED type AuthService struct { db *sql.DB dbSer *db.AuthDb } func NewAuthService(db *sql.DB) *AuthService { return &AuthService{ db: db, } } func (a *AuthService) SignUp(email string, password string) (*types.SessionId, error) { _, err := mail.ParseAddress(email) if err != nil { return nil, errors.Join(errors.New("Invalid email")) } err = checkPassword(password) if err != nil { return nil, err } userId, err := uuid.NewRandom() if err != nil { utils.LogError("Could not generate UUID for new user", err) return nil, ErrInternalServer } salt := make([]byte, 16) _, err = rand.Read(salt) if err != nil { utils.LogError("Could not generate salt for new user", err) return nil, ErrInternalServer } hash := getHashPassword(password, salt) err := a.dbSer.InsertUser(a.db, userId, email, hash, salt) _, err = a.db.Exec("INSERT INTO user (user_uuid, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, FALSE, FALSE, ?, ?, datetime())", userId, email, hash, salt) if err != nil { // This does leak information about the email being in use, though not specifically stated // It needs to be refacoteres to "If the email is not already in use, an email has been send to your address", or something // The happy path, currently a redirect, needs to send the same message! // Then it is also important to have the same compute time in both paths // Otherwise an attacker could guess emails when comparing the response time if strings.Contains(err.Error(), "email") { return nil, errors.New("Bad Request") } utils.LogError("Could not insert user", err) return nil, ErrInternalServer } sessionId, err := createSession(a.db, userId) if err != nil { return nil, err } // Send verification email as a goroutine go SendVerificationEmail(a.db, userId.String(), email) return sessionId, nil } func (a *AuthService) SignIn(email string, password string) (*types.SessionId, error) { start := time.Now() sessionId, err := a.internalSignIn(email, password) duration := time.Since(start) timeToWait := 100 - duration.Milliseconds() // It is important to sleep for a while to prevent timing attacks // If the email is correct, the server will calculate the hash, which will take some time // This way an attacker could guess emails when comparing the response time // Because of that, we cant use WriteHeader in the middle of the function. We have to wait until the end // Unfortunatly this makes the code harder to read time.Sleep(time.Duration(timeToWait) * time.Millisecond) return sessionId, err } func (a *AuthService) internalSignIn(email string, password string) (*types.SessionId, error) { var ( userId uuid.UUID savedHash []byte salt []byte ) err := a.db.QueryRow("SELECT user_uuid, password, salt FROM user WHERE email = ?", email).Scan(&userId, &savedHash, &salt) if err != nil { if err == sql.ErrNoRows { return nil, ErrInvalidUsernameOrPassword } utils.LogError("Could not query user on sign in", err) return nil, ErrInternalServer } new_hash := getHashPassword(password, salt) if subtle.ConstantTimeCompare(new_hash, savedHash) == 0 { return nil, errors.Join(ErrInvalidUsernameOrPassword) } sessionId, err := createSession(a.db, userId) if err != nil { return nil, err } return sessionId, nil } func GetUserFromSessionId(db *sql.DB, sessionId types.SessionId) *types.User { if sessionId == "" { return nil } var ( createdAt time.Time userId uuid.UUID email string emailVerified bool ) err := db.QueryRow(` SELECT u.user_uuid, u.email, u.email_verified, s.created_at FROM session s INNER JOIN user u ON s.user_uuid = u.user_uuid WHERE session_id = ?`, sessionId).Scan(&userId, &email, &emailVerified, &createdAt) if err != nil { slog.Warn("Could not verify session: " + err.Error()) return nil } if createdAt.Add(time.Duration(8 * time.Hour)).Before(time.Now()) { return nil } else { return types.NewUser(userId, email, sessionId, emailVerified) } } func ValidateEmail(db *sql.DB, token string) error { result, err := db.Exec(` UPDATE user SET email_verified = true, email_verified_at = datetime() WHERE user_uuid = ( SELECT user_uuid FROM user_token WHERE type = "email_verify" AND token = ? );`, token) if err != nil { return errors.Join(errors.New("Could not update user on verify response"), err) } i, err := result.RowsAffected() if err != nil { return errors.Join(errors.New("Could not get rows affected on verify response"), err) } if i == 0 { return errors.New("Token is invalid") } return nil } func UserInfoComp(user *types.User) templ.Component { if user != nil { return auth.UserComp(user.Email) } else { return auth.UserComp("") } } func SignOut(db *sql.DB, user *types.User) error { if user == nil { return nil } _, err := db.Exec("DELETE FROM session WHERE session_id = ?", user.SessionId) if err != nil { return errors.Join(errors.New("Could not delete session"), err) } return nil } func DeleteAccount(db *sql.DB, user *types.User, password string) error { if password == "" { return errors.New("Please enter your password") } var ( storedHash []byte salt []byte ) err := db.QueryRow("SELECT password, salt FROM user WHERE user_uuid = ?", user.Id).Scan(&storedHash, &salt) if err != nil { return errors.Join(errors.New("Could not get password"), err) } currHash := getHashPassword(password, salt) if subtle.ConstantTimeCompare(currHash, storedHash) == 0 { return errors.New("Password is not correct") } _, err = db.Exec("DELETE FROM workout WHERE user_id = ?", user.Id) if err != nil { return errors.Join(errors.New("Could not delete workouts"), err) } _, err = db.Exec("DELETE FROM user_token WHERE user_uuid = ?", user.Id) if err != nil { return errors.Join(errors.New("Could not delete tokens"), err) } _, err = db.Exec("DELETE FROM session WHERE user_uuid = ?", user.Id) if err != nil { return errors.Join(errors.New("Could not delete sessions"), err) } _, err = db.Exec("DELETE FROM user WHERE user_uuid = ?", user.Id) if err != nil { return errors.Join(errors.New("Could not delete user"), err) } go utils.SendMail(user.Email, "Account deleted", "Your account has been deleted") return nil } func ChangePassword(db *sql.DB, user *types.User, currPass string, newPass string) error { err := checkPassword(newPass) if err != nil { return err } if currPass == newPass { return errors.New("New password can not be the same as the current password") } var ( storedHash []byte salt []byte ) err = db.QueryRow("SELECT password, salt FROM user WHERE user_uuid = ?", user.Id).Scan(&storedHash, &salt) if err != nil { return errors.Join(errors.New("Could not get password"), err) } currHash := getHashPassword(currPass, salt) if subtle.ConstantTimeCompare(currHash, storedHash) == 0 { return errors.New("Current password is not correct") } newHash := getHashPassword(newPass, salt) _, err = db.Exec("UPDATE user SET password = ? WHERE user_uuid = ?", newHash, user.Id) if err != nil { return errors.Join(errors.New("Could not update password"), err) } return nil } func ActualResetPassword(db *sql.DB, token string, newPass string) error { err := checkPassword(newPass) if err != nil { return err } var ( userId uuid.UUID salt []byte ) err = db.QueryRow(` SELECT u.user_uuid, salt FROM user_token t INNER JOIN user u ON t.user_uuid = u.user_uuid WHERE t.token = ? AND t.type = 'password_reset' AND t.expires_at > datetime() `, token).Scan(&userId, &salt) if err != nil { return errors.Join(errors.New("Could not get user from token"), err) } _, err = db.Exec("DELETE FROM user_token WHERE token = ? AND type = 'password_reset'", token) if err != nil { return errors.Join(errors.New("Could not delete token"), err) } passHash := getHashPassword(newPass, salt) _, err = db.Exec("UPDATE user SET password = ? WHERE user_uuid = ?", passHash, userId) if err != nil { return errors.Join(errors.New("Could not update password"), err) } return nil } func ResetPassword(db *sql.DB, email string) (string, error) { if email == "" { return "", errors.New("Please enter an email") } token, err := randomToken() if err != nil { return "", errors.Join(errors.New("Could not generate token"), err) } res, err := db.Exec(` INSERT INTO user_token (user_uuid, type, token, created_at, expires_at) SELECT user_uuid, 'password_reset', ?, datetime(), datetime('now', '+15 minute') FROM user WHERE email = ? `, token, email) if err != nil { return "", errors.Join(errors.New("Could not insert token"), err) } i, err := res.RowsAffected() if err != nil { return "", errors.Join(errors.New("Could not get rows affected"), err) } if i == 0 { return "", nil } else { return token, nil } } func SendVerificationEmail(db *sql.DB, userId string, email string) { var token string err := db.QueryRow("SELECT token FROM user_token WHERE user_uuid = ? AND type = 'email_verify'", userId).Scan(&token) if err != nil && err != sql.ErrNoRows { utils.LogError("Could not get token", err) return } if token == "" { token, err := randomToken() if err != nil { utils.LogError("Could not generate token", err) return } _, err = db.Exec("INSERT INTO user_token (user_uuid, type, token, created_at) VALUES (?, 'email_verify', ?, datetime())", userId, token) if err != nil { utils.LogError("Could not insert token", err) return } } var w strings.Builder err = tempMail.Register(token).Render(context.Background(), &w) if err != nil { utils.LogError("Could not render welcome email", err) return } utils.SendMail(email, "Welcome to ME-FIT", w.String()) } func createSession(db *sql.DB, user_uuid uuid.UUID) (*types.SessionId, error) { sessionId, err := randomToken() if err != nil { return nil, err } // Delete old inactive sessions _, err = db.Exec("DELETE FROM session WHERE created_at < datetime('now','-8 hours') AND user_uuid = ?", user_uuid) if err != nil { utils.LogError("Could not delete old sessions", err) return nil, ErrInternalServer } _, err = db.Exec("INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime())", sessionId, user_uuid) if err != nil { utils.LogError("Could not insert new session", err) return nil, ErrInternalServer } sessionIdType := types.SessionId(sessionId) return &sessionIdType, nil } func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { sessionId, err := randomToken() if err != nil { utils.LogError("Could not generate session ID", err) auth.Error("Internal Server Error").Render(r.Context(), w) return false } // Delete old inactive sessions _, err = db.Exec("DELETE FROM session WHERE created_at < datetime('now','-8 hours') AND user_uuid = ?", user_uuid) if err != nil { utils.LogError("Could not delete old sessions", err) } _, err = db.Exec("INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime())", sessionId, user_uuid) if err != nil { utils.LogError("Could not insert session", err) auth.Error("Internal Server Error").Render(r.Context(), w) return false } cookie := http.Cookie{ Name: "id", Value: sessionId, MaxAge: 60 * 60 * 8, // 8 hours Secure: true, HttpOnly: true, SameSite: http.SameSiteStrictMode, Path: "/", } http.SetCookie(w, &cookie) return true } func getHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) } func checkPassword(password string) error { if len(password) < 8 || !strings.ContainsAny(password, "0123456789") || !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") || !strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") { return errors.New("Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character") } else { return nil } } //TODO: delete func getSessionID(r *http.Request) types.SessionId { for _, c := range r.Cookies() { if c.Name == "id" { return types.SessionId(c.Value) } } return "" } func GetUserFromRequest(db *sql.DB, r *http.Request) *types.User { sessionId := getSessionID(r) return GetUserFromSessionId(db, sessionId) } func randomToken() (string, error) { b := make([]byte, 32) _, err := rand.Read(b) if err != nil { utils.LogError("Could not generate random token", err) return "", ErrInternalServer } return base64.StdEncoding.EncodeToString(b), nil }