diff --git a/handler/auth.go b/handler/auth.go index 71fa20b..8d28f66 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -17,7 +17,7 @@ func authUi(db *sql.DB) http.Handler { router.Handle("/auth/verify-email", service.HandleSignUpVerifyResponsePage(db)) // The link contained in the email router.Handle("/auth/change-password", service.HandleChangePasswordPage(db)) router.Handle("/auth/reset-password", service.HandleResetPasswordPage(db)) - router.Handle("/", service.HandleIndexAnd404(db)) + router.Handle("/", handleNotFound(db)) return router } diff --git a/handler/default.go b/handler/default.go index d4194b9..0f6aa66 100644 --- a/handler/default.go +++ b/handler/default.go @@ -3,17 +3,19 @@ package handler import ( "me-fit/middleware" "me-fit/service" + "me-fit/template" + "me-fit/utils" "database/sql" "net/http" ) func GetHandler(db *sql.DB) http.Handler { - var router = http.NewServeMux() + router := http.NewServeMux() - router.HandleFunc("/", service.HandleIndexAnd404(db)) + router.HandleFunc("/$", handleIndex(db)) + router.HandleFunc("/", handleNotFound(db)) - // Serve static files (CSS, JS and images) router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/")))) router.Handle("/auth/", authUi(db)) @@ -32,3 +34,33 @@ func GetHandler(db *sql.DB) http.Handler { func auth(db *sql.DB, h http.Handler) http.Handler { return middleware.EnsureValidSession(db, h) } + +func handleIndex(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := service.GetUserFromRequest(db, r) + + err := template. + Layout(template.Index(), service.UserInfoComp(user)). + Render(r.Context(), w) + if err != nil { + utils.LogError("Failed to render index", err) + http.Error(w, "Failed to render index", http.StatusInternalServerError) + } + } +} + +func handleNotFound(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := service.GetUserFromRequest(db, r) + + err := template. + Layout(template.NotFound(), service.UserInfoComp(user)). + Render(r.Context(), w) + if err != nil { + utils.LogError("Failed to render index", err) + http.Error(w, "Failed to render index", http.StatusInternalServerError) + } + + w.WriteHeader(http.StatusNotFound) + } +} diff --git a/middleware/auth.go b/middleware/auth.go index e48c010..4c7786a 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,6 +1,7 @@ package middleware import ( + "me-fit/service" "me-fit/utils" "context" @@ -12,7 +13,7 @@ func EnsureValidSession(db *sql.DB, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := service.GetUserFromRequest(db, r) if user == nil { utils.DoRedirect(w, r, "/auth/signin") return diff --git a/service/auth.go b/service/auth.go index e069846..ce45e59 100644 --- a/service/auth.go +++ b/service/auth.go @@ -24,9 +24,42 @@ import ( "golang.org/x/crypto/argon2" ) +// TESTED + +func GetUserFromSessionId(db *sql.DB, sessionId types.SessionId) *types.User { + if sessionId == "" { + return nil + } + + var ( + createdAt time.Time + userId uuid.UUID + email string + emailVerified bool + ) + + err := db.QueryRow(` + SELECT u.user_uuid, u.email, u.email_verified, s.created_at + FROM session s + INNER JOIN user u ON s.user_uuid = u.user_uuid + WHERE session_id = ?`, sessionId).Scan(&userId, &email, &emailVerified, &createdAt) + if err != nil { + slog.Warn("Could not verify session: " + err.Error()) + return nil + } + + if createdAt.Add(time.Duration(8 * time.Hour)).Before(time.Now()) { + return nil + } else { + return types.NewUser(userId, email, sessionId, emailVerified) + } +} + +// NOT TESTED + func HandleSignInPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user == nil { userComp := UserInfoComp(nil) @@ -48,7 +81,7 @@ func HandleSignInPage(db *sql.DB) http.HandlerFunc { func HandleSignUpPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user == nil { userComp := UserInfoComp(nil) @@ -70,7 +103,7 @@ func HandleSignUpPage(db *sql.DB) http.HandlerFunc { func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user == nil { utils.DoRedirect(w, r, "/auth/signin") } else if user.EmailVerified { @@ -90,7 +123,7 @@ func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc { func HandleDeleteAccountPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // An unverified email should be able to delete their account - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user == nil { utils.DoRedirect(w, r, "/auth/signin") } else { @@ -152,7 +185,7 @@ func HandleChangePasswordPage(db *sql.DB) http.HandlerFunc { isPasswordReset := r.URL.Query().Has("token") - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user == nil && !isPasswordReset { utils.DoRedirect(w, r, "/auth/signin") } else { @@ -170,7 +203,7 @@ func HandleChangePasswordPage(db *sql.DB) http.HandlerFunc { func HandleResetPasswordPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user != nil { utils.DoRedirect(w, r, "/auth/signin") } else { @@ -314,7 +347,7 @@ func HandleSignInComp(db *sql.DB) http.HandlerFunc { func HandleSignOutComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user != nil { _, err := db.Exec("DELETE FROM session WHERE session_id = ?", user.SessionId) @@ -343,7 +376,7 @@ func HandleSignOutComp(db *sql.DB) http.HandlerFunc { func HandleDeleteAccountComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user == nil { utils.DoRedirect(w, r, "/auth/signin") return @@ -409,7 +442,7 @@ func HandleDeleteAccountComp(db *sql.DB) http.HandlerFunc { func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user == nil || user.EmailVerified { utils.DoRedirect(w, r, "/auth/signin") return @@ -424,7 +457,7 @@ func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc { func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) + user := GetUserFromRequest(db, r) if user == nil { utils.DoRedirect(w, r, "/auth/signin") return @@ -669,3 +702,19 @@ func checkPassword(password string) error { return nil } } + +//TODO: delete + +func getSessionID(r *http.Request) types.SessionId { + for _, c := range r.Cookies() { + if c.Name == "id" { + return types.SessionId(c.Value) + } + } + return "" +} + +func GetUserFromRequest(db *sql.DB, r *http.Request) *types.User { + sessionId := getSessionID(r) + return GetUserFromSessionId(db, sessionId) +} diff --git a/service/auth_test.go b/service/auth_test.go index 7b62f3b..1665b6d 100644 --- a/service/auth_test.go +++ b/service/auth_test.go @@ -1,9 +1,119 @@ package service import ( + "me-fit/types" + "me-fit/utils" + + "database/sql" "testing" + + "github.com/google/uuid" ) +func mustSetup(t *testing.T) *sql.DB { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Could not open Database data.db: %v", err) + } + utils.MustRunMigrationsTest(db, "../") + return db +} + +func TestGetUserFromSessionIfSessionNotExpired(t *testing.T) { + db := mustSetup(t) + defer db.Close() + + expected := types.NewUser(uuid.New(), "email", "session_id", true) + + db.Exec(`INSERT INTO user ( + user_uuid, email, email_verified, email_verified_at, + is_admin, password, salt, created_at) + VAlUES ( + ?, ?, 1, datetime(), + 0, "password", "salt", datetime())`, expected.Id, expected.Email) + db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '-2 hour'))`, expected.SessionId, expected.Id) + + actual := GetUserFromSessionId(db, expected.SessionId) + + if *actual != *expected { + t.Errorf("Expected %v, got %v", *expected, *actual) + } +} + +func TestGetUserFromSessionIfSessionInFuture(t *testing.T) { + db := mustSetup(t) + defer db.Close() + + expected := types.NewUser(uuid.New(), "email", "session_id", true) + + db.Exec(`INSERT INTO user ( + user_uuid, email, email_verified, email_verified_at, + is_admin, password, salt, created_at) + VAlUES ( + ?, ?, 1, datetime(), + 0, "password", "salt", datetime())`, expected.Id, expected.Email) + db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '+2 hour'))`, expected.SessionId, expected.Id) + + actual := GetUserFromSessionId(db, expected.SessionId) + + if *actual != *expected { + t.Errorf("Expected %v, got %v", *expected, *actual) + } +} + +func TestFailGetUserFromSessionIfSessionExpired(t *testing.T) { + db := mustSetup(t) + defer db.Close() + + expected := types.NewUser(uuid.New(), "email", "session_id", true) + + db.Exec(`INSERT INTO user ( + user_uuid, email, email_verified, email_verified_at, + is_admin, password, salt, created_at) + VAlUES ( + ?, ?, 1, datetime(), + 0, "password", "salt", datetime())`, expected.Id, expected.Email) + db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '-8 hour', '-1 minute'))`, expected.SessionId, expected.Id) + + actual := GetUserFromSessionId(db, expected.SessionId) + + if actual != nil { + t.Errorf("Expected nil, got %v", *actual) + } +} + +func TestGetUserFromSessionShouldFindCorrectUserBySessionId(t *testing.T) { + db := mustSetup(t) + defer db.Close() + + expected := types.NewUser(uuid.New(), "email", "session_id", true) + userId2 := uuid.New() + + db.Exec(`INSERT INTO user ( + user_uuid, email, email_verified, email_verified_at, + is_admin, password, salt, created_at) + VAlUES ( + ?, ?, 1, datetime(), + 0, "password", "salt", datetime()), + ( + ?, ?, 1, datetime(), + 0, "password", "salt", datetime()) + `, expected.Id, expected.Email, userId2, "email2") + db.Exec(` + INSERT INTO session ( + session_id, user_uuid, created_at) + VALUES + (?, ?, datetime('now')), + (?, ?, datetime('now')) + `, expected.SessionId, expected.Id, expected.SessionId+"x", userId2) + + actual := GetUserFromSessionId(db, expected.SessionId) + + if *actual != *expected { + t.Errorf("Expected %v, got %v", *expected, *actual) + } +} + func TestValidPasswords(t *testing.T) { passwords := []string{ "aB!'2d2y", //normal diff --git a/service/index_and_404.go b/service/index_and_404.go deleted file mode 100644 index 480ca6d..0000000 --- a/service/index_and_404.go +++ /dev/null @@ -1,32 +0,0 @@ -package service - -import ( - "database/sql" - "me-fit/template" - "me-fit/utils" - "net/http" - - "github.com/a-h/templ" -) - -func HandleIndexAnd404(db *sql.DB) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - user := utils.GetUserFromSession(db, r) - - var comp templ.Component = nil - userComp := UserInfoComp(user) - - if r.URL.Path != "/" { - comp = template.Layout(template.NotFound(), userComp) - w.WriteHeader(http.StatusNotFound) - } else { - comp = template.Layout(template.Index(), userComp) - } - - err := comp.Render(r.Context(), w) - if err != nil { - utils.LogError("Failed to render index", err) - http.Error(w, "Failed to render index", http.StatusInternalServerError) - } - } -} diff --git a/types/types.go b/types/types.go index 2c5324f..013f1e7 100644 --- a/types/types.go +++ b/types/types.go @@ -2,9 +2,20 @@ package types import "github.com/google/uuid" +type SessionId string + type User struct { Id uuid.UUID Email string - SessionId string + SessionId SessionId EmailVerified bool } + +func NewUser(id uuid.UUID, email string, sessionId SessionId, emailVerified bool) *User { + return &User{ + Id: id, + Email: email, + SessionId: sessionId, + EmailVerified: emailVerified, + } +} diff --git a/utils/db.go b/utils/db.go index 90f782e..57e54c9 100644 --- a/utils/db.go +++ b/utils/db.go @@ -10,13 +10,20 @@ import ( ) func MustRunMigrations(db *sql.DB) { + mustRunMigrationsInternal(db, "") +} +func MustRunMigrationsTest(db *sql.DB, pathPrefix string) { + mustRunMigrationsInternal(db, "../") +} + +func mustRunMigrationsInternal(db *sql.DB, pathPrefix string) { driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { log.Fatal(err) } m, err := migrate.NewWithDatabaseInstance( - "file://./migration/", + "file://./"+pathPrefix+"migration/", "", driver) if err != nil { diff --git a/utils/http.go b/utils/http.go index 406d0de..5a3a517 100644 --- a/utils/http.go +++ b/utils/http.go @@ -1,12 +1,10 @@ package utils import ( - "database/sql" "fmt" "log/slog" "me-fit/types" "net/http" - "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -61,39 +59,10 @@ func GetUser(r *http.Request) *types.User { } } -func GetUserFromSession(db *sql.DB, r *http.Request) *types.User { - sessionId := getSessionID(r) - if sessionId == "" { - return nil - } - - var user types.User - var createdAt time.Time - - user.SessionId = sessionId - - err := db.QueryRow(` - SELECT u.user_uuid, u.email, u.email_verified, s.created_at - FROM session s - INNER JOIN user u ON s.user_uuid = u.user_uuid - WHERE session_id = ?`, sessionId).Scan(&user.Id, &user.Email, &user.EmailVerified, &createdAt) - if err != nil { - slog.Warn("Could not verify session: " + err.Error()) - return nil - } - - if createdAt.Add(time.Duration(8 * time.Hour)).Before(time.Now()) { - return nil - } else { - return &user - } - -} - -func getSessionID(r *http.Request) string { +func GetSessionID(r *http.Request) types.SessionId { for _, c := range r.Cookies() { if c.Name == "id" { - return c.Value + return types.SessionId(c.Value) } } return ""