From e53a0e5cf73678796f4e09edd4bce398c92db4bf Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Fri, 20 Sep 2024 22:19:02 +0200 Subject: [PATCH] chore(auth): add service structs and test error handling --- handler/auth.go | 48 +++++++------- handler/default.go | 11 +++- service/auth.go | 154 ++++++++++++++++++++++++++++----------------- utils/ctypto.go | 16 ----- 4 files changed, 130 insertions(+), 99 deletions(-) delete mode 100644 utils/ctypto.go diff --git a/handler/auth.go b/handler/auth.go index ed9a9ef..16b5af0 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -15,32 +15,37 @@ import ( "net/url" ) -func authUi(db *sql.DB) http.Handler { +type AuthHandler struct { + db *sql.DB + service *service.AuthService +} + +func (a *AuthHandler) authUi() http.Handler { router := http.NewServeMux() - router.Handle("/auth/signin", handleSignInPage(db)) - router.Handle("/auth/signup", handleSignUpPage(db)) - router.Handle("/auth/verify", handleSignUpVerifyPage(db)) // Hint for the user to verify their email - router.Handle("/auth/delete-account", handleDeleteAccountPage(db)) - router.Handle("/auth/verify-email", handleSignUpVerifyResponsePage(db)) // The link contained in the email - router.Handle("/auth/change-password", handleChangePasswordPage(db)) - router.Handle("/auth/reset-password", handleResetPasswordPage(db)) - router.Handle("/", handleNotFound(db)) + router.Handle("/auth/signin", handleSignInPage(a.db)) + router.Handle("/auth/signup", handleSignUpPage(a.db)) + router.Handle("/auth/verify", handleSignUpVerifyPage(a.db)) // Hint for the user to verify their email + router.Handle("/auth/delete-account", handleDeleteAccountPage(a.db)) + router.Handle("/auth/verify-email", handleSignUpVerifyResponsePage(a.db)) // The link contained in the email + router.Handle("/auth/change-password", handleChangePasswordPage(a.db)) + router.Handle("/auth/reset-password", handleResetPasswordPage(a.db)) + router.Handle("/", handleNotFound(a.db)) return router } -func authApi(db *sql.DB) http.Handler { +func (a *AuthHandler) authApi() http.Handler { router := http.NewServeMux() - router.Handle("/api/auth/signup", handleSignUp(db)) - router.Handle("/api/auth/signin", handleSignIn(db)) - router.Handle("/api/auth/signout", handleSignOut(db)) - router.Handle("/api/auth/delete-account", handleDeleteAccount(db)) - router.Handle("/api/auth/verify-resend", handleVerifyResend(db)) - router.Handle("/api/auth/change-password", handleChangePassword(db)) - router.Handle("/api/auth/reset-password", handleResetPassword(db)) - router.Handle("/api/auth/reset-password-actual", handleActualResetPassword(db)) + router.Handle("/api/auth/signup", handleSignUp(a.db)) + router.Handle("/api/auth/signin", a.handleSignIn()) + router.Handle("/api/auth/signout", handleSignOut(a.db)) + router.Handle("/api/auth/delete-account", handleDeleteAccount(a.db)) + router.Handle("/api/auth/verify-resend", handleVerifyResend(a.db)) + router.Handle("/api/auth/change-password", handleChangePassword(a.db)) + router.Handle("/api/auth/reset-password", handleResetPassword(a.db)) + router.Handle("/api/auth/reset-password-actual", handleActualResetPassword(a.db)) return router } @@ -215,14 +220,15 @@ func createSessionCookie(sessionId types.SessionId) *http.Cookie { return &cookie } -func handleSignIn(db *sql.DB) http.HandlerFunc { +func (a *AuthHandler) handleSignIn() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") - sessionId, err := service.SignIn(db, email, password) + sessionId, err := a.service.SignIn(email, password) if err != nil { - utils.TriggerToast(w, r, "error", "Invalid username or password") + utils.TriggerToast(w, r, "error", err.Error()) + return } http.SetCookie(w, createSessionCookie(*sessionId)) diff --git a/handler/default.go b/handler/default.go index 1470263..083582c 100644 --- a/handler/default.go +++ b/handler/default.go @@ -13,13 +13,18 @@ import ( func GetHandler(db *sql.DB) http.Handler { router := http.NewServeMux() - router.HandleFunc("/$", handleIndex(db)) + router.HandleFunc("/{$}", handleIndex(db)) router.HandleFunc("/", handleNotFound(db)) router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/")))) - router.Handle("/auth/", authUi(db)) - router.Handle("/api/auth/", authApi(db)) + authHandler := AuthHandler{ + db: db, + service: service.NewAuthService(db), + } + + router.Handle("/auth/", authHandler.authUi()) + router.Handle("/api/auth/", authHandler.authApi()) router.Handle("/workout", authMiddleware(db, workoutUi(db))) router.Handle("/api/workout", authMiddleware(db, workoutApi(db))) diff --git a/service/auth.go b/service/auth.go index 27312f3..d8258a5 100644 --- a/service/auth.go +++ b/service/auth.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/subtle" "database/sql" + "encoding/base64" "errors" "log/slog" "net/http" @@ -23,60 +24,28 @@ import ( ) // TESTED + +var ( + ErrInternalServer = errors.New("Internal Server Error") + ErrInvalidUsernameOrPassword = errors.New("Invalid username or password") +) + // NOT TESTED -func SignUp(db *sql.DB, email string, password string) (*types.SessionId, error) { - _, err := mail.ParseAddress(email) - if err != nil { - return nil, errors.New("Invalid email") - } - - err = checkPassword(password) - if err != nil { - return nil, err - } - - userId, err := uuid.NewRandom() - if err != nil { - return nil, errors.Join(errors.New("Could not generate UUID for new user"), err) - } - - salt := make([]byte, 16) - _, err = rand.Read(salt) - if err != nil { - return nil, errors.Join(errors.New("Could not generate salt for new user"), err) - } - - hash := getHashPassword(password, salt) - - _, err = db.Exec("INSERT INTO user (user_uuid, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, FALSE, FALSE, ?, ?, datetime())", userId, email, hash, salt) - if err != nil { - // This does leak information about the email being in use, though not specifically stated - // It needs to be refacoteres to "If the email is not already in use, an email has been send to your address", or something - // The happy path, currently a redirect, needs to send the same message! - // Then it is also important to have the same compute time in both paths - // Otherwise an attacker could guess emails when comparing the response time - if strings.Contains(err.Error(), "email") { - return nil, errors.New("Bad Request") - } - - return nil, errors.Join(errors.New("Could not insert user"), err) - } - - sessionId, err := createSession(db, userId) - if err != nil { - return nil, err - } - - // Send verification email as a goroutine - go SendVerificationEmail(db, userId.String(), email) - return sessionId, nil +type AuthService struct { + db *sql.DB } -func SignIn(db *sql.DB, email string, password string) (*types.SessionId, error) { +func NewAuthService(db *sql.DB) *AuthService { + return &AuthService{ + db: db, + } +} + +func (a *AuthService) SignIn(email string, password string) (*types.SessionId, error) { start := time.Now() - sessionId, err := internalSignIn(db, email, password) + sessionId, err := a.internalSignIn(email, password) duration := time.Since(start) timeToWait := 100 - duration.Milliseconds() @@ -90,25 +59,29 @@ func SignIn(db *sql.DB, email string, password string) (*types.SessionId, error) return sessionId, err } -func internalSignIn(db *sql.DB, email string, password string) (*types.SessionId, error) { +func (a *AuthService) internalSignIn(email string, password string) (*types.SessionId, error) { var ( userId uuid.UUID savedHash []byte salt []byte ) - err := db.QueryRow("SELECT user_uuid, password, salt FROM user WHERE email = ?", email).Scan(&userId, &savedHash, &salt) + err := a.db.QueryRow("SELECT user_uuid, password, salt FROM user WHERE email = ?", email).Scan(&userId, &savedHash, &salt) if err != nil { - return nil, err + if err == sql.ErrNoRows { + return nil, ErrInvalidUsernameOrPassword + } + utils.LogError("Could not query user on sign in", err) + return nil, ErrInternalServer } new_hash := getHashPassword(password, salt) if subtle.ConstantTimeCompare(new_hash, savedHash) == 0 { - return nil, errors.New("Invalid Request") + return nil, errors.Join(ErrInvalidUsernameOrPassword) } - sessionId, err := createSession(db, userId) + sessionId, err := createSession(a.db, userId) if err != nil { return nil, err } @@ -144,6 +117,56 @@ func GetUserFromSessionId(db *sql.DB, sessionId types.SessionId) *types.User { return types.NewUser(userId, email, sessionId, emailVerified) } } +func SignUp(db *sql.DB, email string, password string) (*types.SessionId, error) { + _, err := mail.ParseAddress(email) + if err != nil { + return nil, errors.Join(errors.New("Invalid email")) + } + + err = checkPassword(password) + if err != nil { + return nil, err + } + + userId, err := uuid.NewRandom() + if err != nil { + utils.LogError("Could not generate UUID for new user", err) + return nil, ErrInternalServer + } + + salt := make([]byte, 16) + _, err = rand.Read(salt) + if err != nil { + utils.LogError("Could not generate salt for new user", err) + return nil, ErrInternalServer + } + + hash := getHashPassword(password, salt) + + _, err = db.Exec("INSERT INTO user (user_uuid, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, FALSE, FALSE, ?, ?, datetime())", userId, email, hash, salt) + if err != nil { + // This does leak information about the email being in use, though not specifically stated + // It needs to be refacoteres to "If the email is not already in use, an email has been send to your address", or something + // The happy path, currently a redirect, needs to send the same message! + // Then it is also important to have the same compute time in both paths + // Otherwise an attacker could guess emails when comparing the response time + if strings.Contains(err.Error(), "email") { + return nil, errors.New("Bad Request") + } + + utils.LogError("Could not insert user", err) + return nil, ErrInternalServer + } + + sessionId, err := createSession(db, userId) + if err != nil { + return nil, err + } + + // Send verification email as a goroutine + go SendVerificationEmail(db, userId.String(), email) + return sessionId, nil +} func ValidateEmail(db *sql.DB, token string) error { result, err := db.Exec(` @@ -320,7 +343,7 @@ func ResetPassword(db *sql.DB, email string) (string, error) { return "", errors.New("Please enter an email") } - token, err := utils.RandomToken() + token, err := randomToken() if err != nil { return "", errors.Join(errors.New("Could not generate token"), err) } @@ -357,7 +380,7 @@ func SendVerificationEmail(db *sql.DB, userId string, email string) { } if token == "" { - token, err := utils.RandomToken() + token, err := randomToken() if err != nil { utils.LogError("Could not generate token", err) return @@ -380,20 +403,22 @@ func SendVerificationEmail(db *sql.DB, userId string, email string) { } func createSession(db *sql.DB, user_uuid uuid.UUID) (*types.SessionId, error) { - sessionId, err := utils.RandomToken() + sessionId, err := randomToken() if err != nil { - return nil, errors.Join(errors.New("Could not generate session ID"), err) + return nil, err } // Delete old inactive sessions _, err = db.Exec("DELETE FROM session WHERE created_at < datetime('now','-8 hours') AND user_uuid = ?", user_uuid) if err != nil { utils.LogError("Could not delete old sessions", err) + return nil, ErrInternalServer } _, err = db.Exec("INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime())", sessionId, user_uuid) if err != nil { - return nil, errors.Join(errors.New("Could not insert session"), err) + utils.LogError("Could not insert new session", err) + return nil, ErrInternalServer } sessionIdType := types.SessionId(sessionId) @@ -401,7 +426,7 @@ func createSession(db *sql.DB, user_uuid uuid.UUID) (*types.SessionId, error) { } func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { - sessionId, err := utils.RandomToken() + sessionId, err := randomToken() if err != nil { utils.LogError("Could not generate session ID", err) auth.Error("Internal Server Error").Render(r.Context(), w) @@ -467,3 +492,14 @@ func GetUserFromRequest(db *sql.DB, r *http.Request) *types.User { sessionId := getSessionID(r) return GetUserFromSessionId(db, sessionId) } + +func randomToken() (string, error) { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + utils.LogError("Could not generate random token", err) + return "", ErrInternalServer + } + + return base64.StdEncoding.EncodeToString(b), nil +} diff --git a/utils/ctypto.go b/utils/ctypto.go deleted file mode 100644 index 4f3cffe..0000000 --- a/utils/ctypto.go +++ /dev/null @@ -1,16 +0,0 @@ -package utils - -import ( - "crypto/rand" - "encoding/base64" -) - -func RandomToken() (string, error) { - b := make([]byte, 32) - _, err := rand.Read(b) - if err != nil { - return "", err - } - - return base64.StdEncoding.EncodeToString(b), nil -}