chore(auth): add test for retrieving session from db #181
All checks were successful
Build Docker Image / Explore-Gitea-Actions (push) Successful in 46s

This commit is contained in:
2024-09-18 23:07:01 +02:00
parent dbe687c105
commit bb9381433b
9 changed files with 229 additions and 82 deletions

View File

@@ -17,7 +17,7 @@ func authUi(db *sql.DB) http.Handler {
router.Handle("/auth/verify-email", service.HandleSignUpVerifyResponsePage(db)) // The link contained in the email router.Handle("/auth/verify-email", service.HandleSignUpVerifyResponsePage(db)) // The link contained in the email
router.Handle("/auth/change-password", service.HandleChangePasswordPage(db)) router.Handle("/auth/change-password", service.HandleChangePasswordPage(db))
router.Handle("/auth/reset-password", service.HandleResetPasswordPage(db)) router.Handle("/auth/reset-password", service.HandleResetPasswordPage(db))
router.Handle("/", service.HandleIndexAnd404(db)) router.Handle("/", handleNotFound(db))
return router return router
} }

View File

@@ -3,17 +3,19 @@ package handler
import ( import (
"me-fit/middleware" "me-fit/middleware"
"me-fit/service" "me-fit/service"
"me-fit/template"
"me-fit/utils"
"database/sql" "database/sql"
"net/http" "net/http"
) )
func GetHandler(db *sql.DB) http.Handler { func GetHandler(db *sql.DB) http.Handler {
var router = http.NewServeMux() router := http.NewServeMux()
router.HandleFunc("/", service.HandleIndexAnd404(db)) router.HandleFunc("/$", handleIndex(db))
router.HandleFunc("/", handleNotFound(db))
// Serve static files (CSS, JS and images)
router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/")))) router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/"))))
router.Handle("/auth/", authUi(db)) router.Handle("/auth/", authUi(db))
@@ -32,3 +34,33 @@ func GetHandler(db *sql.DB) http.Handler {
func auth(db *sql.DB, h http.Handler) http.Handler { func auth(db *sql.DB, h http.Handler) http.Handler {
return middleware.EnsureValidSession(db, h) return middleware.EnsureValidSession(db, h)
} }
func handleIndex(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user := service.GetUserFromRequest(db, r)
err := template.
Layout(template.Index(), service.UserInfoComp(user)).
Render(r.Context(), w)
if err != nil {
utils.LogError("Failed to render index", err)
http.Error(w, "Failed to render index", http.StatusInternalServerError)
}
}
}
func handleNotFound(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user := service.GetUserFromRequest(db, r)
err := template.
Layout(template.NotFound(), service.UserInfoComp(user)).
Render(r.Context(), w)
if err != nil {
utils.LogError("Failed to render index", err)
http.Error(w, "Failed to render index", http.StatusInternalServerError)
}
w.WriteHeader(http.StatusNotFound)
}
}

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"me-fit/service"
"me-fit/utils" "me-fit/utils"
"context" "context"
@@ -12,7 +13,7 @@ func EnsureValidSession(db *sql.DB, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := service.GetUserFromRequest(db, r)
if user == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return

View File

@@ -24,9 +24,42 @@ import (
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
) )
// TESTED
func GetUserFromSessionId(db *sql.DB, sessionId types.SessionId) *types.User {
if sessionId == "" {
return nil
}
var (
createdAt time.Time
userId uuid.UUID
email string
emailVerified bool
)
err := db.QueryRow(`
SELECT u.user_uuid, u.email, u.email_verified, s.created_at
FROM session s
INNER JOIN user u ON s.user_uuid = u.user_uuid
WHERE session_id = ?`, sessionId).Scan(&userId, &email, &emailVerified, &createdAt)
if err != nil {
slog.Warn("Could not verify session: " + err.Error())
return nil
}
if createdAt.Add(time.Duration(8 * time.Hour)).Before(time.Now()) {
return nil
} else {
return types.NewUser(userId, email, sessionId, emailVerified)
}
}
// NOT TESTED
func HandleSignInPage(db *sql.DB) http.HandlerFunc { func HandleSignInPage(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user == nil { if user == nil {
userComp := UserInfoComp(nil) userComp := UserInfoComp(nil)
@@ -48,7 +81,7 @@ func HandleSignInPage(db *sql.DB) http.HandlerFunc {
func HandleSignUpPage(db *sql.DB) http.HandlerFunc { func HandleSignUpPage(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user == nil { if user == nil {
userComp := UserInfoComp(nil) userComp := UserInfoComp(nil)
@@ -70,7 +103,7 @@ func HandleSignUpPage(db *sql.DB) http.HandlerFunc {
func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc { func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
} else if user.EmailVerified { } else if user.EmailVerified {
@@ -90,7 +123,7 @@ func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc {
func HandleDeleteAccountPage(db *sql.DB) http.HandlerFunc { func HandleDeleteAccountPage(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// An unverified email should be able to delete their account // An unverified email should be able to delete their account
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
} else { } else {
@@ -152,7 +185,7 @@ func HandleChangePasswordPage(db *sql.DB) http.HandlerFunc {
isPasswordReset := r.URL.Query().Has("token") isPasswordReset := r.URL.Query().Has("token")
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user == nil && !isPasswordReset { if user == nil && !isPasswordReset {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
} else { } else {
@@ -170,7 +203,7 @@ func HandleChangePasswordPage(db *sql.DB) http.HandlerFunc {
func HandleResetPasswordPage(db *sql.DB) http.HandlerFunc { func HandleResetPasswordPage(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user != nil { if user != nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
} else { } else {
@@ -314,7 +347,7 @@ func HandleSignInComp(db *sql.DB) http.HandlerFunc {
func HandleSignOutComp(db *sql.DB) http.HandlerFunc { func HandleSignOutComp(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user != nil { if user != nil {
_, err := db.Exec("DELETE FROM session WHERE session_id = ?", user.SessionId) _, err := db.Exec("DELETE FROM session WHERE session_id = ?", user.SessionId)
@@ -343,7 +376,7 @@ func HandleSignOutComp(db *sql.DB) http.HandlerFunc {
func HandleDeleteAccountComp(db *sql.DB) http.HandlerFunc { func HandleDeleteAccountComp(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return
@@ -409,7 +442,7 @@ func HandleDeleteAccountComp(db *sql.DB) http.HandlerFunc {
func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc { func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user == nil || user.EmailVerified { if user == nil || user.EmailVerified {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return
@@ -424,7 +457,7 @@ func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc {
func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc { func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r) user := GetUserFromRequest(db, r)
if user == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return
@@ -669,3 +702,19 @@ func checkPassword(password string) error {
return nil return nil
} }
} }
//TODO: delete
func getSessionID(r *http.Request) types.SessionId {
for _, c := range r.Cookies() {
if c.Name == "id" {
return types.SessionId(c.Value)
}
}
return ""
}
func GetUserFromRequest(db *sql.DB, r *http.Request) *types.User {
sessionId := getSessionID(r)
return GetUserFromSessionId(db, sessionId)
}

View File

@@ -1,9 +1,119 @@
package service package service
import ( import (
"me-fit/types"
"me-fit/utils"
"database/sql"
"testing" "testing"
"github.com/google/uuid"
) )
func mustSetup(t *testing.T) *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Could not open Database data.db: %v", err)
}
utils.MustRunMigrationsTest(db, "../")
return db
}
func TestGetUserFromSessionIfSessionNotExpired(t *testing.T) {
db := mustSetup(t)
defer db.Close()
expected := types.NewUser(uuid.New(), "email", "session_id", true)
db.Exec(`INSERT INTO user (
user_uuid, email, email_verified, email_verified_at,
is_admin, password, salt, created_at)
VAlUES (
?, ?, 1, datetime(),
0, "password", "salt", datetime())`, expected.Id, expected.Email)
db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '-2 hour'))`, expected.SessionId, expected.Id)
actual := GetUserFromSessionId(db, expected.SessionId)
if *actual != *expected {
t.Errorf("Expected %v, got %v", *expected, *actual)
}
}
func TestGetUserFromSessionIfSessionInFuture(t *testing.T) {
db := mustSetup(t)
defer db.Close()
expected := types.NewUser(uuid.New(), "email", "session_id", true)
db.Exec(`INSERT INTO user (
user_uuid, email, email_verified, email_verified_at,
is_admin, password, salt, created_at)
VAlUES (
?, ?, 1, datetime(),
0, "password", "salt", datetime())`, expected.Id, expected.Email)
db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '+2 hour'))`, expected.SessionId, expected.Id)
actual := GetUserFromSessionId(db, expected.SessionId)
if *actual != *expected {
t.Errorf("Expected %v, got %v", *expected, *actual)
}
}
func TestFailGetUserFromSessionIfSessionExpired(t *testing.T) {
db := mustSetup(t)
defer db.Close()
expected := types.NewUser(uuid.New(), "email", "session_id", true)
db.Exec(`INSERT INTO user (
user_uuid, email, email_verified, email_verified_at,
is_admin, password, salt, created_at)
VAlUES (
?, ?, 1, datetime(),
0, "password", "salt", datetime())`, expected.Id, expected.Email)
db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '-8 hour', '-1 minute'))`, expected.SessionId, expected.Id)
actual := GetUserFromSessionId(db, expected.SessionId)
if actual != nil {
t.Errorf("Expected nil, got %v", *actual)
}
}
func TestGetUserFromSessionShouldFindCorrectUserBySessionId(t *testing.T) {
db := mustSetup(t)
defer db.Close()
expected := types.NewUser(uuid.New(), "email", "session_id", true)
userId2 := uuid.New()
db.Exec(`INSERT INTO user (
user_uuid, email, email_verified, email_verified_at,
is_admin, password, salt, created_at)
VAlUES (
?, ?, 1, datetime(),
0, "password", "salt", datetime()),
(
?, ?, 1, datetime(),
0, "password", "salt", datetime())
`, expected.Id, expected.Email, userId2, "email2")
db.Exec(`
INSERT INTO session (
session_id, user_uuid, created_at)
VALUES
(?, ?, datetime('now')),
(?, ?, datetime('now'))
`, expected.SessionId, expected.Id, expected.SessionId+"x", userId2)
actual := GetUserFromSessionId(db, expected.SessionId)
if *actual != *expected {
t.Errorf("Expected %v, got %v", *expected, *actual)
}
}
func TestValidPasswords(t *testing.T) { func TestValidPasswords(t *testing.T) {
passwords := []string{ passwords := []string{
"aB!'2d2y", //normal "aB!'2d2y", //normal

View File

@@ -1,32 +0,0 @@
package service
import (
"database/sql"
"me-fit/template"
"me-fit/utils"
"net/http"
"github.com/a-h/templ"
)
func HandleIndexAnd404(db *sql.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user := utils.GetUserFromSession(db, r)
var comp templ.Component = nil
userComp := UserInfoComp(user)
if r.URL.Path != "/" {
comp = template.Layout(template.NotFound(), userComp)
w.WriteHeader(http.StatusNotFound)
} else {
comp = template.Layout(template.Index(), userComp)
}
err := comp.Render(r.Context(), w)
if err != nil {
utils.LogError("Failed to render index", err)
http.Error(w, "Failed to render index", http.StatusInternalServerError)
}
}
}

View File

@@ -2,9 +2,20 @@ package types
import "github.com/google/uuid" import "github.com/google/uuid"
type SessionId string
type User struct { type User struct {
Id uuid.UUID Id uuid.UUID
Email string Email string
SessionId string SessionId SessionId
EmailVerified bool EmailVerified bool
} }
func NewUser(id uuid.UUID, email string, sessionId SessionId, emailVerified bool) *User {
return &User{
Id: id,
Email: email,
SessionId: sessionId,
EmailVerified: emailVerified,
}
}

View File

@@ -10,13 +10,20 @@ import (
) )
func MustRunMigrations(db *sql.DB) { func MustRunMigrations(db *sql.DB) {
mustRunMigrationsInternal(db, "")
}
func MustRunMigrationsTest(db *sql.DB, pathPrefix string) {
mustRunMigrationsInternal(db, "../")
}
func mustRunMigrationsInternal(db *sql.DB, pathPrefix string) {
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
m, err := migrate.NewWithDatabaseInstance( m, err := migrate.NewWithDatabaseInstance(
"file://./migration/", "file://./"+pathPrefix+"migration/",
"", "",
driver) driver)
if err != nil { if err != nil {

View File

@@ -1,12 +1,10 @@
package utils package utils
import ( import (
"database/sql"
"fmt" "fmt"
"log/slog" "log/slog"
"me-fit/types" "me-fit/types"
"net/http" "net/http"
"time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
@@ -61,39 +59,10 @@ func GetUser(r *http.Request) *types.User {
} }
} }
func GetUserFromSession(db *sql.DB, r *http.Request) *types.User { func GetSessionID(r *http.Request) types.SessionId {
sessionId := getSessionID(r)
if sessionId == "" {
return nil
}
var user types.User
var createdAt time.Time
user.SessionId = sessionId
err := db.QueryRow(`
SELECT u.user_uuid, u.email, u.email_verified, s.created_at
FROM session s
INNER JOIN user u ON s.user_uuid = u.user_uuid
WHERE session_id = ?`, sessionId).Scan(&user.Id, &user.Email, &user.EmailVerified, &createdAt)
if err != nil {
slog.Warn("Could not verify session: " + err.Error())
return nil
}
if createdAt.Add(time.Duration(8 * time.Hour)).Before(time.Now()) {
return nil
} else {
return &user
}
}
func getSessionID(r *http.Request) string {
for _, c := range r.Cookies() { for _, c := range r.Cookies() {
if c.Name == "id" { if c.Name == "id" {
return c.Value return types.SessionId(c.Value)
} }
} }
return "" return ""