From 5cfea4e2d38939e61d04a4efba1e52aa2dd06662 Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Thu, 8 May 2025 21:45:53 +0200 Subject: [PATCH] feat(account): #49 refactor error handling --- db/account.go | 23 ++-- db/error.go | 10 ++ db/{default.go => migration.go} | 5 - handler/account.go | 22 ++-- handler/error.go | 30 +++++ .../middleware/cross_site_request_forgery.go | 4 - handler/middleware/gzip.go | 8 +- handler/middleware/logger.go | 2 +- handler/root_and_404.go | 3 +- main.go | 3 +- service/account.go | 105 ++++++++++-------- service/error.go | 8 ++ service/random_generator.go | 8 +- template/account/account.templ | 30 ----- 14 files changed, 142 insertions(+), 119 deletions(-) create mode 100644 db/error.go rename db/{default.go => migration.go} (87%) create mode 100644 handler/error.go create mode 100644 service/error.go diff --git a/db/account.go b/db/account.go index c3e891c..95e90e9 100644 --- a/db/account.go +++ b/db/account.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "spend-sparrow/log" "spend-sparrow/types" @@ -31,7 +32,7 @@ func (db AccountSqlite) Insert(userId uuid.UUID, account *types.Account) error { INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by) VALUES (?,?,?,?,?,?,?)`, account.Id, userId, account.Name, 0, 0, account.CreatedAt, account.CreatedBy) if err != nil { - log.Error("Error inserting account: %v", err) + log.Error("account Insert: %v", err) return types.ErrInternal } @@ -40,7 +41,6 @@ func (db AccountSqlite) Insert(userId uuid.UUID, account *types.Account) error { func (db AccountSqlite) Update(userId uuid.UUID, account *types.Account) error { - log.Info("Updating account: %v", account) r, err := db.db.Exec(` UPDATE account SET @@ -53,17 +53,17 @@ func (db AccountSqlite) Update(userId uuid.UUID, account *types.Account) error { WHERE id = ? AND user_id = ?`, account.Name, account.CurrentBalance, account.LastTransaction, account.OinkBalance, account.UpdatedAt, account.UpdatedBy, account.Id, userId) if err != nil { - log.Error("Error updating account: %v", err) + log.Error("account Update: %v", err) return types.ErrInternal } rows, err := r.RowsAffected() if err != nil { - log.Error("Error deleting account, getting rows affected: %v", err) + log.Error("account Update: %v", err) return types.ErrInternal } if rows == 0 { - log.Error("Error deleting account, rows affected: %v", rows) + log.Info("account Update: not found") return ErrNotFound } @@ -82,7 +82,7 @@ func (db AccountSqlite) GetAll(userId uuid.UUID) ([]*types.Account, error) { WHERE user_id = ? ORDER BY name`, userId) if err != nil { - log.Error("Could not getAll accounts: %v", err) + log.Error("account GetAll: %v", err) return nil, types.ErrInternal } @@ -101,7 +101,10 @@ func (db AccountSqlite) Get(userId uuid.UUID, id uuid.UUID) (*types.Account, err WHERE user_id = ? AND id = ?`, userId, id) if err != nil { - log.Error("Could not get accounts: %v", err) + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + log.Error("account Get: %v", err) return nil, types.ErrInternal } @@ -112,18 +115,18 @@ func (db AccountSqlite) Delete(userId uuid.UUID, id uuid.UUID) error { res, err := db.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", id, userId) if err != nil { - log.Error("Error deleting account: %v", err) + log.Error("account Delete: %v", err) return types.ErrInternal } rows, err := res.RowsAffected() if err != nil { - log.Error("Error deleting account, getting rows affected: %v", err) + log.Error("account Delete: %v", err) return types.ErrInternal } if rows == 0 { - log.Error("Error deleting account, rows affected: %v", rows) + log.Info("account Delete: not found") return ErrNotFound } diff --git a/db/error.go b/db/error.go new file mode 100644 index 0000000..d427c11 --- /dev/null +++ b/db/error.go @@ -0,0 +1,10 @@ +package db + +import ( + "errors" +) + +var ( + ErrNotFound = errors.New("the value does not exist") + ErrAlreadyExists = errors.New("row already exists") +) diff --git a/db/default.go b/db/migration.go similarity index 87% rename from db/default.go rename to db/migration.go index c4a1b11..e4d8357 100644 --- a/db/default.go +++ b/db/migration.go @@ -12,11 +12,6 @@ import ( "github.com/jmoiron/sqlx" ) -var ( - ErrNotFound = errors.New("the value does not exist") - ErrAlreadyExists = errors.New("row already exists") -) - func RunMigrations(db *sqlx.DB, pathPrefix string) error { driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) if err != nil { diff --git a/handler/account.go b/handler/account.go index 6525735..49d451e 100644 --- a/handler/account.go +++ b/handler/account.go @@ -1,15 +1,14 @@ package handler import ( + "fmt" + "net/http" "spend-sparrow/handler/middleware" - "spend-sparrow/log" "spend-sparrow/service" t "spend-sparrow/template/account" "spend-sparrow/types" "spend-sparrow/utils" - "net/http" - "github.com/a-h/templ" "github.com/google/uuid" ) @@ -49,7 +48,7 @@ func (h AccountImpl) handleAccountPage() http.HandlerFunc { accounts, err := h.s.GetAll(user) if err != nil { - utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) + handleError(w, r, err) return } @@ -69,20 +68,19 @@ func (h AccountImpl) handleAccountItemComp() http.HandlerFunc { idStr := r.PathValue("id") if idStr == "new" { comp := t.EditAccount(nil) - log.Info("Component: %v", comp) h.r.Render(r, w, comp) return } id, err := uuid.Parse(idStr) if err != nil { - utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest) + handleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)) return } account, err := h.s.Get(user, id) if err != nil { - utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) + handleError(w, r, err) return } @@ -113,18 +111,18 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc { if idStr == "new" { account, err = h.s.Add(user, name) if err != nil { - utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError) + handleError(w, r, err) return } } else { id, err := uuid.Parse(idStr) if err != nil { - utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest) + handleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)) return } account, err = h.s.Update(user, id, name) if err != nil { - utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError) + handleError(w, r, err) return } } @@ -144,13 +142,13 @@ func (h AccountImpl) handleDeleteAccount() http.HandlerFunc { id, err := uuid.Parse(r.PathValue("id")) if err != nil { - utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest) + handleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)) return } err = h.s.Delete(user, id) if err != nil { - utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError) + handleError(w, r, err) return } } diff --git a/handler/error.go b/handler/error.go new file mode 100644 index 0000000..6d8e043 --- /dev/null +++ b/handler/error.go @@ -0,0 +1,30 @@ +package handler + +import ( + "errors" + "net/http" + "spend-sparrow/service" + "spend-sparrow/utils" + "strings" +) + +func handleError(w http.ResponseWriter, r *http.Request, err error) { + if errors.Is(err, service.ErrUnauthorized) { + utils.TriggerToastWithStatus(w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized) + return + } else if errors.Is(err, service.ErrBadRequest) { + utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest) + return + } + + utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) +} + +func extractErrorMessage(err error) string { + errMsg := err.Error() + if errMsg == "" { + return "" + } + + return strings.SplitN(errMsg, ":", 2)[0] +} diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 663a2fc..0f68067 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -37,10 +37,6 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) { return rr.ResponseWriter.Write([]byte(dataStr)) } -func (rr *csrfResponseWriter) WriteHeader(statusCode int) { - rr.ResponseWriter.WriteHeader(statusCode) -} - func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/handler/middleware/gzip.go b/handler/middleware/gzip.go index 2d5ea62..77f5e45 100644 --- a/handler/middleware/gzip.go +++ b/handler/middleware/gzip.go @@ -27,11 +27,13 @@ func Gzip(next http.Handler) http.Handler { w.Header().Set("Content-Encoding", "gzip") gz := gzip.NewWriter(w) - gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w} - next.ServeHTTP(gzr, r) + wrapper := gzipResponseWriter{Writer: gz, ResponseWriter: w} + + next.ServeHTTP(wrapper, r) err := gz.Close() - if err != nil { + if err != nil && err != http.ErrBodyNotAllowed { + // if err != nil { log.Error("Gzip: could not close Writer: %v", err) } }) diff --git a/handler/middleware/logger.go b/handler/middleware/logger.go index 126d38e..f0ec9a0 100644 --- a/handler/middleware/logger.go +++ b/handler/middleware/logger.go @@ -27,8 +27,8 @@ type WrappedWriter struct { } func (w *WrappedWriter) WriteHeader(code int) { - w.ResponseWriter.WriteHeader(code) w.StatusCode = code + w.ResponseWriter.WriteHeader(code) } func Log(next http.Handler) http.Handler { diff --git a/handler/root_and_404.go b/handler/root_and_404.go index 569b927..f4f66d3 100644 --- a/handler/root_and_404.go +++ b/handler/root_and_404.go @@ -1,12 +1,11 @@ package handler import ( + "net/http" "spend-sparrow/handler/middleware" "spend-sparrow/service" "spend-sparrow/template" - "net/http" - "github.com/a-h/templ" ) diff --git a/main.go b/main.go index 4f271ed..3649a8c 100644 --- a/main.go +++ b/main.go @@ -134,8 +134,7 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler { middleware.CacheControl, middleware.CrossSiteRequestForgery(authService), middleware.Authenticate(authService), - middleware.Log, - // Gzip last, as it compresses the body middleware.Gzip, + middleware.Log, ) } diff --git a/service/account.go b/service/account.go index 60be33e..48ac0cd 100644 --- a/service/account.go +++ b/service/account.go @@ -1,7 +1,7 @@ package service import ( - "errors" + "fmt" "regexp" "spend-sparrow/db" @@ -39,17 +39,17 @@ func NewAccountImpl(db db.Account, random Random, clock Clock, settings *types.S } } -func (service AccountImpl) Add(user *types.User, name string) (*types.Account, error) { +func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) { if user == nil { - return nil, types.ErrInternal + return nil, ErrUnauthorized } - newId, err := service.random.UUID() + newId, err := s.random.UUID() if err != nil { return nil, types.ErrInternal } - err = service.validateAccount(name) + err = s.validateAccount(name) if err != nil { return nil, err } @@ -64,61 +64,48 @@ func (service AccountImpl) Add(user *types.User, name string) (*types.Account, e LastTransaction: nil, OinkBalance: 0, - CreatedAt: service.clock.Now(), + CreatedAt: s.clock.Now(), CreatedBy: user.Id, UpdatedAt: nil, UpdatedBy: nil, } - err = service.db.Insert(user.Id, account) + err = s.db.Insert(user.Id, account) if err != nil { - return nil, err + return nil, types.ErrInternal } - savedAccount, err := service.db.Get(user.Id, newId) + savedAccount, err := s.db.Get(user.Id, newId) if err != nil { - if errors.Is(err, db.ErrNotFound) { - log.Error("Account not found after insert: %v", err) - } + log.Error("account %v not found after insert: %v", newId, err) return nil, types.ErrInternal } return savedAccount, nil } -func (service AccountImpl) Update(user *types.User, id uuid.UUID, name string) (*types.Account, error) { +func (s AccountImpl) Update(user *types.User, id uuid.UUID, name string) (*types.Account, error) { if user == nil { + return nil, ErrUnauthorized + } + err := s.validateAccount(name) + if err != nil { + return nil, err + } + + account, err := s.db.Get(user.Id, id) + if err != nil { + if err == db.ErrNotFound { + return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest) + } return nil, types.ErrInternal } - err := service.validateAccount(name) - if err != nil { - return nil, err - } - account, err := service.db.Get(user.Id, id) - if err != nil { - return nil, err - } - - timestamp := service.clock.Now() + timestamp := s.clock.Now() account.Name = name account.UpdatedAt = ×tamp account.UpdatedBy = &user.Id - err = service.db.Update(user.Id, account) - if err != nil { - return nil, err - } - - return account, nil -} - -func (service AccountImpl) Get(user *types.User, id uuid.UUID) (*types.Account, error) { - - if user == nil { - return nil, types.ErrInternal - } - - account, err := service.db.Get(user.Id, id) + err = s.db.Update(user.Id, account) if err != nil { return nil, types.ErrInternal } @@ -126,13 +113,30 @@ func (service AccountImpl) Get(user *types.User, id uuid.UUID) (*types.Account, return account, nil } -func (service AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { +func (s AccountImpl) Get(user *types.User, id uuid.UUID) (*types.Account, error) { if user == nil { + return nil, ErrUnauthorized + } + + account, err := s.db.Get(user.Id, id) + if err != nil { + if err == db.ErrNotFound { + return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest) + } return nil, types.ErrInternal } - accounts, err := service.db.GetAll(user.Id) + return account, nil +} + +func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { + + if user == nil { + return nil, ErrUnauthorized + } + + accounts, err := s.db.GetAll(user.Id) if err != nil { return nil, types.ErrInternal } @@ -140,33 +144,36 @@ func (service AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { return accounts, nil } -func (service AccountImpl) Delete(user *types.User, id uuid.UUID) error { +func (s AccountImpl) Delete(user *types.User, id uuid.UUID) error { if user == nil { - return types.ErrInternal + return ErrUnauthorized } - account, err := service.db.Get(user.Id, id) + account, err := s.db.Get(user.Id, id) if err != nil { - return err + if err == db.ErrNotFound { + return fmt.Errorf("account %v not found: %w", id, ErrBadRequest) + } + return types.ErrInternal } if account.UserId != user.Id { return types.ErrUnauthorized } - err = service.db.Delete(user.Id, account.Id) + err = s.db.Delete(user.Id, account.Id) if err != nil { - return err + return types.ErrInternal } return nil } -func (service AccountImpl) validateAccount(name string) error { +func (s AccountImpl) validateAccount(name string) error { if name == "" { - return errors.New("please enter a value for the \"name\" field") + return fmt.Errorf("field \"name\" needs to be set: %w", ErrBadRequest) } else if !safeInputRegex.MatchString(name) { - return errors.New("please use only letters, dashes or numbers for \"name\"") + return fmt.Errorf("use only letters, dashes and spaces for \"name\": %w", ErrBadRequest) } else { return nil } diff --git a/service/error.go b/service/error.go new file mode 100644 index 0000000..4f5da50 --- /dev/null +++ b/service/error.go @@ -0,0 +1,8 @@ +package service + +import "errors" + +var ( + ErrBadRequest = errors.New("bad request") + ErrUnauthorized = errors.New("unauthorized") +) diff --git a/service/random_generator.go b/service/random_generator.go index e564b60..8e768ed 100644 --- a/service/random_generator.go +++ b/service/random_generator.go @@ -45,5 +45,11 @@ func (r *RandomImpl) String(size int) (string, error) { } func (r *RandomImpl) UUID() (uuid.UUID, error) { - return uuid.NewRandom() + id, err := uuid.NewRandom() + if err != nil { + log.Error("Error generating random UUID: %v", err) + return uuid.Nil, types.ErrInternal + } + + return id, nil } diff --git a/template/account/account.templ b/template/account/account.templ index 8050f9b..b0fb7b9 100644 --- a/template/account/account.templ +++ b/template/account/account.templ @@ -18,36 +18,6 @@ templ Account(accounts []*types.Account) {
for _, account := range accounts { @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) - @AccountItem(account) }