feat(account): #66 remove db interface for accounts
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 4m56s
Build and Push Docker Image / Build-And-Push-Docker-Image (push) Successful in 5m33s

This commit was merged in pull request #73.
This commit is contained in:
2025-05-16 11:37:46 +02:00
parent dbf272e3f3
commit 7e244ccc07
6 changed files with 46 additions and 176 deletions

View File

@@ -1,134 +0,0 @@
package db
import (
"database/sql"
"spend-sparrow/log"
"spend-sparrow/types"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
)
// While it may be duplicated to check for userId in the database access, it serves as a security layer
type Account interface {
Insert(userId uuid.UUID, account *types.Account) error
Update(userId uuid.UUID, account *types.Account) error
GetAll(userId uuid.UUID) ([]*types.Account, error)
Get(userId uuid.UUID, id uuid.UUID) (*types.Account, error)
Delete(userId uuid.UUID, id uuid.UUID) error
}
type AccountSqlite struct {
db *sqlx.DB
}
func NewAccountSqlite(db *sqlx.DB) *AccountSqlite {
return &AccountSqlite{db: db}
}
func (db AccountSqlite) Insert(userId uuid.UUID, account *types.Account) error {
_, err := db.db.Exec(`
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
VALUES (?,?,?,?,?,?,?)`, account.Id, userId, account.Name, 0, 0, account.CreatedAt, account.CreatedBy)
if err != nil {
log.Error("account Insert: %v", err)
return types.ErrInternal
}
return nil
}
func (db AccountSqlite) Update(userId uuid.UUID, account *types.Account) error {
r, err := db.db.Exec(`
UPDATE account
SET
name = ?,
current_balance = ?,
last_transaction = ?,
oink_balance = ?,
updated_at = ?,
updated_by = ?
WHERE id = ?
AND user_id = ?`, account.Name, account.CurrentBalance, account.LastTransaction, account.OinkBalance, account.UpdatedAt, account.UpdatedBy, account.Id, userId)
if err != nil {
log.Error("account Update: %v", err)
return types.ErrInternal
}
rows, err := r.RowsAffected()
if err != nil {
log.Error("account Update: %v", err)
return types.ErrInternal
}
if rows == 0 {
log.Info("account Update: not found")
return ErrNotFound
}
return nil
}
func (db AccountSqlite) GetAll(userId uuid.UUID) ([]*types.Account, error) {
accounts := make([]*types.Account, 0)
err := db.db.Select(&accounts, `
SELECT
id, user_id, name,
current_balance, last_transaction, oink_balance,
created_at, created_by, updated_at, updated_by
FROM account
WHERE user_id = ?
ORDER BY name`, userId)
if err != nil {
log.Error("account GetAll: %v", err)
return nil, types.ErrInternal
}
return accounts, nil
}
func (db AccountSqlite) Get(userId uuid.UUID, id uuid.UUID) (*types.Account, error) {
account := &types.Account{}
err := db.db.Get(account, `
SELECT
id, user_id, name,
current_balance, last_transaction, oink_balance,
created_at, created_by, updated_at, updated_by
FROM account
WHERE user_id = ?
AND id = ?`, userId, id)
if err != nil {
if err == sql.ErrNoRows {
return nil, ErrNotFound
}
log.Error("account Get: %v", err)
return nil, types.ErrInternal
}
return account, nil
}
func (db AccountSqlite) Delete(userId uuid.UUID, id uuid.UUID) error {
res, err := db.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", id, userId)
if err != nil {
log.Error("account Delete: %v", err)
return types.ErrInternal
}
rows, err := res.RowsAffected()
if err != nil {
log.Error("account Delete: %v", err)
return types.ErrInternal
}
if rows == 0 {
log.Info("account Delete: not found")
return ErrNotFound
}
return nil
}

View File

@@ -107,7 +107,7 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc {
return return
} }
} else { } else {
account, err = h.s.Update(user, id, name) account, err = h.s.UpdateName(user, id, name)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"errors" "errors"
"net/http" "net/http"
"spend-sparrow/db"
"spend-sparrow/service" "spend-sparrow/service"
"spend-sparrow/utils" "spend-sparrow/utils"
"strings" "strings"
@@ -15,6 +16,9 @@ func handleError(w http.ResponseWriter, r *http.Request, err error) {
} else if errors.Is(err, service.ErrBadRequest) { } else if errors.Is(err, service.ErrBadRequest) {
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
return return
} else if errors.Is(err, db.ErrNotFound) {
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound)
return
} }
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)

View File

@@ -107,7 +107,6 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
var router = http.NewServeMux() var router = http.NewServeMux()
authDb := db.NewAuthSqlite(d) authDb := db.NewAuthSqlite(d)
accountDb := db.NewAccountSqlite(d)
treasureChestDb := db.NewTreasureChestSqlite(d) treasureChestDb := db.NewTreasureChestSqlite(d)
randomService := service.NewRandom() randomService := service.NewRandom()
@@ -115,7 +114,7 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
mailService := service.NewMail(serverSettings) mailService := service.NewMail(serverSettings)
authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings) authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings)
accountService := service.NewAccount(accountDb, randomService, clockService, serverSettings) accountService := service.NewAccount(d, randomService, clockService, serverSettings)
treasureChestService := service.NewTreasureChest(treasureChestDb, randomService, clockService, serverSettings) treasureChestService := service.NewTreasureChest(treasureChestDb, randomService, clockService, serverSettings)
transactionService := service.NewTransaction(d, randomService, clockService, serverSettings) transactionService := service.NewTransaction(d, randomService, clockService, serverSettings)

View File

@@ -1754,7 +1754,7 @@ func TestIntegrationAccount(t *testing.T) {
req.Header.Set("Cookie", "id="+sessionId) req.Header.Set("Cookie", "id="+sessionId)
resp, err = httpClient.Do(req) resp, err = httpClient.Do(req)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusNotFound, resp.StatusCode)
assert.NotContains(t, readBody(t, resp.Body), expectedNewName) assert.NotContains(t, readBody(t, resp.Body), expectedNewName)
}) })
t.Run(`should not be able to see other users content`, func(t *testing.T) { t.Run(`should not be able to see other users content`, func(t *testing.T) {

View File

@@ -9,6 +9,7 @@ import (
"spend-sparrow/types" "spend-sparrow/types"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
) )
@@ -27,20 +28,20 @@ var (
type Account interface { type Account interface {
Add(user *types.User, name string) (*types.Account, error) Add(user *types.User, name string) (*types.Account, error)
Update(user *types.User, id string, name string) (*types.Account, error) UpdateName(user *types.User, id string, name string) (*types.Account, error)
Get(user *types.User, id string) (*types.Account, error) Get(user *types.User, id string) (*types.Account, error)
GetAll(user *types.User) ([]*types.Account, error) GetAll(user *types.User) ([]*types.Account, error)
Delete(user *types.User, id string) error Delete(user *types.User, id string) error
} }
type AccountImpl struct { type AccountImpl struct {
db db.Account db *sqlx.DB
clock Clock clock Clock
random Random random Random
settings *types.Settings settings *types.Settings
} }
func NewAccount(db db.Account, random Random, clock Clock, settings *types.Settings) Account { func NewAccount(db *sqlx.DB, random Random, clock Clock, settings *types.Settings) Account {
return AccountImpl{ return AccountImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -82,20 +83,18 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
UpdatedBy: nil, UpdatedBy: nil,
} }
err = s.db.Insert(user.Id, account) r, err := s.db.NamedExec(`
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account)
err = db.TransformAndLogDbError("account Insert", r, err)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, err
} }
savedAccount, err := s.db.Get(user.Id, newId) return account, nil
if err != nil {
log.Error("account %v not found after insert: %v", newId, err)
return nil, types.ErrInternal
}
return savedAccount, nil
} }
func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Account, error) { func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) {
accountMetric.WithLabelValues("update").Inc() accountMetric.WithLabelValues("update").Inc()
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
@@ -110,7 +109,9 @@ func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Ac
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
account, err := s.db.Get(user.Id, uuid) var account types.Account
err = s.db.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if err == db.ErrNotFound {
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest) return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
@@ -123,12 +124,20 @@ func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Ac
account.UpdatedAt = &timestamp account.UpdatedAt = &timestamp
account.UpdatedBy = &user.Id account.UpdatedBy = &user.Id
err = s.db.Update(user.Id, account) r, err := s.db.NamedExec(`
UPDATE account
SET
name = :name,
updated_at = :updated_at,
updated_by = :updated_by
WHERE id = :id
AND user_id = :user_id`, account)
err = db.TransformAndLogDbError("account Update", r, err)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, err
} }
return account, nil return &account, nil
} }
func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
@@ -143,12 +152,13 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
account, err := s.db.Get(user.Id, uuid) account := &types.Account{}
err = s.db.Get(account, `
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Get", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { log.Error("account get: %v", err)
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest) return nil, err
}
return nil, types.ErrInternal
} }
return account, nil return account, nil
@@ -160,9 +170,12 @@ func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
accounts, err := s.db.GetAll(user.Id) accounts := make([]*types.Account, 0)
err := s.db.Select(&accounts, `
SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id)
err = db.TransformAndLogDbError("account GetAll", nil, err)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, err
} }
return accounts, nil return accounts, nil
@@ -179,22 +192,10 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
account, err := s.db.Get(user.Id, uuid) res, err := s.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("account Delete", res, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { return err
return fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
} }
return types.ErrInternal
}
if account.UserId != user.Id {
return types.ErrUnauthorized
}
err = s.db.Delete(user.Id, account.Id)
if err != nil {
return types.ErrInternal
}
return nil return nil
} }