feat(security): #286 implement csrf middleware

This commit is contained in:
2024-12-04 23:15:40 +01:00
parent bbcdbf7a01
commit 57989c9b03
18 changed files with 484 additions and 204 deletions

View File

@@ -1,6 +1,7 @@
package handler
import (
"me-fit/handler/middleware"
"me-fit/log"
"me-fit/service"
"me-fit/template/auth"
@@ -58,9 +59,9 @@ var (
func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, _ := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
if user != nil {
if !user.EmailVerified {
session := middleware.GetSession(r)
if session != nil {
if !session.User.EmailVerified {
utils.DoRedirect(w, r, "/auth/verify")
} else {
utils.DoRedirect(w, r, "/")
@@ -121,10 +122,10 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
func (handler AuthImpl) handleSignUpPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, _ := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
session := middleware.GetSession(r)
if user != nil {
if !user.EmailVerified {
if session != nil {
if !session.User.EmailVerified {
utils.DoRedirect(w, r, "/auth/verify")
} else {
utils.DoRedirect(w, r, "/")
@@ -139,33 +140,34 @@ func (handler AuthImpl) handleSignUpPage() http.HandlerFunc {
func (handler AuthImpl) handleSignUpVerifyPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, _ := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
if user == nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
if user.EmailVerified {
if session.User.EmailVerified {
utils.DoRedirect(w, r, "/")
return
}
signIn := auth.VerifyComp()
handler.render.RenderLayout(r, w, signIn, user)
handler.render.RenderLayout(r, w, signIn, session.User)
}
}
func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
user := session.User
go handler.service.SendVerificationMail(user.Id, user.Email)
_, err = w.Write([]byte("<p class=\"mt-8\">Verification email sent</p>"))
_, err := w.Write([]byte("<p class=\"mt-8\">Verification email sent</p>"))
if err != nil {
log.Error("Could not write response: %v", err)
}
@@ -219,11 +221,14 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
func (handler AuthImpl) handleSignOut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := handler.service.SignOut(utils.GetSessionID(r))
if err != nil {
utils.TriggerToast(w, r, "error", "Internal Server Error")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
session := middleware.GetSession(r)
if session != nil {
err := handler.service.SignOut(session.Id)
if err != nil {
http.Error(w, "An error occurred", http.StatusInternalServerError)
return
}
}
c := http.Cookie{
@@ -243,34 +248,34 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc {
func (handler AuthImpl) handleDeleteAccountPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// An unverified email should be able to delete their account
user, err := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
comp := auth.DeleteAccountComp()
handler.render.RenderLayout(r, w, comp, user)
handler.render.RenderLayout(r, w, comp, session.User)
}
}
func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
password := r.FormValue("password")
_, err = handler.service.SignIn(user.Email, password)
_, err := handler.service.SignIn(session.User.Email, password)
if err != nil {
utils.TriggerToast(w, r, "error", "Password not correct")
return
}
err = handler.service.DeleteAccount(user)
err = handler.service.DeleteAccount(session.User)
if err != nil {
utils.TriggerToast(w, r, "error", "Internal Server Error")
return
@@ -285,23 +290,23 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
isPasswordReset := r.URL.Query().Has("token")
user, _ := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
session := middleware.GetSession(r)
if user == nil && !isPasswordReset {
if session == nil && !isPasswordReset {
utils.DoRedirect(w, r, "/auth/signin")
return
}
comp := auth.ChangePasswordComp(isPasswordReset)
handler.render.RenderLayout(r, w, comp, user)
handler.render.RenderLayout(r, w, comp, session.User)
}
}
func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
@@ -309,7 +314,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
currPass := r.FormValue("current-password")
newPass := r.FormValue("new-password")
err = handler.service.ChangePassword(user, currPass, newPass)
err := handler.service.ChangePassword(session.User, currPass, newPass)
if err != nil {
utils.TriggerToast(w, r, "error", "Password not correct")
return
@@ -322,14 +327,14 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
func (handler AuthImpl) handleResetPasswordPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
comp := auth.ResetPasswordComp()
handler.render.RenderLayout(r, w, comp, user)
handler.render.RenderLayout(r, w, comp, session.User)
}
}

View File

@@ -1,9 +1,9 @@
package handler
import (
"me-fit/handler/middleware"
"me-fit/service"
"me-fit/template"
"me-fit/utils"
"net/http"
@@ -32,7 +32,11 @@ func (handler IndexImpl) Handle(router *http.ServeMux) {
func (handler IndexImpl) handleIndexAnd404() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, _ := handler.service.GetUserFromSessionId(utils.GetSessionID(r))
session := middleware.GetSession(r)
var user *service.User
if session != nil {
user = session.User
}
var comp templ.Component

View File

@@ -0,0 +1,47 @@
package middleware
import (
"context"
"me-fit/service"
"net/http"
)
type ContextKey string
var SessionKey ContextKey = "session"
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionId := getSessionID(r)
session, _ := service.SignInSession(sessionId)
if session != nil {
ctx := context.WithValue(r.Context(), SessionKey, session)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
next.ServeHTTP(w, r)
}
})
}
}
func GetSession(r *http.Request) *service.Session {
obj := r.Context().Value(SessionKey)
if obj == nil {
return nil
}
return obj.(*service.Session)
}
func getSessionID(r *http.Request) string {
cookie, err := r.Cookie("id")
if err != nil {
return ""
}
return cookie.Name
}

View File

@@ -0,0 +1,63 @@
package middleware
import (
"fmt"
"me-fit/service"
"strings"
"net/http"
)
type csrfResponseWriter struct {
http.ResponseWriter
auth service.Auth
session *service.Session
}
func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *service.Session) *csrfResponseWriter {
return &csrfResponseWriter{
ResponseWriter: w,
auth: auth,
session: session,
}
}
TODO: Create session for CSRF token
func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
dataStr := string(data)
if strings.Contains(dataStr, "</form>") {
csrfToken, err := rr.auth.GetCsrfToken(rr.session)
if err == nil {
csrfField := fmt.Sprintf(`<input type="hidden" name="csrf-token" value="%s">`, csrfToken)
dataStr = strings.ReplaceAll(dataStr, "</form>", csrfField+"</form>")
}
}
return rr.ResponseWriter.Write([]byte(dataStr))
}
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := GetSession(r)
if r.Method == http.MethodPost ||
r.Method == http.MethodPut ||
r.Method == http.MethodDelete ||
r.Method == http.MethodPatch {
csrfToken := r.FormValue("csrf-token")
if csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
http.Error(w, "", http.StatusForbidden)
return
}
}
responseWriter := newCsrfResponseWriter(w, auth, session)
next.ServeHTTP(responseWriter, r)
})
}
}

View File

@@ -1,6 +1,7 @@
package handler
import (
"me-fit/handler/middleware"
"me-fit/log"
"me-fit/service"
"me-fit/template/workout"
@@ -38,22 +39,22 @@ func (handler WorkoutImpl) Handle(router *http.ServeMux) {
func (handler WorkoutImpl) handleWorkoutPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := handler.auth.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
currentDate := time.Now().Format("2006-01-02")
comp := workout.WorkoutComp(currentDate)
handler.render.RenderLayout(r, w, comp, user)
handler.render.RenderLayout(r, w, comp, session.User)
}
}
func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := handler.auth.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
@@ -64,7 +65,7 @@ func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc {
var repsStr = r.FormValue("reps")
wo := service.NewWorkoutDto("", dateStr, typeStr, setsStr, repsStr)
wo, err = handler.service.AddWorkout(user, wo)
wo, err := handler.service.AddWorkout(session.User, wo)
if err != nil {
utils.TriggerToast(w, r, "error", "Invalid input values")
http.Error(w, "Invalid input values", http.StatusBadRequest)
@@ -79,13 +80,13 @@ func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc {
func (handler WorkoutImpl) handleGetWorkout() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := handler.auth.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
workouts, err := handler.service.GetWorkouts(user)
workouts, err := handler.service.GetWorkouts(session.User)
if err != nil {
return
}
@@ -102,8 +103,8 @@ func (handler WorkoutImpl) handleGetWorkout() http.HandlerFunc {
func (handler WorkoutImpl) handleDeleteWorkout() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := handler.auth.GetUserFromSessionId(utils.GetSessionID(r))
if err != nil {
session := middleware.GetSession(r)
if session == nil {
utils.DoRedirect(w, r, "/auth/signin")
return
}
@@ -124,7 +125,7 @@ func (handler WorkoutImpl) handleDeleteWorkout() http.HandlerFunc {
return
}
err = handler.service.DeleteWorkout(user, rowIdInt)
err = handler.service.DeleteWorkout(session.User, rowIdInt)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
log.Error("Could not delete workout: %v", err.Error())