diff --git a/handler.go b/handler.go index 59b5124..d6c28ac 100644 --- a/handler.go +++ b/handler.go @@ -11,21 +11,21 @@ import ( func getHandler(db *sql.DB) http.Handler { var router = http.NewServeMux() - router.HandleFunc("/", service.HandleIndexAnd404) + router.HandleFunc("/", service.HandleIndexAnd404(db)) // Serve static files (CSS, JS and images) router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/")))) - router.HandleFunc("/app", service.WorkoutPage) - router.HandleFunc("POST /api/workout", service.NewWorkout(db)) - router.HandleFunc("GET /api/workout", service.GetWorkouts(db)) - router.HandleFunc("DELETE /api/workout", service.DeleteWorkout(db)) + router.HandleFunc("/app", service.HandleWorkoutPage(db)) + router.HandleFunc("POST /api/workout", service.HandleNewWorkout(db)) + router.HandleFunc("GET /api/workout", service.HandleGetWorkouts(db)) + router.HandleFunc("DELETE /api/workout", service.HandleDeleteWorkout(db)) - router.HandleFunc("/auth/signin", service.SignInPage) - router.HandleFunc("/auth/signup", service.SignUpPage) - router.HandleFunc("/api/auth/signup", service.SignUp(db)) - router.HandleFunc("/api/auth/signin", service.SignIn(db)) - router.HandleFunc("/api/auth/userinfo", service.UserInfoComp(db)) + router.HandleFunc("/auth/signin", service.HandleSignInPage(db)) + router.HandleFunc("/auth/signup", service.HandleSignUpPage(db)) + router.HandleFunc("/api/auth/signup", service.HandleSignUp(db)) + router.HandleFunc("/api/auth/signin", service.HandleSignIn(db)) + router.HandleFunc("/api/auth/signout", service.HandleSignOutComp(db)) return middleware.Logging(middleware.EnableCors(router)) } diff --git a/service/auth.go b/service/auth.go index ae0783f..494229d 100644 --- a/service/auth.go +++ b/service/auth.go @@ -14,27 +14,42 @@ import ( "me-fit/template" "me-fit/template/auth" + "github.com/a-h/templ" "github.com/google/uuid" "golang.org/x/crypto/argon2" ) -func SignInPage(w http.ResponseWriter, r *http.Request) { - signIn := auth.SignInOrUp(true) - template.Layout(signIn).Render(r.Context(), w) +type User struct { + user_uuid uuid.UUID + email string } -func SignUpPage(w http.ResponseWriter, r *http.Request) { - signIn := auth.SignInOrUp(false) - template.Layout(signIn).Render(r.Context(), w) -} - -func UserInfoComp(db *sql.DB) http.HandlerFunc { +func HandleSignInPage(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - //TODO + user_comp := UserInfoComp(verifySessionAndReturnUser(db, r)) + signIn := auth.SignInOrUp(true) + template.Layout(signIn, user_comp).Render(r.Context(), w) } } -func SignUp(db *sql.DB) http.HandlerFunc { +func HandleSignUpPage(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user_comp := UserInfoComp(verifySessionAndReturnUser(db, r)) + signIn := auth.SignInOrUp(false) + template.Layout(signIn, user_comp).Render(r.Context(), w) + } +} + +func UserInfoComp(user *User) templ.Component { + + if user != nil { + return auth.UserComp(user.email) + } else { + return auth.UserComp("") + } +} + +func HandleSignUp(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") @@ -78,16 +93,16 @@ func SignUp(db *sql.DB) http.HandlerFunc { return } - result := tryCreateSessionAndSetCookie(r, w, db, user_uuid) + result := tryCreateSessionAndSetCookie(w, db, user_uuid) if !result { return } - w.WriteHeader(http.StatusOK) + http.Redirect(w, r, "/", http.StatusSeeOther) } } -func SignIn(db *sql.DB) http.HandlerFunc { +func HandleSignIn(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var email = r.FormValue("email") var password = r.FormValue("password") @@ -112,14 +127,14 @@ func SignIn(db *sql.DB) http.HandlerFunc { } if result { - result := tryCreateSessionAndSetCookie(r, w, db, user_uuid) + result := tryCreateSessionAndSetCookie(w, db, user_uuid) if !result { return } } duration := time.Since(start) - time_to_wait := 300 - duration.Milliseconds() + time_to_wait := 100 - duration.Milliseconds() // It is important to sleep for a while to prevent timing attacks // If the email is correct, the server will calculate the hash, which will take some time // This way an attacker could guess emails when comparing the response time @@ -127,13 +142,39 @@ func SignIn(db *sql.DB) http.HandlerFunc { time.Sleep(time.Duration(time_to_wait) * time.Millisecond) if result { - w.WriteHeader(http.StatusOK) + http.Redirect(w, r, "/", http.StatusSeeOther) } else { http.Error(w, "Unauthorized", http.StatusUnauthorized) } } } +func HandleSignOutComp(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id := getSessionID(r) + + _, err := db.Exec("DELETE FROM session WHERE session_id = ?", id) + if err != nil { + log.Printf("Could not delete session: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + c := http.Cookie{ + Name: "id", + Value: "", + MaxAge: -1, + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Path: "/", + } + + http.SetCookie(w, &c) + auth.UserComp("").Render(r.Context(), w) + } +} + // var ( // metricsAuthSignUp = promauto.NewCounterVec( // prometheus.CounterOpts{ @@ -159,7 +200,7 @@ func SignIn(db *sql.DB) http.HandlerFunc { // // ) // ) -func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { +func tryCreateSessionAndSetCookie(w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) bool { var session_id_bytes []byte = make([]byte, 32) _, err := rand.Reader.Read(session_id_bytes) if err != nil { @@ -190,6 +231,42 @@ func tryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sq return true } +func getSessionID(r *http.Request) string { + for _, c := range r.Cookies() { + if c.Name == "id" { + return c.Value + } + } + return "" +} + +func verifySessionAndReturnUser(db *sql.DB, r *http.Request) *User { + session_id := getSessionID(r) + if session_id == "" { + return nil + } + + var user User + var created_at time.Time + + err := db.QueryRow(` + SELECT u.user_uuid, u.email, s.created_at + FROM session s + INNER JOIN user u ON s.user_uuid = u.user_uuid + WHERE session_id = ?`, session_id).Scan(&user.user_uuid, &user.email, &created_at) + if err != nil { + log.Printf("Could not verify session: %v", err) + return nil + } + + // TODO: Test + // if created_at.Add(time.Duration(8 * time.Hour)).Before(time.Now()) { + // return nil + // } + + return &user +} + func getHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) } diff --git a/service/static_ui.go b/service/static_ui.go index f99720e..d86efe8 100644 --- a/service/static_ui.go +++ b/service/static_ui.go @@ -1,20 +1,25 @@ package service import ( + "database/sql" "me-fit/template" "net/http" "github.com/a-h/templ" ) -func HandleIndexAnd404(w http.ResponseWriter, r *http.Request) { - var comp templ.Component = nil - if r.URL.Path != "/" { - comp = template.Layout(template.NotFound()) - w.WriteHeader(http.StatusNotFound) - } else { - comp = template.Layout(template.Index()) - } +func HandleIndexAnd404(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var comp templ.Component = nil + user_comp := UserInfoComp(verifySessionAndReturnUser(db, r)) - comp.Render(r.Context(), w) + if r.URL.Path != "/" { + comp = template.Layout(template.NotFound(), user_comp) + w.WriteHeader(http.StatusNotFound) + } else { + comp = template.Layout(template.Index(), user_comp) + } + + comp.Render(r.Context(), w) + } } diff --git a/service/workout.go b/service/workout.go index 3d7650f..b8634fa 100644 --- a/service/workout.go +++ b/service/workout.go @@ -23,13 +23,16 @@ var ( ) ) -func WorkoutPage(w http.ResponseWriter, r *http.Request) { - comp := template.App() - layout := template.Layout(comp) - layout.Render(r.Context(), w) +func HandleWorkoutPage(db *sql.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + inner := template.App() + user_comp := UserInfoComp(verifySessionAndReturnUser(db, r)) + layout := template.Layout(inner, user_comp) + layout.Render(r.Context(), w) + } } -func NewWorkout(db *sql.DB) http.HandlerFunc { +func HandleNewWorkout(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { metrics.WithLabelValues("new").Inc() @@ -73,7 +76,7 @@ func NewWorkout(db *sql.DB) http.HandlerFunc { } } -func GetWorkouts(db *sql.DB) http.HandlerFunc { +func HandleGetWorkouts(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { metrics.WithLabelValues("get").Inc() @@ -115,7 +118,7 @@ func GetWorkouts(db *sql.DB) http.HandlerFunc { } } -func DeleteWorkout(db *sql.DB) http.HandlerFunc { +func HandleDeleteWorkout(db *sql.DB) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { metrics.WithLabelValues("delete").Inc() diff --git a/template/auth/sign_in_or_up.templ b/template/auth/sign_in_or_up.templ index 0ecebce..929ecf9 100644 --- a/template/auth/sign_in_or_up.templ +++ b/template/auth/sign_in_or_up.templ @@ -5,9 +5,9 @@ templ SignInOrUp(isSignIn bool) { class="max-w-xl px-2 mx-auto flex flex-col gap-4 h-full justify-center" method="POST" if isSignIn { - hx-post="/api/auth/signin" + action="/api/auth/signin" } else { - hx-post="/api/auth/signup" + action="/api/auth/signup" } >