chore(auth): add test for retrieving session from db #181
All checks were successful
Build Docker Image / Explore-Gitea-Actions (push) Successful in 46s
All checks were successful
Build Docker Image / Explore-Gitea-Actions (push) Successful in 46s
This commit is contained in:
@@ -17,7 +17,7 @@ func authUi(db *sql.DB) http.Handler {
|
|||||||
router.Handle("/auth/verify-email", service.HandleSignUpVerifyResponsePage(db)) // The link contained in the email
|
router.Handle("/auth/verify-email", service.HandleSignUpVerifyResponsePage(db)) // The link contained in the email
|
||||||
router.Handle("/auth/change-password", service.HandleChangePasswordPage(db))
|
router.Handle("/auth/change-password", service.HandleChangePasswordPage(db))
|
||||||
router.Handle("/auth/reset-password", service.HandleResetPasswordPage(db))
|
router.Handle("/auth/reset-password", service.HandleResetPasswordPage(db))
|
||||||
router.Handle("/", service.HandleIndexAnd404(db))
|
router.Handle("/", handleNotFound(db))
|
||||||
|
|
||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,17 +3,19 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"me-fit/middleware"
|
"me-fit/middleware"
|
||||||
"me-fit/service"
|
"me-fit/service"
|
||||||
|
"me-fit/template"
|
||||||
|
"me-fit/utils"
|
||||||
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetHandler(db *sql.DB) http.Handler {
|
func GetHandler(db *sql.DB) http.Handler {
|
||||||
var router = http.NewServeMux()
|
router := http.NewServeMux()
|
||||||
|
|
||||||
router.HandleFunc("/", service.HandleIndexAnd404(db))
|
router.HandleFunc("/$", handleIndex(db))
|
||||||
|
router.HandleFunc("/", handleNotFound(db))
|
||||||
|
|
||||||
// Serve static files (CSS, JS and images)
|
|
||||||
router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/"))))
|
router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/"))))
|
||||||
|
|
||||||
router.Handle("/auth/", authUi(db))
|
router.Handle("/auth/", authUi(db))
|
||||||
@@ -32,3 +34,33 @@ func GetHandler(db *sql.DB) http.Handler {
|
|||||||
func auth(db *sql.DB, h http.Handler) http.Handler {
|
func auth(db *sql.DB, h http.Handler) http.Handler {
|
||||||
return middleware.EnsureValidSession(db, h)
|
return middleware.EnsureValidSession(db, h)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleIndex(db *sql.DB) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user := service.GetUserFromRequest(db, r)
|
||||||
|
|
||||||
|
err := template.
|
||||||
|
Layout(template.Index(), service.UserInfoComp(user)).
|
||||||
|
Render(r.Context(), w)
|
||||||
|
if err != nil {
|
||||||
|
utils.LogError("Failed to render index", err)
|
||||||
|
http.Error(w, "Failed to render index", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNotFound(db *sql.DB) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user := service.GetUserFromRequest(db, r)
|
||||||
|
|
||||||
|
err := template.
|
||||||
|
Layout(template.NotFound(), service.UserInfoComp(user)).
|
||||||
|
Render(r.Context(), w)
|
||||||
|
if err != nil {
|
||||||
|
utils.LogError("Failed to render index", err)
|
||||||
|
http.Error(w, "Failed to render index", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"me-fit/service"
|
||||||
"me-fit/utils"
|
"me-fit/utils"
|
||||||
|
|
||||||
"context"
|
"context"
|
||||||
@@ -12,7 +13,7 @@ func EnsureValidSession(db *sql.DB, next http.Handler) http.Handler {
|
|||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := service.GetUserFromRequest(db, r)
|
||||||
if user == nil {
|
if user == nil {
|
||||||
utils.DoRedirect(w, r, "/auth/signin")
|
utils.DoRedirect(w, r, "/auth/signin")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -24,9 +24,42 @@ import (
|
|||||||
"golang.org/x/crypto/argon2"
|
"golang.org/x/crypto/argon2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TESTED
|
||||||
|
|
||||||
|
func GetUserFromSessionId(db *sql.DB, sessionId types.SessionId) *types.User {
|
||||||
|
if sessionId == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
createdAt time.Time
|
||||||
|
userId uuid.UUID
|
||||||
|
email string
|
||||||
|
emailVerified bool
|
||||||
|
)
|
||||||
|
|
||||||
|
err := db.QueryRow(`
|
||||||
|
SELECT u.user_uuid, u.email, u.email_verified, s.created_at
|
||||||
|
FROM session s
|
||||||
|
INNER JOIN user u ON s.user_uuid = u.user_uuid
|
||||||
|
WHERE session_id = ?`, sessionId).Scan(&userId, &email, &emailVerified, &createdAt)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("Could not verify session: " + err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if createdAt.Add(time.Duration(8 * time.Hour)).Before(time.Now()) {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return types.NewUser(userId, email, sessionId, emailVerified)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOT TESTED
|
||||||
|
|
||||||
func HandleSignInPage(db *sql.DB) http.HandlerFunc {
|
func HandleSignInPage(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
userComp := UserInfoComp(nil)
|
userComp := UserInfoComp(nil)
|
||||||
@@ -48,7 +81,7 @@ func HandleSignInPage(db *sql.DB) http.HandlerFunc {
|
|||||||
|
|
||||||
func HandleSignUpPage(db *sql.DB) http.HandlerFunc {
|
func HandleSignUpPage(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
userComp := UserInfoComp(nil)
|
userComp := UserInfoComp(nil)
|
||||||
@@ -70,7 +103,7 @@ func HandleSignUpPage(db *sql.DB) http.HandlerFunc {
|
|||||||
|
|
||||||
func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc {
|
func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
if user == nil {
|
if user == nil {
|
||||||
utils.DoRedirect(w, r, "/auth/signin")
|
utils.DoRedirect(w, r, "/auth/signin")
|
||||||
} else if user.EmailVerified {
|
} else if user.EmailVerified {
|
||||||
@@ -90,7 +123,7 @@ func HandleSignUpVerifyPage(db *sql.DB) http.HandlerFunc {
|
|||||||
func HandleDeleteAccountPage(db *sql.DB) http.HandlerFunc {
|
func HandleDeleteAccountPage(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
// An unverified email should be able to delete their account
|
// An unverified email should be able to delete their account
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
if user == nil {
|
if user == nil {
|
||||||
utils.DoRedirect(w, r, "/auth/signin")
|
utils.DoRedirect(w, r, "/auth/signin")
|
||||||
} else {
|
} else {
|
||||||
@@ -152,7 +185,7 @@ func HandleChangePasswordPage(db *sql.DB) http.HandlerFunc {
|
|||||||
|
|
||||||
isPasswordReset := r.URL.Query().Has("token")
|
isPasswordReset := r.URL.Query().Has("token")
|
||||||
|
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
if user == nil && !isPasswordReset {
|
if user == nil && !isPasswordReset {
|
||||||
utils.DoRedirect(w, r, "/auth/signin")
|
utils.DoRedirect(w, r, "/auth/signin")
|
||||||
} else {
|
} else {
|
||||||
@@ -170,7 +203,7 @@ func HandleChangePasswordPage(db *sql.DB) http.HandlerFunc {
|
|||||||
func HandleResetPasswordPage(db *sql.DB) http.HandlerFunc {
|
func HandleResetPasswordPage(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
if user != nil {
|
if user != nil {
|
||||||
utils.DoRedirect(w, r, "/auth/signin")
|
utils.DoRedirect(w, r, "/auth/signin")
|
||||||
} else {
|
} else {
|
||||||
@@ -314,7 +347,7 @@ func HandleSignInComp(db *sql.DB) http.HandlerFunc {
|
|||||||
|
|
||||||
func HandleSignOutComp(db *sql.DB) http.HandlerFunc {
|
func HandleSignOutComp(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
_, err := db.Exec("DELETE FROM session WHERE session_id = ?", user.SessionId)
|
_, err := db.Exec("DELETE FROM session WHERE session_id = ?", user.SessionId)
|
||||||
@@ -343,7 +376,7 @@ func HandleSignOutComp(db *sql.DB) http.HandlerFunc {
|
|||||||
|
|
||||||
func HandleDeleteAccountComp(db *sql.DB) http.HandlerFunc {
|
func HandleDeleteAccountComp(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
if user == nil {
|
if user == nil {
|
||||||
utils.DoRedirect(w, r, "/auth/signin")
|
utils.DoRedirect(w, r, "/auth/signin")
|
||||||
return
|
return
|
||||||
@@ -409,7 +442,7 @@ func HandleDeleteAccountComp(db *sql.DB) http.HandlerFunc {
|
|||||||
|
|
||||||
func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc {
|
func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
if user == nil || user.EmailVerified {
|
if user == nil || user.EmailVerified {
|
||||||
utils.DoRedirect(w, r, "/auth/signin")
|
utils.DoRedirect(w, r, "/auth/signin")
|
||||||
return
|
return
|
||||||
@@ -424,7 +457,7 @@ func HandleVerifyResendComp(db *sql.DB) http.HandlerFunc {
|
|||||||
func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc {
|
func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
user := utils.GetUserFromSession(db, r)
|
user := GetUserFromRequest(db, r)
|
||||||
if user == nil {
|
if user == nil {
|
||||||
utils.DoRedirect(w, r, "/auth/signin")
|
utils.DoRedirect(w, r, "/auth/signin")
|
||||||
return
|
return
|
||||||
@@ -669,3 +702,19 @@ func checkPassword(password string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO: delete
|
||||||
|
|
||||||
|
func getSessionID(r *http.Request) types.SessionId {
|
||||||
|
for _, c := range r.Cookies() {
|
||||||
|
if c.Name == "id" {
|
||||||
|
return types.SessionId(c.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserFromRequest(db *sql.DB, r *http.Request) *types.User {
|
||||||
|
sessionId := getSessionID(r)
|
||||||
|
return GetUserFromSessionId(db, sessionId)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,119 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"me-fit/types"
|
||||||
|
"me-fit/utils"
|
||||||
|
|
||||||
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func mustSetup(t *testing.T) *sql.DB {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not open Database data.db: %v", err)
|
||||||
|
}
|
||||||
|
utils.MustRunMigrationsTest(db, "../")
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserFromSessionIfSessionNotExpired(t *testing.T) {
|
||||||
|
db := mustSetup(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
expected := types.NewUser(uuid.New(), "email", "session_id", true)
|
||||||
|
|
||||||
|
db.Exec(`INSERT INTO user (
|
||||||
|
user_uuid, email, email_verified, email_verified_at,
|
||||||
|
is_admin, password, salt, created_at)
|
||||||
|
VAlUES (
|
||||||
|
?, ?, 1, datetime(),
|
||||||
|
0, "password", "salt", datetime())`, expected.Id, expected.Email)
|
||||||
|
db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '-2 hour'))`, expected.SessionId, expected.Id)
|
||||||
|
|
||||||
|
actual := GetUserFromSessionId(db, expected.SessionId)
|
||||||
|
|
||||||
|
if *actual != *expected {
|
||||||
|
t.Errorf("Expected %v, got %v", *expected, *actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserFromSessionIfSessionInFuture(t *testing.T) {
|
||||||
|
db := mustSetup(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
expected := types.NewUser(uuid.New(), "email", "session_id", true)
|
||||||
|
|
||||||
|
db.Exec(`INSERT INTO user (
|
||||||
|
user_uuid, email, email_verified, email_verified_at,
|
||||||
|
is_admin, password, salt, created_at)
|
||||||
|
VAlUES (
|
||||||
|
?, ?, 1, datetime(),
|
||||||
|
0, "password", "salt", datetime())`, expected.Id, expected.Email)
|
||||||
|
db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '+2 hour'))`, expected.SessionId, expected.Id)
|
||||||
|
|
||||||
|
actual := GetUserFromSessionId(db, expected.SessionId)
|
||||||
|
|
||||||
|
if *actual != *expected {
|
||||||
|
t.Errorf("Expected %v, got %v", *expected, *actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFailGetUserFromSessionIfSessionExpired(t *testing.T) {
|
||||||
|
db := mustSetup(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
expected := types.NewUser(uuid.New(), "email", "session_id", true)
|
||||||
|
|
||||||
|
db.Exec(`INSERT INTO user (
|
||||||
|
user_uuid, email, email_verified, email_verified_at,
|
||||||
|
is_admin, password, salt, created_at)
|
||||||
|
VAlUES (
|
||||||
|
?, ?, 1, datetime(),
|
||||||
|
0, "password", "salt", datetime())`, expected.Id, expected.Email)
|
||||||
|
db.Exec(`INSERT INTO session (session_id, user_uuid, created_at) VALUES (?, ?, datetime('now', '-8 hour', '-1 minute'))`, expected.SessionId, expected.Id)
|
||||||
|
|
||||||
|
actual := GetUserFromSessionId(db, expected.SessionId)
|
||||||
|
|
||||||
|
if actual != nil {
|
||||||
|
t.Errorf("Expected nil, got %v", *actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUserFromSessionShouldFindCorrectUserBySessionId(t *testing.T) {
|
||||||
|
db := mustSetup(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
expected := types.NewUser(uuid.New(), "email", "session_id", true)
|
||||||
|
userId2 := uuid.New()
|
||||||
|
|
||||||
|
db.Exec(`INSERT INTO user (
|
||||||
|
user_uuid, email, email_verified, email_verified_at,
|
||||||
|
is_admin, password, salt, created_at)
|
||||||
|
VAlUES (
|
||||||
|
?, ?, 1, datetime(),
|
||||||
|
0, "password", "salt", datetime()),
|
||||||
|
(
|
||||||
|
?, ?, 1, datetime(),
|
||||||
|
0, "password", "salt", datetime())
|
||||||
|
`, expected.Id, expected.Email, userId2, "email2")
|
||||||
|
db.Exec(`
|
||||||
|
INSERT INTO session (
|
||||||
|
session_id, user_uuid, created_at)
|
||||||
|
VALUES
|
||||||
|
(?, ?, datetime('now')),
|
||||||
|
(?, ?, datetime('now'))
|
||||||
|
`, expected.SessionId, expected.Id, expected.SessionId+"x", userId2)
|
||||||
|
|
||||||
|
actual := GetUserFromSessionId(db, expected.SessionId)
|
||||||
|
|
||||||
|
if *actual != *expected {
|
||||||
|
t.Errorf("Expected %v, got %v", *expected, *actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestValidPasswords(t *testing.T) {
|
func TestValidPasswords(t *testing.T) {
|
||||||
passwords := []string{
|
passwords := []string{
|
||||||
"aB!'2d2y", //normal
|
"aB!'2d2y", //normal
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"me-fit/template"
|
|
||||||
"me-fit/utils"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/a-h/templ"
|
|
||||||
)
|
|
||||||
|
|
||||||
func HandleIndexAnd404(db *sql.DB) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := utils.GetUserFromSession(db, r)
|
|
||||||
|
|
||||||
var comp templ.Component = nil
|
|
||||||
userComp := UserInfoComp(user)
|
|
||||||
|
|
||||||
if r.URL.Path != "/" {
|
|
||||||
comp = template.Layout(template.NotFound(), userComp)
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
} else {
|
|
||||||
comp = template.Layout(template.Index(), userComp)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := comp.Render(r.Context(), w)
|
|
||||||
if err != nil {
|
|
||||||
utils.LogError("Failed to render index", err)
|
|
||||||
http.Error(w, "Failed to render index", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,9 +2,20 @@ package types
|
|||||||
|
|
||||||
import "github.com/google/uuid"
|
import "github.com/google/uuid"
|
||||||
|
|
||||||
|
type SessionId string
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Id uuid.UUID
|
Id uuid.UUID
|
||||||
Email string
|
Email string
|
||||||
SessionId string
|
SessionId SessionId
|
||||||
EmailVerified bool
|
EmailVerified bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewUser(id uuid.UUID, email string, sessionId SessionId, emailVerified bool) *User {
|
||||||
|
return &User{
|
||||||
|
Id: id,
|
||||||
|
Email: email,
|
||||||
|
SessionId: sessionId,
|
||||||
|
EmailVerified: emailVerified,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,13 +10,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func MustRunMigrations(db *sql.DB) {
|
func MustRunMigrations(db *sql.DB) {
|
||||||
|
mustRunMigrationsInternal(db, "")
|
||||||
|
}
|
||||||
|
func MustRunMigrationsTest(db *sql.DB, pathPrefix string) {
|
||||||
|
mustRunMigrationsInternal(db, "../")
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustRunMigrationsInternal(db *sql.DB, pathPrefix string) {
|
||||||
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := migrate.NewWithDatabaseInstance(
|
m, err := migrate.NewWithDatabaseInstance(
|
||||||
"file://./migration/",
|
"file://./"+pathPrefix+"migration/",
|
||||||
"",
|
"",
|
||||||
driver)
|
driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"me-fit/types"
|
"me-fit/types"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
@@ -61,39 +59,10 @@ func GetUser(r *http.Request) *types.User {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserFromSession(db *sql.DB, r *http.Request) *types.User {
|
func GetSessionID(r *http.Request) types.SessionId {
|
||||||
sessionId := getSessionID(r)
|
|
||||||
if sessionId == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var user types.User
|
|
||||||
var createdAt time.Time
|
|
||||||
|
|
||||||
user.SessionId = sessionId
|
|
||||||
|
|
||||||
err := db.QueryRow(`
|
|
||||||
SELECT u.user_uuid, u.email, u.email_verified, s.created_at
|
|
||||||
FROM session s
|
|
||||||
INNER JOIN user u ON s.user_uuid = u.user_uuid
|
|
||||||
WHERE session_id = ?`, sessionId).Scan(&user.Id, &user.Email, &user.EmailVerified, &createdAt)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("Could not verify session: " + err.Error())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if createdAt.Add(time.Duration(8 * time.Hour)).Before(time.Now()) {
|
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
return &user
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSessionID(r *http.Request) string {
|
|
||||||
for _, c := range r.Cookies() {
|
for _, c := range r.Cookies() {
|
||||||
if c.Name == "id" {
|
if c.Name == "id" {
|
||||||
return c.Value
|
return types.SessionId(c.Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
Reference in New Issue
Block a user