From bd721a1e48cf133c4f3969081fc8f2f31d252b77 Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Wed, 25 Sep 2024 22:45:04 +0200 Subject: [PATCH] chore(auth): start refactoring for testable code #181 --- db/auth.go | 78 +++++++++++++++++++++++++++++++++++++ db/auth_test.go | 69 +++++++++++++++++++++++++++++++++ handler/auth.go | 36 ++++++++++++++++- main.go | 2 +- service/auth.go | 100 +++++++++++++++++++++--------------------------- types/types.go | 10 ++++- utils/db.go | 4 +- 7 files changed, 237 insertions(+), 62 deletions(-) create mode 100644 db/auth.go create mode 100644 db/auth_test.go diff --git a/db/auth.go b/db/auth.go new file mode 100644 index 0000000..793f8a6 --- /dev/null +++ b/db/auth.go @@ -0,0 +1,78 @@ +package db + +import ( + "database/sql" + "errors" + "me-fit/types" + "me-fit/utils" + "time" + + "github.com/google/uuid" +) + +var ( + UserNotFound = errors.New("User not found") +) + +type User struct { + Id uuid.UUID + Email string + EmailVerified bool + EmailVerifiedAt time.Time + IsAdmin bool + Password []byte + Salt []byte + CreateAt time.Time +} + +func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User { + return &User{ + Id: id, + Email: email, + EmailVerified: emailVerified, + EmailVerifiedAt: emailVerifiedAt, + IsAdmin: isAdmin, + Password: password, + Salt: salt, + CreateAt: createAt, + } +} + +type Auth interface { + GetUser(email string) (*User, error) +} + +type AuthSqlite struct { + db *sql.DB +} + +func NewAuthSqlite(db *sql.DB) *AuthSqlite { + return &AuthSqlite{db: db} +} + +func (a AuthSqlite) GetUser(email string) (*User, error) { + var ( + userId uuid.UUID + emailVerified bool + emailVerifiedAt time.Time + isAdmin bool + password []byte + salt []byte + createdAt time.Time + ) + + err := a.db.QueryRow(` + SELECT user_uuid, email_verified, email_verified_at, password, salt, created_at + FROM user + WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, UserNotFound + } else { + utils.LogError("SQL error GetUser", err) + return nil, types.InternalServerError + } + } + + return NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil +} diff --git a/db/auth_test.go b/db/auth_test.go new file mode 100644 index 0000000..3dcaa45 --- /dev/null +++ b/db/auth_test.go @@ -0,0 +1,69 @@ +package db + +import ( + "database/sql" + "me-fit/utils" + "reflect" + "testing" + "time" + + "github.com/google/uuid" +) + +func setupDb(t *testing.T) *sql.DB { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening database: %v", err) + } + + utils.MustRunMigrations(db, "../") + + return db +} + +func TestGetUser(t *testing.T) { + t.Parallel() + + t.Run("should return UserNotFound", func(t *testing.T) { + t.Parallel() + db := setupDb(t) + defer db.Close() + + underTest := AuthSqlite{db: db} + + _, err := underTest.GetUser("someNonExistentEmail") + if err != UserNotFound { + t.Errorf("Expected UserNotFound, got %v", err) + } + }) + + t.Run("should find user in database", func(t *testing.T) { + t.Parallel() + db := setupDb(t) + defer db.Close() + + underTest := AuthSqlite{db: db} + + verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) + createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) + user := NewUser(uuid.New(), "some@email.de", true, verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + + _, err := db.Exec(` + INSERT INTO user (user_uuid, email, email_verified, email_verified_at, is_admin, password, salt, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + `, user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt) + if err != nil { + t.Fatalf("Error inserting user: %v", err) + } + + actual, err := underTest.GetUser(user.Email) + if err != nil { + t.Fatalf("Error getting user: %v", err) + } + + if !reflect.DeepEqual(user, actual) { + t.Errorf("Expected %v, got %v", user, actual) + } + }) + +} diff --git a/handler/auth.go b/handler/auth.go index 9bb57cb..5bb74f3 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -2,12 +2,22 @@ package handler import ( "me-fit/service" + "me-fit/utils" "database/sql" "net/http" ) +type AuthHandler struct { + db *sql.DB + service *service.AuthService +} + func handleAuth(db *sql.DB, router *http.ServeMux) { + a := AuthHandler{ + db: db, + service: service.NewAuthService(db), + } // Don't use auth middleware for these routes, as it makes redirecting very difficult, if the mail is not yet verified router.Handle("/auth/signin", service.HandleSignInPage(db)) router.Handle("/auth/signup", service.HandleSignUpPage(db)) @@ -17,7 +27,7 @@ func handleAuth(db *sql.DB, router *http.ServeMux) { router.Handle("/auth/change-password", service.HandleChangePasswordPage(db)) router.Handle("/auth/reset-password", service.HandleResetPasswordPage(db)) router.Handle("/api/auth/signup", service.HandleSignUpComp(db)) - router.Handle("/api/auth/signin", service.HandleSignInComp(db)) + router.Handle("/api/auth/signin", a.handleSignIn()) router.Handle("/api/auth/signout", service.HandleSignOutComp(db)) router.Handle("/api/auth/delete-account", service.HandleDeleteAccountComp(db)) router.Handle("/api/auth/verify-resend", service.HandleVerifyResendComp(db)) @@ -25,3 +35,27 @@ func handleAuth(db *sql.DB, router *http.ServeMux) { router.Handle("/api/auth/reset-password", service.HandleResetPasswordComp(db)) router.Handle("/api/auth/reset-password-actual", service.HandleActualResetPasswordComp(db)) } + +func (a AuthHandler) handleSignIn() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var email = r.FormValue("email") + var password = r.FormValue("password") + + user := a.service.SignIn(email, password) + + if user != nil { + result := service.TryCreateSessionAndSetCookie(r, w, a.db, user.Id) + if !result { + return + } + + if !user.EmailVerified { + utils.DoRedirect(w, r, "/auth/verify") + } else { + utils.DoRedirect(w, r, "/") + } + } else { + http.Error(w, "Invalid email or password", http.StatusUnauthorized) + } + } +} diff --git a/main.go b/main.go index d67222a..fe9f036 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,7 @@ func main() { } defer db.Close() - utils.MustRunMigrations(db) + utils.MustRunMigrations(db, "") startPrometheus() diff --git a/service/auth.go b/service/auth.go index 024ec09..c03b4f8 100644 --- a/service/auth.go +++ b/service/auth.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "me-fit/db" "me-fit/template" "me-fit/template/auth" tempMail "me-fit/template/mail" @@ -24,6 +25,46 @@ import ( "golang.org/x/crypto/argon2" ) +type AuthService struct { + db db.Auth +} + +func NewAuthService(d *sql.DB) *AuthService { + return &AuthService{ + db: db.NewAuthSqlite(d), + } +} + +func (a AuthService) SignIn(email string, password string) *db.User { + + var result bool = true + start := time.Now() + + user, err := a.db.GetUser(email) + if err != nil { + result = false + } + + if result { + new_hash := getHashPassword(password, user.Salt) + + if subtle.ConstantTimeCompare(new_hash, user.Password) == 0 { + result = false + } + } + + duration := time.Since(start) + timeToWait := 100 - duration.Milliseconds() + // It is important to sleep for a while to prevent timing attacks + // If the email is correct, the server will calculate the hash, which will take some time + // This way an attacker could guess emails when comparing the response time + // Because of that, we cant use WriteHeader in the middle of the function. We have to wait until the end + // Unfortunatly this makes the code harder to read + time.Sleep(time.Duration(timeToWait) * time.Millisecond) + + return user +} + func HandleSignInPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) @@ -245,7 +286,7 @@ func HandleSignUpComp(db *sql.DB) http.HandlerFunc { return } - result := tryCreateSessionAndSetCookie(r, w, db, userId) + result := TryCreateSessionAndSetCookie(r, w, db, userId) if !result { return } @@ -257,61 +298,6 @@ func HandleSignUpComp(db *sql.DB) http.HandlerFunc { } } -func HandleSignInComp(db *sql.DB) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - var email = r.FormValue("email") - var password = r.FormValue("password") - - var result bool = true - start := time.Now() - - var ( - userId uuid.UUID - savedHash []byte - salt []byte - emailVerified bool - ) - err := db.QueryRow("SELECT user_uuid, password, salt, email_verified FROM user WHERE email = ?", email).Scan(&userId, &savedHash, &salt, &emailVerified) - if err != nil { - result = false - } - - if result { - new_hash := getHashPassword(password, salt) - - if subtle.ConstantTimeCompare(new_hash, savedHash) == 0 { - result = false - } - } - - if result { - result := tryCreateSessionAndSetCookie(r, w, db, userId) - if !result { - return - } - } - - duration := time.Since(start) - timeToWait := 100 - duration.Milliseconds() - // It is important to sleep for a while to prevent timing attacks - // If the email is correct, the server will calculate the hash, which will take some time - // This way an attacker could guess emails when comparing the response time - // Because of that, we cant use WriteHeader in the middle of the function. We have to wait until the end - // Unfortunatly this makes the code harder to read - time.Sleep(time.Duration(timeToWait) * time.Millisecond) - - if result { - if !emailVerified { - utils.DoRedirect(w, r, "/auth/verify") - } else { - utils.DoRedirect(w, r, "/") - } - } else { - auth.Error("Invalid email or password").Render(r.Context(), w) - } - } -} - func HandleSignOutComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) @@ -617,7 +603,7 @@ func sendVerificationEmail(db *sql.DB, userId string, email string) { utils.SendMail(email, "Welcome to ME-FIT", w.String()) } -func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { +func TryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { sessionId, err := utils.RandomToken() if err != nil { utils.LogError("Could not generate session ID", err) diff --git a/types/types.go b/types/types.go index 2c5324f..13933b3 100644 --- a/types/types.go +++ b/types/types.go @@ -1,6 +1,14 @@ package types -import "github.com/google/uuid" +import ( + "errors" + + "github.com/google/uuid" +) + +var ( + InternalServerError = errors.New("Internal server error") +) type User struct { Id uuid.UUID diff --git a/utils/db.go b/utils/db.go index 90f782e..7ff44e4 100644 --- a/utils/db.go +++ b/utils/db.go @@ -9,14 +9,14 @@ import ( _ "github.com/golang-migrate/migrate/v4/source/file" ) -func MustRunMigrations(db *sql.DB) { +func MustRunMigrations(db *sql.DB, pathPrefix string) { driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { log.Fatal(err) } m, err := migrate.NewWithDatabaseInstance( - "file://./migration/", + "file://"+pathPrefix+"migration/", "", driver) if err != nil {