diff --git a/Dockerfile b/Dockerfile index 50a5057..7ce4ac2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM golang:1.23.1@sha256:2fe82a3f3e006b4f2a316c6a21f62b66e1330ae211d039bb8d1128 WORKDIR /me-fit RUN go install github.com/a-h/templ/cmd/templ@latest COPY . ./ -RUN templ generate && go build -o /me-fit/me-fit . +RUN templ generate && go test ./... && go build -o /me-fit/me-fit . FROM node:22.9.0@sha256:cbe2d5f94110cea9817dd8c5809d05df49b4bd1aac5203f3594d88665ad37988 AS builder_node diff --git a/db/auth.go b/db/auth.go new file mode 100644 index 0000000..c557a7d --- /dev/null +++ b/db/auth.go @@ -0,0 +1,78 @@ +package db + +import ( + "database/sql" + "errors" + "me-fit/types" + "me-fit/utils" + "time" + + "github.com/google/uuid" +) + +var ( + ErrUserNotFound = errors.New("User not found") +) + +type User struct { + Id uuid.UUID + Email string + EmailVerified bool + EmailVerifiedAt time.Time + IsAdmin bool + Password []byte + Salt []byte + CreateAt time.Time +} + +func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User { + return &User{ + Id: id, + Email: email, + EmailVerified: emailVerified, + EmailVerifiedAt: emailVerifiedAt, + IsAdmin: isAdmin, + Password: password, + Salt: salt, + CreateAt: createAt, + } +} + +type DbAuth interface { + GetUser(email string) (*User, error) +} + +type DbAuthSqlite struct { + db *sql.DB +} + +func NewDbAuthSqlite(db *sql.DB) *DbAuthSqlite { + return &DbAuthSqlite{db: db} +} + +func (db DbAuthSqlite) GetUser(email string) (*User, error) { + var ( + userId uuid.UUID + emailVerified bool + emailVerifiedAt time.Time + isAdmin bool + password []byte + salt []byte + createdAt time.Time + ) + + err := db.db.QueryRow(` + SELECT user_uuid, email_verified, email_verified_at, password, salt, created_at + FROM user + WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, ErrUserNotFound + } else { + utils.LogError("SQL error GetUser", err) + return nil, types.ErrInternal + } + } + + return NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil +} diff --git a/db/auth_test.go b/db/auth_test.go new file mode 100644 index 0000000..bc2729e --- /dev/null +++ b/db/auth_test.go @@ -0,0 +1,69 @@ +package db + +import ( + "database/sql" + "me-fit/utils" + "reflect" + "testing" + "time" + + "github.com/google/uuid" +) + +func setupDb(t *testing.T) *sql.DB { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Error opening database: %v", err) + } + + utils.MustRunMigrations(db, "../") + + return db +} + +func TestGetUser(t *testing.T) { + t.Parallel() + + t.Run("should return UserNotFound", func(t *testing.T) { + t.Parallel() + db := setupDb(t) + defer db.Close() + + underTest := DbAuthSqlite{db: db} + + _, err := underTest.GetUser("someNonExistentEmail") + if err != ErrUserNotFound { + t.Errorf("Expected UserNotFound, got %v", err) + } + }) + + t.Run("should find user in database", func(t *testing.T) { + t.Parallel() + db := setupDb(t) + defer db.Close() + + underTest := DbAuthSqlite{db: db} + + verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) + createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) + user := NewUser(uuid.New(), "some@email.de", true, verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + + _, err := db.Exec(` + INSERT INTO user (user_uuid, email, email_verified, email_verified_at, is_admin, password, salt, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + `, user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt) + if err != nil { + t.Fatalf("Error inserting user: %v", err) + } + + actual, err := underTest.GetUser(user.Email) + if err != nil { + t.Fatalf("Error getting user: %v", err) + } + + if !reflect.DeepEqual(user, actual) { + t.Errorf("Expected %v, got %v", user, actual) + } + }) + +} diff --git a/handler/auth.go b/handler/auth.go index 9bb57cb..22630e1 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -2,26 +2,86 @@ package handler import ( "me-fit/service" + "me-fit/utils" + "time" "database/sql" "net/http" ) -func handleAuth(db *sql.DB, router *http.ServeMux) { - // Don't use auth middleware for these routes, as it makes redirecting very difficult, if the mail is not yet verified - router.Handle("/auth/signin", service.HandleSignInPage(db)) - router.Handle("/auth/signup", service.HandleSignUpPage(db)) - router.Handle("/auth/verify", service.HandleSignUpVerifyPage(db)) // Hint for the user to verify their email - router.Handle("/auth/delete-account", service.HandleDeleteAccountPage(db)) - router.Handle("/auth/verify-email", service.HandleSignUpVerifyResponsePage(db)) // The link contained in the email - router.Handle("/auth/change-password", service.HandleChangePasswordPage(db)) - router.Handle("/auth/reset-password", service.HandleResetPasswordPage(db)) - router.Handle("/api/auth/signup", service.HandleSignUpComp(db)) - router.Handle("/api/auth/signin", service.HandleSignInComp(db)) - router.Handle("/api/auth/signout", service.HandleSignOutComp(db)) - router.Handle("/api/auth/delete-account", service.HandleDeleteAccountComp(db)) - router.Handle("/api/auth/verify-resend", service.HandleVerifyResendComp(db)) - router.Handle("/api/auth/change-password", service.HandleChangePasswordComp(db)) - router.Handle("/api/auth/reset-password", service.HandleResetPasswordComp(db)) - router.Handle("/api/auth/reset-password-actual", service.HandleActualResetPasswordComp(db)) +type HandlerAuth interface { + handle(router *http.ServeMux) +} + +type HandlerAuthImpl struct { + db *sql.DB + service service.ServiceAuth +} + +func NewHandlerAuth(db *sql.DB, service service.ServiceAuth) HandlerAuth { + return HandlerAuthImpl{ + db: db, + service: service, + } +} + +func (handler HandlerAuthImpl) handle(router *http.ServeMux) { + // Don't use auth middleware for these routes, as it makes redirecting very difficult, if the mail is not yet verified + router.Handle("/auth/signin", service.HandleSignInPage(handler.db)) + router.Handle("/auth/signup", service.HandleSignUpPage(handler.db)) + router.Handle("/auth/verify", service.HandleSignUpVerifyPage(handler.db)) // Hint for the user to verify their email + router.Handle("/auth/delete-account", service.HandleDeleteAccountPage(handler.db)) + router.Handle("/auth/verify-email", service.HandleSignUpVerifyResponsePage(handler.db)) // The link contained in the email + router.Handle("/auth/change-password", service.HandleChangePasswordPage(handler.db)) + router.Handle("/auth/reset-password", service.HandleResetPasswordPage(handler.db)) + router.Handle("/api/auth/signup", service.HandleSignUpComp(handler.db)) + router.Handle("/api/auth/signin", handler.handleSignIn()) + router.Handle("/api/auth/signout", service.HandleSignOutComp(handler.db)) + router.Handle("/api/auth/delete-account", service.HandleDeleteAccountComp(handler.db)) + router.Handle("/api/auth/verify-resend", service.HandleVerifyResendComp(handler.db)) + router.Handle("/api/auth/change-password", service.HandleChangePasswordComp(handler.db)) + router.Handle("/api/auth/reset-password", service.HandleResetPasswordComp(handler.db)) + router.Handle("/api/auth/reset-password-actual", service.HandleActualResetPasswordComp(handler.db)) +} + +var ( + securityWaitDuration = 250 * time.Millisecond +) + +func (handler HandlerAuthImpl) handleSignIn() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*service.User, error) { + var email = r.FormValue("email") + var password = r.FormValue("password") + + user, err := handler.service.SignIn(email, password) + if err != nil { + return nil, err + } + + err = service.TryCreateSessionAndSetCookie(r, w, handler.db, user.Id) + if err != nil { + return nil, err + } + + return user, nil + }) + + if err != nil { + if err == service.ErrInvaidCredentials { + utils.TriggerToast(w, r, "error", "Invalid email or password") + http.Error(w, "Invalid email or password", http.StatusUnauthorized) + } else { + utils.LogError("Error signing in", err) + http.Error(w, "An error occurred", http.StatusInternalServerError) + } + return + } + + if user.EmailVerified { + utils.DoRedirect(w, r, "/") + } else { + utils.DoRedirect(w, r, "/auth/verify") + } + } } diff --git a/handler/default.go b/handler/default.go index e877d59..3810adc 100644 --- a/handler/default.go +++ b/handler/default.go @@ -1,6 +1,7 @@ package handler import ( + "me-fit/db" "me-fit/middleware" "me-fit/service" @@ -8,17 +9,19 @@ import ( "net/http" ) -func GetHandler(db *sql.DB) http.Handler { +func GetHandler(d *sql.DB) http.Handler { var router = http.NewServeMux() - router.HandleFunc("/", service.HandleIndexAnd404(db)) + router.HandleFunc("/", service.HandleIndexAnd404(d)) + + handlerAuth := NewHandlerAuth(d, service.NewServiceAuthImpl(db.NewDbAuthSqlite(d))) // Serve static files (CSS, JS and images) router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/")))) - handleWorkout(db, router) + handleWorkout(d, router) - handleAuth(db, router) + handlerAuth.handle(router) return middleware.Logging(middleware.EnableCors(router)) } diff --git a/main.go b/main.go index d67222a..fe9f036 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,7 @@ func main() { } defer db.Close() - utils.MustRunMigrations(db) + utils.MustRunMigrations(db, "") startPrometheus() diff --git a/service/auth.go b/service/auth.go index 024ec09..e8b2eaa 100644 --- a/service/auth.go +++ b/service/auth.go @@ -11,8 +11,8 @@ import ( "net/mail" "net/url" "strings" - "time" + "me-fit/db" "me-fit/template" "me-fit/template/auth" tempMail "me-fit/template/mail" @@ -24,6 +24,59 @@ import ( "golang.org/x/crypto/argon2" ) +var ( + ErrInvaidCredentials = errors.New("Invalid email or password") + ErrPasswordComplexity = errors.New("Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character") +) + +type User struct { + Id uuid.UUID + Email string + EmailVerified bool +} + +func NewUser(user *db.User) *User { + return &User{ + Id: user.Id, + Email: user.Email, + EmailVerified: user.EmailVerified, + } +} + +type ServiceAuth interface { + SignIn(email string, password string) (*User, error) +} + +type ServiceAuthImpl struct { + dbAuth db.DbAuth +} + +func NewServiceAuthImpl(dbAuth db.DbAuth) *ServiceAuthImpl { + return &ServiceAuthImpl{ + dbAuth: dbAuth, + } +} + +func (service ServiceAuthImpl) SignIn(email string, password string) (*User, error) { + + user, err := service.dbAuth.GetUser(email) + if err != nil { + if errors.Is(err, db.ErrUserNotFound) { + return nil, ErrInvaidCredentials + } else { + return nil, types.ErrInternal + } + } + + hash := getHashPassword(password, user.Salt) + + if subtle.ConstantTimeCompare(hash, user.Password) == 0 { + return nil, ErrInvaidCredentials + } + + return NewUser(user), nil +} + func HandleSignInPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) @@ -245,8 +298,8 @@ func HandleSignUpComp(db *sql.DB) http.HandlerFunc { return } - result := tryCreateSessionAndSetCookie(r, w, db, userId) - if !result { + err = TryCreateSessionAndSetCookie(r, w, db, userId) + if err != nil { return } @@ -257,61 +310,6 @@ func HandleSignUpComp(db *sql.DB) http.HandlerFunc { } } -func HandleSignInComp(db *sql.DB) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - var email = r.FormValue("email") - var password = r.FormValue("password") - - var result bool = true - start := time.Now() - - var ( - userId uuid.UUID - savedHash []byte - salt []byte - emailVerified bool - ) - err := db.QueryRow("SELECT user_uuid, password, salt, email_verified FROM user WHERE email = ?", email).Scan(&userId, &savedHash, &salt, &emailVerified) - if err != nil { - result = false - } - - if result { - new_hash := getHashPassword(password, salt) - - if subtle.ConstantTimeCompare(new_hash, savedHash) == 0 { - result = false - } - } - - if result { - result := tryCreateSessionAndSetCookie(r, w, db, userId) - if !result { - return - } - } - - duration := time.Since(start) - timeToWait := 100 - duration.Milliseconds() - // It is important to sleep for a while to prevent timing attacks - // If the email is correct, the server will calculate the hash, which will take some time - // This way an attacker could guess emails when comparing the response time - // Because of that, we cant use WriteHeader in the middle of the function. We have to wait until the end - // Unfortunatly this makes the code harder to read - time.Sleep(time.Duration(timeToWait) * time.Millisecond) - - if result { - if !emailVerified { - utils.DoRedirect(w, r, "/auth/verify") - } else { - utils.DoRedirect(w, r, "/") - } - } else { - auth.Error("Invalid email or password").Render(r.Context(), w) - } - } -} - func HandleSignOutComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := utils.GetUserFromSession(db, r) @@ -617,25 +615,24 @@ func sendVerificationEmail(db *sql.DB, userId string, email string) { utils.SendMail(email, "Welcome to ME-FIT", w.String()) } -func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { +func TryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) error { sessionId, err := utils.RandomToken() if err != nil { utils.LogError("Could not generate session ID", err) - auth.Error("Internal Server Error").Render(r.Context(), w) - return false + return types.ErrInternal } // Delete old inactive sessions _, err = db.Exec("DELETE FROM session WHERE created_at < datetime('now','-8 hours') AND user_uuid = ?", user_uuid) if err != nil { utils.LogError("Could not delete old sessions", err) + return types.ErrInternal } _, err = db.Exec("INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime())", sessionId, user_uuid) if err != nil { utils.LogError("Could not insert session", err) - auth.Error("Internal Server Error").Render(r.Context(), w) - return false + return types.ErrInternal } cookie := http.Cookie{ @@ -649,7 +646,7 @@ func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sq } http.SetCookie(w, &cookie) - return true + return nil } func getHashPassword(password string, salt []byte) []byte { @@ -663,7 +660,7 @@ func checkPassword(password string) error { !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") || !strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") { - return errors.New("Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character") + return ErrPasswordComplexity } else { return nil } diff --git a/service/auth_test.go b/service/auth_test.go new file mode 100644 index 0000000..295ada2 --- /dev/null +++ b/service/auth_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "me-fit/db" + "me-fit/types" + + "errors" + "testing" + "time" + + "github.com/google/uuid" +) + +type DbAuthStub struct { + user *db.User + err error +} + +func (d DbAuthStub) GetUser(email string) (*db.User, error) { + return d.user, d.err +} + +func TestSignIn(t *testing.T) { + t.Parallel() + t.Run("should return user if password is correct", func(t *testing.T) { + t.Parallel() + salt := []byte("salt") + stub := DbAuthStub{ + user: db.NewUser( + uuid.New(), + "test@test.de", + true, + time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), + false, + getHashPassword("password", salt), + salt, + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + ), + err: nil, + } + underTest := NewServiceAuthImpl(stub) + + actualUser, err := underTest.SignIn("test@test.de", "password") + if err != nil { + t.Errorf("Expected nil, got %v", err) + } + + expectedUser := User{ + Id: stub.user.Id, + Email: stub.user.Email, + EmailVerified: stub.user.EmailVerified, + } + if *actualUser != expectedUser { + t.Errorf("Expected %v, got %v", expectedUser, actualUser) + } + }) + + t.Run("should return ErrInvalidCretentials if password is not correct", func(t *testing.T) { + t.Parallel() + salt := []byte("salt") + stub := DbAuthStub{ + user: db.NewUser( + uuid.New(), + "test@test.de", + true, + time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), + false, + getHashPassword("password", salt), + salt, + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + ), + err: nil, + } + underTest := NewServiceAuthImpl(stub) + + _, err := underTest.SignIn("test@test.de", "wrong password") + if err != ErrInvaidCredentials { + t.Errorf("Expected %v, got %v", ErrInvaidCredentials, err) + } + }) + t.Run("should return ErrInvalidCretentials if user has not been found", func(t *testing.T) { + t.Parallel() + stub := DbAuthStub{ + user: nil, + err: db.ErrUserNotFound, + } + underTest := NewServiceAuthImpl(stub) + + _, err := underTest.SignIn("test", "test") + if err != ErrInvaidCredentials { + t.Errorf("Expected %v, got %v", ErrInvaidCredentials, err) + } + }) + t.Run("should forward ErrInternal on any other error", func(t *testing.T) { + t.Parallel() + stub := DbAuthStub{ + user: nil, + err: errors.New("Some error"), + } + underTest := NewServiceAuthImpl(stub) + + _, err := underTest.SignIn("test", "test") + if err != types.ErrInternal { + t.Errorf("Expected %v, got %v", types.ErrInternal, err) + } + }) +} diff --git a/types/types.go b/types/types.go index 2c5324f..d40f2dd 100644 --- a/types/types.go +++ b/types/types.go @@ -1,6 +1,14 @@ package types -import "github.com/google/uuid" +import ( + "errors" + + "github.com/google/uuid" +) + +var ( + ErrInternal = errors.New("Internal server error") +) type User struct { Id uuid.UUID diff --git a/utils/db.go b/utils/db.go index 90f782e..7ff44e4 100644 --- a/utils/db.go +++ b/utils/db.go @@ -9,14 +9,14 @@ import ( _ "github.com/golang-migrate/migrate/v4/source/file" ) -func MustRunMigrations(db *sql.DB) { +func MustRunMigrations(db *sql.DB, pathPrefix string) { driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { log.Fatal(err) } m, err := migrate.NewWithDatabaseInstance( - "file://./migration/", + "file://"+pathPrefix+"migration/", "", driver) if err != nil { diff --git a/utils/http.go b/utils/http.go index 406d0de..b6e29a4 100644 --- a/utils/http.go +++ b/utils/http.go @@ -90,6 +90,13 @@ func GetUserFromSession(db *sql.DB, r *http.Request) *types.User { } +func WaitMinimumTime[T interface{}](waitTime time.Duration, function func() (T, error)) (T, error) { + start := time.Now() + result, err := function() + time.Sleep(waitTime - time.Since(start)) + return result, err +} + func getSessionID(r *http.Request) string { for _, c := range r.Cookies() { if c.Name == "id" {