diff --git a/db/account.go b/db/account.go deleted file mode 100644 index 95e90e9..0000000 --- a/db/account.go +++ /dev/null @@ -1,134 +0,0 @@ -package db - -import ( - "database/sql" - "spend-sparrow/log" - "spend-sparrow/types" - - "github.com/google/uuid" - "github.com/jmoiron/sqlx" -) - -// While it may be duplicated to check for userId in the database access, it serves as a security layer -type Account interface { - Insert(userId uuid.UUID, account *types.Account) error - Update(userId uuid.UUID, account *types.Account) error - GetAll(userId uuid.UUID) ([]*types.Account, error) - Get(userId uuid.UUID, id uuid.UUID) (*types.Account, error) - Delete(userId uuid.UUID, id uuid.UUID) error -} - -type AccountSqlite struct { - db *sqlx.DB -} - -func NewAccountSqlite(db *sqlx.DB) *AccountSqlite { - return &AccountSqlite{db: db} -} - -func (db AccountSqlite) Insert(userId uuid.UUID, account *types.Account) error { - - _, err := db.db.Exec(` - INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by) - VALUES (?,?,?,?,?,?,?)`, account.Id, userId, account.Name, 0, 0, account.CreatedAt, account.CreatedBy) - if err != nil { - log.Error("account Insert: %v", err) - return types.ErrInternal - } - - return nil -} - -func (db AccountSqlite) Update(userId uuid.UUID, account *types.Account) error { - - r, err := db.db.Exec(` - UPDATE account - SET - name = ?, - current_balance = ?, - last_transaction = ?, - oink_balance = ?, - updated_at = ?, - updated_by = ? - WHERE id = ? - AND user_id = ?`, account.Name, account.CurrentBalance, account.LastTransaction, account.OinkBalance, account.UpdatedAt, account.UpdatedBy, account.Id, userId) - if err != nil { - log.Error("account Update: %v", err) - return types.ErrInternal - } - rows, err := r.RowsAffected() - if err != nil { - log.Error("account Update: %v", err) - return types.ErrInternal - } - - if rows == 0 { - log.Info("account Update: not found") - return ErrNotFound - } - - return nil -} - -func (db AccountSqlite) GetAll(userId uuid.UUID) ([]*types.Account, error) { - - accounts := make([]*types.Account, 0) - err := db.db.Select(&accounts, ` - SELECT - id, user_id, name, - current_balance, last_transaction, oink_balance, - created_at, created_by, updated_at, updated_by - FROM account - WHERE user_id = ? - ORDER BY name`, userId) - if err != nil { - log.Error("account GetAll: %v", err) - return nil, types.ErrInternal - } - - return accounts, nil -} - -func (db AccountSqlite) Get(userId uuid.UUID, id uuid.UUID) (*types.Account, error) { - - account := &types.Account{} - err := db.db.Get(account, ` - SELECT - id, user_id, name, - current_balance, last_transaction, oink_balance, - created_at, created_by, updated_at, updated_by - FROM account - WHERE user_id = ? - AND id = ?`, userId, id) - if err != nil { - if err == sql.ErrNoRows { - return nil, ErrNotFound - } - log.Error("account Get: %v", err) - return nil, types.ErrInternal - } - - return account, nil -} - -func (db AccountSqlite) Delete(userId uuid.UUID, id uuid.UUID) error { - - res, err := db.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", id, userId) - if err != nil { - log.Error("account Delete: %v", err) - return types.ErrInternal - } - - rows, err := res.RowsAffected() - if err != nil { - log.Error("account Delete: %v", err) - return types.ErrInternal - } - - if rows == 0 { - log.Info("account Delete: not found") - return ErrNotFound - } - - return nil -} diff --git a/handler/account.go b/handler/account.go index c7294ed..5b2216e 100644 --- a/handler/account.go +++ b/handler/account.go @@ -107,7 +107,7 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc { return } } else { - account, err = h.s.Update(user, id, name) + account, err = h.s.UpdateName(user, id, name) if err != nil { handleError(w, r, err) return diff --git a/handler/error.go b/handler/error.go index 6d8e043..959291a 100644 --- a/handler/error.go +++ b/handler/error.go @@ -3,6 +3,7 @@ package handler import ( "errors" "net/http" + "spend-sparrow/db" "spend-sparrow/service" "spend-sparrow/utils" "strings" @@ -15,6 +16,9 @@ func handleError(w http.ResponseWriter, r *http.Request, err error) { } else if errors.Is(err, service.ErrBadRequest) { utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest) return + } else if errors.Is(err, db.ErrNotFound) { + utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound) + return } utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) diff --git a/main.go b/main.go index 69f4c2a..5e37929 100644 --- a/main.go +++ b/main.go @@ -107,7 +107,6 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler { var router = http.NewServeMux() authDb := db.NewAuthSqlite(d) - accountDb := db.NewAccountSqlite(d) treasureChestDb := db.NewTreasureChestSqlite(d) randomService := service.NewRandom() @@ -115,7 +114,7 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler { mailService := service.NewMail(serverSettings) authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings) - accountService := service.NewAccount(accountDb, randomService, clockService, serverSettings) + accountService := service.NewAccount(d, randomService, clockService, serverSettings) treasureChestService := service.NewTreasureChest(treasureChestDb, randomService, clockService, serverSettings) transactionService := service.NewTransaction(d, randomService, clockService, serverSettings) diff --git a/main_test.go b/main_test.go index 83919c8..77370f2 100644 --- a/main_test.go +++ b/main_test.go @@ -1754,7 +1754,7 @@ func TestIntegrationAccount(t *testing.T) { req.Header.Set("Cookie", "id="+sessionId) resp, err = httpClient.Do(req) assert.Nil(t, err) - assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) assert.NotContains(t, readBody(t, resp.Body), expectedNewName) }) t.Run(`should not be able to see other users content`, func(t *testing.T) { diff --git a/service/account.go b/service/account.go index 1b1d53f..068d8f6 100644 --- a/service/account.go +++ b/service/account.go @@ -9,6 +9,7 @@ import ( "spend-sparrow/types" "github.com/google/uuid" + "github.com/jmoiron/sqlx" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) @@ -27,20 +28,20 @@ var ( type Account interface { Add(user *types.User, name string) (*types.Account, error) - Update(user *types.User, id string, name string) (*types.Account, error) + UpdateName(user *types.User, id string, name string) (*types.Account, error) Get(user *types.User, id string) (*types.Account, error) GetAll(user *types.User) ([]*types.Account, error) Delete(user *types.User, id string) error } type AccountImpl struct { - db db.Account + db *sqlx.DB clock Clock random Random settings *types.Settings } -func NewAccount(db db.Account, random Random, clock Clock, settings *types.Settings) Account { +func NewAccount(db *sqlx.DB, random Random, clock Clock, settings *types.Settings) Account { return AccountImpl{ db: db, clock: clock, @@ -82,20 +83,18 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) UpdatedBy: nil, } - err = s.db.Insert(user.Id, account) + r, err := s.db.NamedExec(` + INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by) + VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account) + err = db.TransformAndLogDbError("account Insert", r, err) if err != nil { - return nil, types.ErrInternal + return nil, err } - savedAccount, err := s.db.Get(user.Id, newId) - if err != nil { - log.Error("account %v not found after insert: %v", newId, err) - return nil, types.ErrInternal - } - return savedAccount, nil + return account, nil } -func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Account, error) { +func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) { accountMetric.WithLabelValues("update").Inc() if user == nil { return nil, ErrUnauthorized @@ -110,7 +109,9 @@ func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Ac return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - account, err := s.db.Get(user.Id, uuid) + var account types.Account + err = s.db.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = db.TransformAndLogDbError("account Update", nil, err) if err != nil { if err == db.ErrNotFound { return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest) @@ -123,12 +124,20 @@ func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Ac account.UpdatedAt = ×tamp account.UpdatedBy = &user.Id - err = s.db.Update(user.Id, account) + r, err := s.db.NamedExec(` + UPDATE account + SET + name = :name, + updated_at = :updated_at, + updated_by = :updated_by + WHERE id = :id + AND user_id = :user_id`, account) + err = db.TransformAndLogDbError("account Update", r, err) if err != nil { - return nil, types.ErrInternal + return nil, err } - return account, nil + return &account, nil } func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { @@ -143,12 +152,13 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - account, err := s.db.Get(user.Id, uuid) + account := &types.Account{} + err = s.db.Get(account, ` + SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = db.TransformAndLogDbError("account Get", nil, err) if err != nil { - if err == db.ErrNotFound { - return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest) - } - return nil, types.ErrInternal + log.Error("account get: %v", err) + return nil, err } return account, nil @@ -160,9 +170,12 @@ func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { return nil, ErrUnauthorized } - accounts, err := s.db.GetAll(user.Id) + accounts := make([]*types.Account, 0) + err := s.db.Select(&accounts, ` + SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id) + err = db.TransformAndLogDbError("account GetAll", nil, err) if err != nil { - return nil, types.ErrInternal + return nil, err } return accounts, nil @@ -179,22 +192,10 @@ func (s AccountImpl) Delete(user *types.User, id string) error { return fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - account, err := s.db.Get(user.Id, uuid) + res, err := s.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id) + err = db.TransformAndLogDbError("account Delete", res, err) if err != nil { - if err == db.ErrNotFound { - return fmt.Errorf("account %v not found: %w", id, ErrBadRequest) - } - return types.ErrInternal + return err } - - if account.UserId != user.Id { - return types.ErrUnauthorized - } - - err = s.db.Delete(user.Id, account.Id) - if err != nil { - return types.ErrInternal - } - return nil }