feat(account): #66 remove db interface for accounts #73
134
db/account.go
134
db/account.go
@@ -1,134 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"spend-sparrow/log"
|
||||
"spend-sparrow/types"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// While it may be duplicated to check for userId in the database access, it serves as a security layer
|
||||
type Account interface {
|
||||
Insert(userId uuid.UUID, account *types.Account) error
|
||||
Update(userId uuid.UUID, account *types.Account) error
|
||||
GetAll(userId uuid.UUID) ([]*types.Account, error)
|
||||
Get(userId uuid.UUID, id uuid.UUID) (*types.Account, error)
|
||||
Delete(userId uuid.UUID, id uuid.UUID) error
|
||||
}
|
||||
|
||||
type AccountSqlite struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func NewAccountSqlite(db *sqlx.DB) *AccountSqlite {
|
||||
return &AccountSqlite{db: db}
|
||||
}
|
||||
|
||||
func (db AccountSqlite) Insert(userId uuid.UUID, account *types.Account) error {
|
||||
|
||||
_, err := db.db.Exec(`
|
||||
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
|
||||
VALUES (?,?,?,?,?,?,?)`, account.Id, userId, account.Name, 0, 0, account.CreatedAt, account.CreatedBy)
|
||||
if err != nil {
|
||||
log.Error("account Insert: %v", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AccountSqlite) Update(userId uuid.UUID, account *types.Account) error {
|
||||
|
||||
r, err := db.db.Exec(`
|
||||
UPDATE account
|
||||
SET
|
||||
name = ?,
|
||||
current_balance = ?,
|
||||
last_transaction = ?,
|
||||
oink_balance = ?,
|
||||
updated_at = ?,
|
||||
updated_by = ?
|
||||
WHERE id = ?
|
||||
AND user_id = ?`, account.Name, account.CurrentBalance, account.LastTransaction, account.OinkBalance, account.UpdatedAt, account.UpdatedBy, account.Id, userId)
|
||||
if err != nil {
|
||||
log.Error("account Update: %v", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
rows, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
log.Error("account Update: %v", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
log.Info("account Update: not found")
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AccountSqlite) GetAll(userId uuid.UUID) ([]*types.Account, error) {
|
||||
|
||||
accounts := make([]*types.Account, 0)
|
||||
err := db.db.Select(&accounts, `
|
||||
SELECT
|
||||
id, user_id, name,
|
||||
current_balance, last_transaction, oink_balance,
|
||||
created_at, created_by, updated_at, updated_by
|
||||
FROM account
|
||||
WHERE user_id = ?
|
||||
ORDER BY name`, userId)
|
||||
if err != nil {
|
||||
log.Error("account GetAll: %v", err)
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func (db AccountSqlite) Get(userId uuid.UUID, id uuid.UUID) (*types.Account, error) {
|
||||
|
||||
account := &types.Account{}
|
||||
err := db.db.Get(account, `
|
||||
SELECT
|
||||
id, user_id, name,
|
||||
current_balance, last_transaction, oink_balance,
|
||||
created_at, created_by, updated_at, updated_by
|
||||
FROM account
|
||||
WHERE user_id = ?
|
||||
AND id = ?`, userId, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
log.Error("account Get: %v", err)
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (db AccountSqlite) Delete(userId uuid.UUID, id uuid.UUID) error {
|
||||
|
||||
res, err := db.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", id, userId)
|
||||
if err != nil {
|
||||
log.Error("account Delete: %v", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
log.Error("account Delete: %v", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
log.Info("account Delete: not found")
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -107,7 +107,7 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
account, err = h.s.Update(user, id, name)
|
||||
account, err = h.s.UpdateName(user, id, name)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"spend-sparrow/db"
|
||||
"spend-sparrow/service"
|
||||
"spend-sparrow/utils"
|
||||
"strings"
|
||||
@@ -15,6 +16,9 @@ func handleError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
} else if errors.Is(err, service.ErrBadRequest) {
|
||||
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
|
||||
return
|
||||
} else if errors.Is(err, db.ErrNotFound) {
|
||||
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
||||
|
||||
3
main.go
3
main.go
@@ -107,7 +107,6 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||
var router = http.NewServeMux()
|
||||
|
||||
authDb := db.NewAuthSqlite(d)
|
||||
accountDb := db.NewAccountSqlite(d)
|
||||
treasureChestDb := db.NewTreasureChestSqlite(d)
|
||||
|
||||
randomService := service.NewRandom()
|
||||
@@ -115,7 +114,7 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||
mailService := service.NewMail(serverSettings)
|
||||
|
||||
authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings)
|
||||
accountService := service.NewAccount(accountDb, randomService, clockService, serverSettings)
|
||||
accountService := service.NewAccount(d, randomService, clockService, serverSettings)
|
||||
treasureChestService := service.NewTreasureChest(treasureChestDb, randomService, clockService, serverSettings)
|
||||
transactionService := service.NewTransaction(d, randomService, clockService, serverSettings)
|
||||
|
||||
|
||||
@@ -1754,7 +1754,7 @@ func TestIntegrationAccount(t *testing.T) {
|
||||
req.Header.Set("Cookie", "id="+sessionId)
|
||||
resp, err = httpClient.Do(req)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
assert.NotContains(t, readBody(t, resp.Body), expectedNewName)
|
||||
})
|
||||
t.Run(`should not be able to see other users content`, func(t *testing.T) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"spend-sparrow/types"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
@@ -27,20 +28,20 @@ var (
|
||||
|
||||
type Account interface {
|
||||
Add(user *types.User, name string) (*types.Account, error)
|
||||
Update(user *types.User, id string, name string) (*types.Account, error)
|
||||
UpdateName(user *types.User, id string, name string) (*types.Account, error)
|
||||
Get(user *types.User, id string) (*types.Account, error)
|
||||
GetAll(user *types.User) ([]*types.Account, error)
|
||||
Delete(user *types.User, id string) error
|
||||
}
|
||||
|
||||
type AccountImpl struct {
|
||||
db db.Account
|
||||
db *sqlx.DB
|
||||
clock Clock
|
||||
random Random
|
||||
settings *types.Settings
|
||||
}
|
||||
|
||||
func NewAccount(db db.Account, random Random, clock Clock, settings *types.Settings) Account {
|
||||
func NewAccount(db *sqlx.DB, random Random, clock Clock, settings *types.Settings) Account {
|
||||
return AccountImpl{
|
||||
db: db,
|
||||
clock: clock,
|
||||
@@ -82,20 +83,18 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
|
||||
UpdatedBy: nil,
|
||||
}
|
||||
|
||||
err = s.db.Insert(user.Id, account)
|
||||
r, err := s.db.NamedExec(`
|
||||
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
|
||||
VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account)
|
||||
err = db.TransformAndLogDbError("account Insert", r, err)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedAccount, err := s.db.Get(user.Id, newId)
|
||||
if err != nil {
|
||||
log.Error("account %v not found after insert: %v", newId, err)
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
return savedAccount, nil
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Account, error) {
|
||||
func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) {
|
||||
accountMetric.WithLabelValues("update").Inc()
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
@@ -110,7 +109,9 @@ func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Ac
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
account, err := s.db.Get(user.Id, uuid)
|
||||
var account types.Account
|
||||
err = s.db.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("account Update", nil, err)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
|
||||
@@ -123,12 +124,20 @@ func (s AccountImpl) Update(user *types.User, id string, name string) (*types.Ac
|
||||
account.UpdatedAt = ×tamp
|
||||
account.UpdatedBy = &user.Id
|
||||
|
||||
err = s.db.Update(user.Id, account)
|
||||
r, err := s.db.NamedExec(`
|
||||
UPDATE account
|
||||
SET
|
||||
name = :name,
|
||||
updated_at = :updated_at,
|
||||
updated_by = :updated_by
|
||||
WHERE id = :id
|
||||
AND user_id = :user_id`, account)
|
||||
err = db.TransformAndLogDbError("account Update", r, err)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
||||
@@ -143,12 +152,13 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
account, err := s.db.Get(user.Id, uuid)
|
||||
account := &types.Account{}
|
||||
err = s.db.Get(account, `
|
||||
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("account Get", nil, err)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
log.Error("account get: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
@@ -160,9 +170,12 @@ func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
accounts, err := s.db.GetAll(user.Id)
|
||||
accounts := make([]*types.Account, 0)
|
||||
err := s.db.Select(&accounts, `
|
||||
SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id)
|
||||
err = db.TransformAndLogDbError("account GetAll", nil, err)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
@@ -179,22 +192,10 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
|
||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
account, err := s.db.Get(user.Id, uuid)
|
||||
res, err := s.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
|
||||
err = db.TransformAndLogDbError("account Delete", res, err)
|
||||
if err != nil {
|
||||
if err == db.ErrNotFound {
|
||||
return fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
|
||||
return err
|
||||
}
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
if account.UserId != user.Id {
|
||||
return types.ErrUnauthorized
|
||||
}
|
||||
|
||||
err = s.db.Delete(user.Id, account.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user