diff --git a/db/account.go b/db/account.go index 5079e5e..1023de2 100644 --- a/db/account.go +++ b/db/account.go @@ -8,12 +8,13 @@ import ( "github.com/jmoiron/sqlx" ) +// While it may be duplicated to check for groupIds in the database access, it serves as a security layer type Account interface { - Insert(account *types.Account) error - Update(account *types.Account) error + Insert(groupId uuid.UUID, account *types.Account) error + Update(groupId uuid.UUID, account *types.Account) error GetAll(groupId uuid.UUID) ([]*types.Account, error) Get(groupId uuid.UUID, id uuid.UUID) (*types.Account, error) - Delete(id uuid.UUID) error + Delete(groupId uuid.UUID, id uuid.UUID) error } type AccountSqlite struct { @@ -24,11 +25,11 @@ func NewAccountSqlite(db *sqlx.DB) *AccountSqlite { return &AccountSqlite{db: db} } -func (db AccountSqlite) Insert(account *types.Account) error { +func (db AccountSqlite) Insert(groupId uuid.UUID, account *types.Account) error { _, err := db.db.Exec(` INSERT INTO account (id, group_id, name, current_balance, oink_balance, created_at, created_by) - VALUES (?,?,?,?,?,?,?)`, account.Id, account.GroupId, 0, 0, account.CreatedAt, account.CreatedBy) + VALUES (?,?,?,?,?,?,?)`, account.Id, groupId, account.Name, 0, 0, account.CreatedAt, account.CreatedBy) if err != nil { log.Error("Error inserting account: %v", err) return types.ErrInternal @@ -37,22 +38,34 @@ func (db AccountSqlite) Insert(account *types.Account) error { return nil } -func (db AccountSqlite) Update(account *types.Account) error { +func (db AccountSqlite) Update(groupId uuid.UUID, account *types.Account) error { - _, err := db.db.Exec(` + log.Info("Updating account: %v", account) + r, err := db.db.Exec(` UPDATE account + SET name = ?, current_balance = ?, last_transaction = ?, oink_balance = ?, updated_at = ?, - updated_by = ?, + updated_by = ? WHERE id = ? - AND group_id = ?`, account.Name, account.CurrentBalance, account.LastTransaction, account.OinkBalance, account.UpdatedAt, account.UpdatedBy, account.Id, account.GroupId) + AND group_id = ?`, account.Name, account.CurrentBalance, account.LastTransaction, account.OinkBalance, account.UpdatedAt, account.UpdatedBy, account.Id, groupId) if err != nil { log.Error("Error updating account: %v", err) return types.ErrInternal } + rows, err := r.RowsAffected() + if err != nil { + log.Error("Error deleting account, getting rows affected: %v", err) + return types.ErrInternal + } + + if rows == 0 { + log.Error("Error deleting account, rows affected: %v", rows) + return ErrNotFound + } return nil } @@ -62,7 +75,7 @@ func (db AccountSqlite) GetAll(groupId uuid.UUID) ([]*types.Account, error) { accounts := make([]*types.Account, 0) err := db.db.Select(&accounts, ` SELECT - id, name, + id, group_id, name, current_balance, last_transaction, oink_balance, created_at, created_by, updated_at, updated_by FROM account @@ -81,7 +94,7 @@ func (db AccountSqlite) Get(groupId uuid.UUID, id uuid.UUID) (*types.Account, er account := &types.Account{} err := db.db.Get(account, ` SELECT - id, name, + id, group_id, name, current_balance, last_transaction, oink_balance, created_at, created_by, updated_at, updated_by FROM account @@ -95,9 +108,9 @@ func (db AccountSqlite) Get(groupId uuid.UUID, id uuid.UUID) (*types.Account, er return account, nil } -func (db AccountSqlite) Delete(id uuid.UUID) error { +func (db AccountSqlite) Delete(groupId uuid.UUID, id uuid.UUID) error { - res, err := db.db.Exec("DELETE FROM account WHERE id = ?", id) + res, err := db.db.Exec("DELETE FROM account WHERE id = ? and group_id = ?", id, groupId) if err != nil { log.Error("Error deleting account: %v", err) return types.ErrInternal @@ -110,6 +123,7 @@ func (db AccountSqlite) Delete(id uuid.UUID) error { } if rows == 0 { + log.Error("Error deleting account, rows affected: %v", rows) return ErrNotFound } diff --git a/handler/account.go b/handler/account.go index 65e4cad..6525735 100644 --- a/handler/account.go +++ b/handler/account.go @@ -2,11 +2,16 @@ package handler import ( "spend-sparrow/handler/middleware" + "spend-sparrow/log" "spend-sparrow/service" - "spend-sparrow/template/account" + t "spend-sparrow/template/account" + "spend-sparrow/types" "spend-sparrow/utils" "net/http" + + "github.com/a-h/templ" + "github.com/google/uuid" ) type Account interface { @@ -14,27 +19,27 @@ type Account interface { } type AccountImpl struct { - service service.Account - auth service.Auth - render *Render + s service.Account + a service.Auth + r *Render } -func NewAccount(service service.Account, auth service.Auth, render *Render) Account { +func NewAccount(s service.Account, a service.Auth, r *Render) Account { return AccountImpl{ - service: service, - auth: auth, - render: render, + s: s, + a: a, + r: r, } } -func (handler AccountImpl) Handle(router *http.ServeMux) { - router.Handle("/account", handler.handleAccountPage()) - // router.Handle("POST /account", handler.handleAddAccount()) - // router.Handle("GET /account", handler.handleGetAccount()) - // router.Handle("DELETE /account/{id}", handler.handleDeleteAccount()) +func (h AccountImpl) Handle(r *http.ServeMux) { + r.Handle("GET /account", h.handleAccountPage()) + r.Handle("GET /account/{id}", h.handleAccountItemComp()) + r.Handle("POST /account/{id}", h.handleUpdateAccount()) + r.Handle("DELETE /account/{id}", h.handleDeleteAccount()) } -func (handler AccountImpl) handleAccountPage() http.HandlerFunc { +func (h AccountImpl) handleAccountPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { user := middleware.GetUser(r) if user == nil { @@ -42,85 +47,111 @@ func (handler AccountImpl) handleAccountPage() http.HandlerFunc { return } - comp := account.Account() - handler.render.RenderLayout(r, w, comp, user) + accounts, err := h.s.GetAll(user) + if err != nil { + utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) + return + } + + comp := t.Account(accounts) + h.r.RenderLayout(r, w, comp, user) } } -// func (handler AccountImpl) handleAddAccount() http.HandlerFunc { -// return func(w http.ResponseWriter, r *http.Request) { -// user := middleware.GetUser(r) -// if user == nil { -// utils.DoRedirect(w, r, "/auth/signin") -// return -// } -// -// var dateStr = r.FormValue("date") -// var typeStr = r.FormValue("type") -// var setsStr = r.FormValue("sets") -// var repsStr = r.FormValue("reps") -// -// wo := service.NewAccountDto("", dateStr, typeStr, setsStr, repsStr) -// wo, err := handler.service.AddAccount(user, wo) -// if err != nil { -// utils.TriggerToast(w, r, "error", "Invalid input values", http.StatusBadRequest) -// http.Error(w, "Invalid input values", http.StatusBadRequest) -// return -// } -// wor := account.Account{Id: wo.RowId, Date: wo.Date, Type: wo.Type, Sets: wo.Sets, Reps: wo.Reps} -// -// comp := account.AccountItemComp(wor, true) -// handler.render.Render(r, w, comp) -// } -// } -// -// func (handler AccountImpl) handleGetAccount() http.HandlerFunc { -// return func(w http.ResponseWriter, r *http.Request) { -// user := middleware.GetUser(r) -// if user == nil { -// utils.DoRedirect(w, r, "/auth/signin") -// return -// } -// -// workouts, err := handler.service.GetAccounts(user) -// if err != nil { -// return -// } -// -// wos := make([]*types.Account, 0) -// for _, wo := range workouts { -// wos = append(wos, *types.Account{Id: wo.RowId, Date: wo.Date, Type: wo.Type, Sets: wo.Sets, Reps: wo.Reps}) -// } -// -// comp := account.AccountListComp(wos) -// handler.render.Render(r, w, comp) -// } -// } -// -// func (handler AccountImpl) handleDeleteAccount() http.HandlerFunc { -// return func(w http.ResponseWriter, r *http.Request) { -// user := middleware.GetUser(r) -// if user == nil { -// utils.DoRedirect(w, r, "/auth/signin") -// return -// } -// -// rowId := r.PathValue("id") -// if rowId == "" { -// utils.TriggerToast(w, r, "error", "Missing ID field", http.StatusBadRequest) -// return -// } -// -// rowIdInt, err := strconv.Atoi(rowId) -// if err != nil { -// utils.TriggerToast(w, r, "error", "Invalid ID", http.StatusBadRequest) -// return -// } -// -// err = handler.service.DeleteAccount(user, rowIdInt) -// if err != nil { -// utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) -// return -// } -// } -// } +func (h AccountImpl) handleAccountItemComp() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := middleware.GetUser(r) + if user == nil { + utils.DoRedirect(w, r, "/auth/signin") + return + } + + idStr := r.PathValue("id") + if idStr == "new" { + comp := t.EditAccount(nil) + log.Info("Component: %v", comp) + h.r.Render(r, w, comp) + return + } + + id, err := uuid.Parse(idStr) + if err != nil { + utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest) + return + } + + account, err := h.s.Get(user, id) + if err != nil { + utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) + return + } + + var comp templ.Component + if r.URL.Query().Get("edit") == "true" { + comp = t.EditAccount(account) + } else { + comp = t.AccountItem(account) + } + h.r.Render(r, w, comp) + } +} + +func (h AccountImpl) handleUpdateAccount() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := middleware.GetUser(r) + if user == nil { + utils.DoRedirect(w, r, "/auth/signin") + return + } + + var ( + account *types.Account + err error + ) + idStr := r.PathValue("id") + name := r.FormValue("name") + if idStr == "new" { + account, err = h.s.Add(user, name) + if err != nil { + utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError) + return + } + } else { + id, err := uuid.Parse(idStr) + if err != nil { + utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest) + return + } + account, err = h.s.Update(user, id, name) + if err != nil { + utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError) + return + } + } + + comp := t.AccountItem(account) + h.r.Render(r, w, comp) + } +} + +func (h AccountImpl) handleDeleteAccount() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := middleware.GetUser(r) + if user == nil { + utils.DoRedirect(w, r, "/auth/signin") + return + } + + id, err := uuid.Parse(r.PathValue("id")) + if err != nil { + utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest) + return + } + + err = h.s.Delete(user, id) + if err != nil { + utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError) + return + } + } +} diff --git a/handler/auth.go b/handler/auth.go index 962002d..ca0f9e1 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -96,9 +96,9 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc { if err != nil { if err == service.ErrInvalidCredentials { - utils.TriggerToast(w, r, "error", "Invalid email or password", http.StatusUnauthorized) + utils.TriggerToastWithStatus(w, r, "error", "Invalid email or password", http.StatusUnauthorized) } else { - utils.TriggerToast(w, r, "error", "An error occurred", http.StatusInternalServerError) + utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError) } return } @@ -204,19 +204,19 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc { if err != nil { if errors.Is(err, types.ErrInternal) { - utils.TriggerToast(w, r, "error", "An error occurred", http.StatusInternalServerError) + utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError) return } else if errors.Is(err, service.ErrInvalidEmail) { - utils.TriggerToast(w, r, "error", "The email provided is invalid", http.StatusBadRequest) + utils.TriggerToastWithStatus(w, r, "error", "The email provided is invalid", http.StatusBadRequest) return } else if errors.Is(err, service.ErrInvalidPassword) { - utils.TriggerToast(w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest) + utils.TriggerToastWithStatus(w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest) return } // If err is "service.ErrAccountExists", then just continue } - utils.TriggerToast(w, r, "success", "An activation link has been send to your email", http.StatusOK) + utils.TriggerToastWithStatus(w, r, "success", "An activation link has been send to your email", http.StatusOK) } } @@ -273,9 +273,9 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { err := handler.service.DeleteAccount(user, password) if err != nil { if err == service.ErrInvalidCredentials { - utils.TriggerToast(w, r, "error", "Password not correct", http.StatusBadRequest) + utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest) } else { - utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) + utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) } return } @@ -307,7 +307,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { session := middleware.GetSession(r) user := middleware.GetUser(r) if session == nil || user == nil { - utils.TriggerToast(w, r, "error", "Unathorized", http.StatusUnauthorized) + utils.TriggerToastWithStatus(w, r, "error", "Unathorized", http.StatusUnauthorized) return } @@ -316,11 +316,11 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { err := handler.service.ChangePassword(user, session.Id, currPass, newPass) if err != nil { - utils.TriggerToast(w, r, "error", "Password not correct", http.StatusBadRequest) + utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest) return } - utils.TriggerToast(w, r, "success", "Password changed", http.StatusOK) + utils.TriggerToastWithStatus(w, r, "success", "Password changed", http.StatusOK) } } @@ -343,7 +343,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { email := r.FormValue("email") if email == "" { - utils.TriggerToast(w, r, "error", "Please enter an email", http.StatusBadRequest) + utils.TriggerToastWithStatus(w, r, "error", "Please enter an email", http.StatusBadRequest) return } @@ -353,9 +353,9 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { }) if err != nil { - utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) + utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) } else { - utils.TriggerToast(w, r, "info", "If the address exists, an email has been sent.", http.StatusOK) + utils.TriggerToastWithStatus(w, r, "info", "If the address exists, an email has been sent.", http.StatusOK) } } } @@ -365,7 +365,7 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { pageUrl, err := url.Parse(r.Header.Get("HX-Current-URL")) if err != nil { log.Error("Could not get current URL: %v", err) - utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) + utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) return } @@ -374,9 +374,9 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { err = handler.service.ForgotPassword(token, newPass) if err != nil { - utils.TriggerToast(w, r, "error", err.Error(), http.StatusBadRequest) + utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest) } else { - utils.TriggerToast(w, r, "success", "Password changed", http.StatusOK) + utils.TriggerToastWithStatus(w, r, "success", "Password changed", http.StatusOK) } } } diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 1c0da74..663a2fc 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -59,7 +59,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { log.Info("CSRF-Token not correct") if r.Header.Get("HX-Request") == "true" { - utils.TriggerToast(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) + utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) } else { http.Error(w, "CSRF-Token not correct", http.StatusBadRequest) } diff --git a/handler/middleware/gzip.go b/handler/middleware/gzip.go new file mode 100644 index 0000000..1def7fd --- /dev/null +++ b/handler/middleware/gzip.go @@ -0,0 +1,32 @@ +package middleware + +import ( + "compress/gzip" + "io" + "net/http" + "strings" +) + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w gzipResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} + +func Gzip(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + next.ServeHTTP(w, r) + return + } + + w.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(w) + defer gz.Close() + gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w} + next.ServeHTTP(gzr, r) + }) +} diff --git a/handler/middleware/wrapper.go b/handler/middleware/wrapper.go index 80dcb20..cd5f9af 100644 --- a/handler/middleware/wrapper.go +++ b/handler/middleware/wrapper.go @@ -2,10 +2,11 @@ package middleware import "net/http" +// Chain list of handlers together func Wrapper(next http.Handler, handlers ...func(http.Handler) http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { lastHandler := next - for i := len(handlers) - 1; i >= 0; i-- { + for i := 0; i < len(handlers); i++ { lastHandler = handlers[i](lastHandler) } lastHandler.ServeHTTP(w, r) diff --git a/handler/root_and_404.go b/handler/root_and_404.go index 4d836dd..569b927 100644 --- a/handler/root_and_404.go +++ b/handler/root_and_404.go @@ -28,6 +28,7 @@ func NewIndex(service service.Auth, render *Render) Index { func (handler IndexImpl) Handle(router *http.ServeMux) { router.Handle("/", handler.handleRootAnd404()) + router.Handle("/empty", handler.handleEmpty()) } func (handler IndexImpl) handleRootAnd404() http.HandlerFunc { @@ -52,3 +53,9 @@ func (handler IndexImpl) handleRootAnd404() http.HandlerFunc { handler.render.RenderLayoutWithStatus(r, w, comp, user, status) } } + +func (handler IndexImpl) handleEmpty() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Return nothing + } +} diff --git a/main.go b/main.go index e4e6321..4f271ed 100644 --- a/main.go +++ b/main.go @@ -130,10 +130,12 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler { return middleware.Wrapper( router, - middleware.Log, - middleware.CacheControl, middleware.SecurityHeaders(serverSettings), - middleware.Authenticate(authService), + middleware.CacheControl, middleware.CrossSiteRequestForgery(authService), + middleware.Authenticate(authService), + middleware.Log, + // Gzip last, as it compresses the body + middleware.Gzip, ) } diff --git a/service/account.go b/service/account.go index 5e1223e..d009dda 100644 --- a/service/account.go +++ b/service/account.go @@ -12,13 +12,14 @@ import ( ) var ( - safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9-]+$`) + safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9äöüß -]+$`) ) type Account interface { Add(user *types.User, name string) (*types.Account, error) Update(user *types.User, id uuid.UUID, name string) (*types.Account, error) - Get(user *types.User) ([]*types.Account, error) + Get(user *types.User, id uuid.UUID) (*types.Account, error) + GetAll(user *types.User) ([]*types.Account, error) Delete(user *types.User, id uuid.UUID) error } @@ -69,7 +70,7 @@ func (service AccountImpl) Add(user *types.User, name string) (*types.Account, e UpdatedBy: nil, } - err = service.db.Insert(account) + err = service.db.Insert(user.Id, account) if err != nil { return nil, err } @@ -103,7 +104,7 @@ func (service AccountImpl) Update(user *types.User, id uuid.UUID, name string) ( account.UpdatedAt = ×tamp account.UpdatedBy = &user.Id - err = service.db.Update(account) + err = service.db.Update(user.Id, account) if err != nil { return nil, err } @@ -111,13 +112,27 @@ func (service AccountImpl) Update(user *types.User, id uuid.UUID, name string) ( return account, nil } -func (service AccountImpl) Get(user *types.User) ([]*types.Account, error) { +func (service AccountImpl) Get(user *types.User, id uuid.UUID) (*types.Account, error) { if user == nil { return nil, types.ErrInternal } - accounts, err := service.db.GetAll(user.GroupId) + account, err := service.db.Get(user.Id, id) + if err != nil { + return nil, types.ErrInternal + } + + return account, nil +} + +func (service AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { + + if user == nil { + return nil, types.ErrInternal + } + + accounts, err := service.db.GetAll(user.Id) if err != nil { return nil, types.ErrInternal } @@ -130,16 +145,16 @@ func (service AccountImpl) Delete(user *types.User, id uuid.UUID) error { return types.ErrInternal } - account, err := service.db.Get(user.GroupId, id) + account, err := service.db.Get(user.Id, id) if err != nil { return err } - if account.GroupId != user.GroupId { + if account.GroupId != user.Id { return types.ErrUnauthorized } - err = service.db.Delete(account.Id) + err = service.db.Delete(user.Id, account.Id) if err != nil { return err } diff --git a/static/js/toast.js b/static/js/toast.js index 2f92b18..ac624ba 100644 --- a/static/js/toast.js +++ b/static/js/toast.js @@ -1,5 +1,4 @@ - function getClass(type) { switch (type) { case "error": diff --git a/template/account/account.templ b/template/account/account.templ index 03858f6..b0fb7b9 100644 --- a/template/account/account.templ +++ b/template/account/account.templ @@ -2,48 +2,112 @@ package account import "fmt" import "spend-sparrow/template/svg" +import "spend-sparrow/types" -templ Account() { +templ Account(accounts []*types.Account) {