fix: move implementation to "internal" package
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 4m49s
Build and Push Docker Image / Build-And-Push-Docker-Image (push) Successful in 5m8s

This commit was merged in pull request #138.
This commit is contained in:
2025-05-29 13:23:13 +02:00
parent 9bb0cc475d
commit 6219741634
72 changed files with 245 additions and 230 deletions

237
internal/service/account.go Normal file
View File

@@ -0,0 +1,237 @@
package service
import (
"errors"
"fmt"
"spend-sparrow/internal/db"
"spend-sparrow/internal/log"
"spend-sparrow/internal/types"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
accountMetric = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "spendsparrow_account_total",
Help: "The total of account operations",
},
[]string{"operation"},
)
)
type Account interface {
Add(user *types.User, name string) (*types.Account, error)
UpdateName(user *types.User, id string, name string) (*types.Account, error)
Get(user *types.User, id string) (*types.Account, error)
GetAll(user *types.User) ([]*types.Account, error)
Delete(user *types.User, id string) error
}
type AccountImpl struct {
db *sqlx.DB
clock Clock
random Random
}
func NewAccount(db *sqlx.DB, random Random, clock Clock) Account {
return AccountImpl{
db: db,
clock: clock,
random: random,
}
}
func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) {
accountMetric.WithLabelValues("add").Inc()
if user == nil {
return nil, ErrUnauthorized
}
newId, err := s.random.UUID()
if err != nil {
return nil, types.ErrInternal
}
err = validateString(name, "name")
if err != nil {
return nil, err
}
account := &types.Account{
Id: newId,
UserId: user.Id,
Name: name,
CurrentBalance: 0,
LastTransaction: nil,
OinkBalance: 0,
CreatedAt: s.clock.Now(),
CreatedBy: user.Id,
UpdatedAt: nil,
UpdatedBy: nil,
}
r, err := s.db.NamedExec(`
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account)
err = db.TransformAndLogDbError("account Insert", r, err)
if err != nil {
return nil, err
}
return account, nil
}
func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) {
accountMetric.WithLabelValues("update").Inc()
if user == nil {
return nil, ErrUnauthorized
}
err := validateString(name, "name")
if err != nil {
return nil, err
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("account update: %v", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()
var account types.Account
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
}
return nil, types.ErrInternal
}
timestamp := s.clock.Now()
account.Name = name
account.UpdatedAt = &timestamp
account.UpdatedBy = &user.Id
r, err := tx.NamedExec(`
UPDATE account
SET
name = :name,
updated_at = :updated_at,
updated_by = :updated_by
WHERE id = :id
AND user_id = :user_id`, account)
err = db.TransformAndLogDbError("account Update", r, err)
if err != nil {
return nil, err
}
err = tx.Commit()
err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil {
return nil, err
}
return &account, nil
}
func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
accountMetric.WithLabelValues("get").Inc()
if user == nil {
return nil, ErrUnauthorized
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("account get: %v", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
var account types.Account
err = s.db.Get(&account, `
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Get", nil, err)
if err != nil {
log.Error("account get: %v", err)
return nil, err
}
return &account, nil
}
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
accountMetric.WithLabelValues("get_all").Inc()
if user == nil {
return nil, ErrUnauthorized
}
accounts := make([]*types.Account, 0)
err := s.db.Select(&accounts, `
SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id)
err = db.TransformAndLogDbError("account GetAll", nil, err)
if err != nil {
return nil, err
}
return accounts, nil
}
func (s AccountImpl) Delete(user *types.User, id string) error {
accountMetric.WithLabelValues("delete").Inc()
if user == nil {
return ErrUnauthorized
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("account delete: %v", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("account Delete", nil, err)
if err != nil {
return err
}
defer func() {
_ = tx.Rollback()
}()
transactionsCount := 0
err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Delete", nil, err)
if err != nil {
return err
}
if transactionsCount > 0 {
return fmt.Errorf("account has transactions, cannot delete: %w", ErrBadRequest)
}
res, err := tx.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("account Delete", res, err)
if err != nil {
return err
}
err = tx.Commit()
err = db.TransformAndLogDbError("account Delete", nil, err)
if err != nil {
return err
}
return nil
}

510
internal/service/auth.go Normal file
View File

@@ -0,0 +1,510 @@
package service
import (
"context"
"crypto/subtle"
"errors"
"net/mail"
"spend-sparrow/internal/db"
"spend-sparrow/internal/log"
mailTemplate "spend-sparrow/internal/template/mail"
"spend-sparrow/internal/types"
"strings"
"time"
"github.com/google/uuid"
"golang.org/x/crypto/argon2"
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrInvalidPassword = errors.New("the password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character")
ErrInvalidEmail = errors.New("invalid email")
ErrAccountExists = errors.New("account already exists")
ErrSessionIdInvalid = errors.New("session ID is invalid")
ErrTokenInvalid = errors.New("token is invalid")
)
type Auth interface {
SignUp(email string, password string) (*types.User, error)
SendVerificationMail(userId uuid.UUID, email string)
VerifyUserEmail(token string) error
SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error)
SignInSession(sessionId string) (*types.Session, *types.User, error)
SignInAnonymous() (*types.Session, error)
SignOut(sessionId string) error
DeleteAccount(user *types.User, currPass string) error
ChangePassword(user *types.User, sessionId string, currPass, newPass string) error
SendForgotPasswordMail(email string) error
ForgotPassword(token string, newPass string) error
IsCsrfTokenValid(tokenStr string, sessionId string) bool
GetCsrfToken(session *types.Session) (string, error)
}
type AuthImpl struct {
db db.Auth
random Random
clock Clock
mail Mail
serverSettings *types.Settings
}
func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *types.Settings) *AuthImpl {
return &AuthImpl{
db: db,
random: random,
clock: clock,
mail: mail,
serverSettings: serverSettings,
}
}
func (service AuthImpl) SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) {
user, err := service.db.GetUserByEmail(email)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, nil, ErrInvalidCredentials
} else {
return nil, nil, types.ErrInternal
}
}
hash := GetHashPassword(password, user.Salt)
if subtle.ConstantTimeCompare(hash, user.Password) == 0 {
return nil, nil, ErrInvalidCredentials
}
err = service.cleanUpSessionWithTokens(session)
if err != nil {
return nil, nil, types.ErrInternal
}
session, err = service.createSession(user.Id)
if err != nil {
return nil, nil, types.ErrInternal
}
return session, user, nil
}
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
if sessionId == "" {
return nil, nil, ErrSessionIdInvalid
}
session, err := service.db.GetSession(sessionId)
if err != nil {
return nil, nil, types.ErrInternal
}
if session.ExpiresAt.Before(service.clock.Now()) {
_ = service.db.DeleteSession(sessionId)
return nil, nil, nil
}
if session.UserId == uuid.Nil {
return session, nil, nil
}
user, err := service.db.GetUser(session.UserId)
if err != nil {
return nil, nil, types.ErrInternal
}
return session, user, nil
}
func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
session, err := service.createSession(uuid.Nil)
if err != nil {
return nil, types.ErrInternal
}
log.Info("Anonymous session created: %v", session.Id)
return session, nil
}
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
_, err := mail.ParseAddress(email)
if err != nil {
return nil, ErrInvalidEmail
}
if !isPasswordValid(password) {
return nil, ErrInvalidPassword
}
userId, err := service.random.UUID()
if err != nil {
return nil, types.ErrInternal
}
salt, err := service.random.Bytes(16)
if err != nil {
return nil, types.ErrInternal
}
hash := GetHashPassword(password, salt)
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
err = service.db.InsertUser(user)
if err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
return nil, ErrAccountExists
} else {
return nil, types.ErrInternal
}
}
return user, nil
}
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return
}
var token *types.Token
if len(tokens) > 0 {
token = tokens[0]
}
if token == nil {
newTokenStr, err := service.random.String(32)
if err != nil {
return
}
token = types.NewToken(
userId,
"",
newTokenStr,
types.TokenTypeEmailVerify,
service.clock.Now(),
service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token)
if err != nil {
return
}
}
var w strings.Builder
err = mailTemplate.Register(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &w)
if err != nil {
log.Error("Could not render welcome email: %v", err)
return
}
service.mail.SendMail(email, "Welcome to spend-sparrow", w.String())
}
func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
if tokenStr == "" {
return types.ErrInternal
}
token, err := service.db.GetToken(tokenStr)
if err != nil {
return types.ErrInternal
}
user, err := service.db.GetUser(token.UserId)
if err != nil {
return types.ErrInternal
}
if token.Type != types.TokenTypeEmailVerify {
return types.ErrInternal
}
now := service.clock.Now()
if token.ExpiresAt.Before(now) {
return types.ErrInternal
}
user.EmailVerified = true
user.EmailVerifiedAt = &now
err = service.db.UpdateUser(user)
if err != nil {
return types.ErrInternal
}
_ = service.db.DeleteToken(token.Token)
return nil
}
func (service AuthImpl) SignOut(sessionId string) error {
return service.db.DeleteSession(sessionId)
}
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
userDb, err := service.db.GetUser(user.Id)
if err != nil {
return types.ErrInternal
}
currHash := GetHashPassword(currPass, userDb.Salt)
if subtle.ConstantTimeCompare(currHash, userDb.Password) == 0 {
return ErrInvalidCredentials
}
err = service.db.DeleteUser(user.Id)
if err != nil {
return err
}
service.mail.SendMail(user.Email, "Account deleted", "Your account has been deleted")
return nil
}
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
if !isPasswordValid(newPass) {
return ErrInvalidPassword
}
if currPass == newPass {
return ErrInvalidPassword
}
currHash := GetHashPassword(currPass, user.Salt)
if subtle.ConstantTimeCompare(currHash, user.Password) == 0 {
return ErrInvalidCredentials
}
newHash := GetHashPassword(newPass, user.Salt)
user.Password = newHash
err := service.db.UpdateUser(user)
if err != nil {
return err
}
sessions, err := service.db.GetSessions(user.Id)
if err != nil {
return types.ErrInternal
}
for _, s := range sessions {
if s.Id != sessionId {
err = service.db.DeleteSession(s.Id)
if err != nil {
return types.ErrInternal
}
}
}
return nil
}
func (service AuthImpl) SendForgotPasswordMail(email string) error {
tokenStr, err := service.random.String(32)
if err != nil {
return err
}
user, err := service.db.GetUserByEmail(email)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil
} else {
return types.ErrInternal
}
}
token := types.NewToken(
user.Id,
"",
tokenStr,
types.TokenTypePasswordReset,
service.clock.Now(),
service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(token)
if err != nil {
return types.ErrInternal
}
var mail strings.Builder
err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail)
if err != nil {
log.Error("Could not render reset password email: %v", err)
return types.ErrInternal
}
service.mail.SendMail(email, "Reset Password", mail.String())
return nil
}
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
if !isPasswordValid(newPass) {
return ErrInvalidPassword
}
token, err := service.db.GetToken(tokenStr)
if err != nil {
return ErrTokenInvalid
}
err = service.db.DeleteToken(tokenStr)
if err != nil {
return err
}
if token.Type != types.TokenTypePasswordReset ||
token.ExpiresAt.Before(service.clock.Now()) {
return ErrTokenInvalid
}
user, err := service.db.GetUser(token.UserId)
if err != nil {
log.Error("Could not get user from token: %v", err)
return types.ErrInternal
}
passHash := GetHashPassword(newPass, user.Salt)
user.Password = passHash
err = service.db.UpdateUser(user)
if err != nil {
return err
}
sessions, err := service.db.GetSessions(user.Id)
if err != nil {
return types.ErrInternal
}
for _, session := range sessions {
err = service.db.DeleteSession(session.Id)
if err != nil {
return types.ErrInternal
}
}
return nil
}
func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool {
token, err := service.db.GetToken(tokenStr)
if err != nil {
return false
}
if token.Type != types.TokenTypeCsrf ||
token.SessionId != sessionId ||
token.ExpiresAt.Before(service.clock.Now()) {
return false
}
return true
}
func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
if session == nil {
return "", types.ErrInternal
}
tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
if len(tokens) > 0 {
return tokens[0].Token, nil
}
tokenStr, err := service.random.String(32)
if err != nil {
return "", types.ErrInternal
}
token := types.NewToken(
session.UserId,
session.Id,
tokenStr,
types.TokenTypeCsrf,
service.clock.Now(),
service.clock.Now().Add(8*time.Hour))
err = service.db.InsertToken(token)
if err != nil {
return "", types.ErrInternal
}
log.Info("CSRF-Token created: %v", tokenStr)
return tokenStr, nil
}
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
if session == nil {
return nil
}
err := service.db.DeleteSession(session.Id)
if err != nil {
return types.ErrInternal
}
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
if err != nil {
return types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(token.Token)
if err != nil {
return types.ErrInternal
}
}
return nil
}
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
sessionId, err := service.random.String(32)
if err != nil {
return nil, types.ErrInternal
}
err = service.db.DeleteOldSessions(userId)
if err != nil {
return nil, types.ErrInternal
}
createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour)
session := types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session)
if err != nil {
return nil, types.ErrInternal
}
return session, nil
}
func GetHashPassword(password string, salt []byte) []byte {
return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16)
}
func isPasswordValid(password string) bool {
if len(password) < 8 ||
!strings.ContainsAny(password, "0123456789") ||
!strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") ||
!strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") ||
!strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") {
return false
} else {
return true
}
}

17
internal/service/clock.go Normal file
View File

@@ -0,0 +1,17 @@
package service
import "time"
type Clock interface {
Now() time.Time
}
type ClockImpl struct{}
func NewClock() Clock {
return &ClockImpl{}
}
func (c *ClockImpl) Now() time.Time {
return time.Now()
}

View File

@@ -0,0 +1,25 @@
package service
import (
"fmt"
"regexp"
)
const (
DECIMALS_MULTIPLIER = 100
)
var (
safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9ÄÖÜäöüß,&'" -]+$`)
)
func validateString(value string, fieldName string) error {
switch {
case value == "":
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest)
case !safeInputRegex.MatchString(value):
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest)
default:
return nil
}
}

View File

@@ -0,0 +1,8 @@
package service
import "errors"
var (
ErrBadRequest = errors.New("bad request")
ErrUnauthorized = errors.New("unauthorized")
)

55
internal/service/mail.go Normal file
View File

@@ -0,0 +1,55 @@
package service
import (
"fmt"
"net/smtp"
"spend-sparrow/internal/log"
"spend-sparrow/internal/types"
)
type Mail interface {
// Sending an email is a fire and forget operation. Thus no error handling
SendMail(to string, subject string, message string)
}
type MailImpl struct {
server *types.Settings
}
func NewMail(server *types.Settings) MailImpl {
return MailImpl{server: server}
}
func (m MailImpl) SendMail(to string, subject string, message string) {
go m.internalSendMail(to, subject, message)
}
func (m MailImpl) internalSendMail(to string, subject string, message string) {
if m.server.Smtp == nil {
return
}
s := m.server.Smtp
auth := smtp.PlainAuth("", s.User, s.Pass, s.Host)
msg := fmt.Sprintf(
`From: %v <%v>
To: %v
Subject: %v
MIME-version: 1.0;
Content-Type: text/html; charset="UTF-8";
%v`,
s.FromName,
s.FromMail,
to,
subject,
message)
log.Info("Sending mail to %v", to)
err := smtp.SendMail(s.Host+":"+s.Port, auth, s.FromMail, []string{to}, []byte(msg))
if err != nil {
log.Error("Error sending mail: %v", err)
}
}

View File

@@ -0,0 +1,54 @@
package service
import (
"crypto/rand"
"encoding/base64"
"spend-sparrow/internal/log"
"spend-sparrow/internal/types"
"github.com/google/uuid"
)
type Random interface {
Bytes(size int) ([]byte, error)
String(size int) (string, error)
UUID() (uuid.UUID, error)
}
type RandomImpl struct {
}
func NewRandom() *RandomImpl {
return &RandomImpl{}
}
func (r *RandomImpl) Bytes(size int) ([]byte, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
log.Error("Error generating random bytes: %v", err)
return []byte{}, types.ErrInternal
}
return b, nil
}
func (r *RandomImpl) String(size int) (string, error) {
bytes, err := r.Bytes(size)
if err != nil {
log.Error("Error generating random string: %v", err)
return "", types.ErrInternal
}
return base64.StdEncoding.EncodeToString(bytes), nil
}
func (r *RandomImpl) UUID() (uuid.UUID, error) {
id, err := uuid.NewRandom()
if err != nil {
log.Error("Error generating random UUID: %v", err)
return uuid.Nil, types.ErrInternal
}
return id, nil
}

View File

@@ -0,0 +1,553 @@
package service
import (
"errors"
"fmt"
"spend-sparrow/internal/db"
"spend-sparrow/internal/log"
"spend-sparrow/internal/types"
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
transactionMetric = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "spendsparrow_transaction_total",
Help: "The total of transaction operations",
},
[]string{"operation"},
)
)
type Transaction interface {
Add(tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
Update(user *types.User, transaction types.Transaction) (*types.Transaction, error)
Get(user *types.User, id string) (*types.Transaction, error)
GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
Delete(user *types.User, id string) error
RecalculateBalances(user *types.User) error
}
type TransactionImpl struct {
db *sqlx.DB
clock Clock
random Random
}
func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
return TransactionImpl{
db: db,
clock: clock,
random: random,
}
}
func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
transactionMetric.WithLabelValues("add").Inc()
if user == nil {
return nil, ErrUnauthorized
}
var err error
if tx == nil {
tx, err = s.db.Beginx()
err = db.TransformAndLogDbError("transaction Add", nil, err)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()
}
transaction, err := s.validateAndEnrichTransaction(tx, nil, user.Id, transactionInput)
if err != nil {
return nil, err
}
r, err := tx.NamedExec(`
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
party, description, error, created_at, created_by)
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
:party, :description, :error, :created_at, :created_by)`, transaction)
err = db.TransformAndLogDbError("transaction Insert", r, err)
if err != nil {
return nil, err
}
if transaction.Error == nil && transaction.AccountId != nil {
r, err = tx.Exec(`
UPDATE account
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction Add", r, err)
if err != nil {
return nil, err
}
}
if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err = tx.Exec(`
UPDATE treasure_chest
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError("transaction Add", r, err)
if err != nil {
return nil, err
}
}
err = tx.Commit()
err = db.TransformAndLogDbError("transaction Add", nil, err)
if err != nil {
return nil, err
}
return transaction, nil
}
func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*types.Transaction, error) {
transactionMetric.WithLabelValues("update").Inc()
if user == nil {
return nil, ErrUnauthorized
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()
transaction := &types.Transaction{}
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest)
}
return nil, types.ErrInternal
}
if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(`
UPDATE account
SET current_balance = current_balance - ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction Update", r, err)
if err != nil {
return nil, err
}
}
if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(`
UPDATE treasure_chest
SET current_balance = current_balance - ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError("transaction Update", r, err)
if err != nil {
return nil, err
}
}
transaction, err = s.validateAndEnrichTransaction(tx, transaction, user.Id, input)
if err != nil {
return nil, err
}
if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(`
UPDATE account
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction Update", r, err)
if err != nil {
return nil, err
}
}
if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(`
UPDATE treasure_chest
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError("transaction Update", r, err)
if err != nil {
return nil, err
}
}
r, err := tx.NamedExec(`
UPDATE "transaction"
SET
account_id = :account_id,
treasure_chest_id = :treasure_chest_id,
value = :value,
timestamp = :timestamp,
party = :party,
description = :description,
error = :error,
updated_at = :updated_at,
updated_by = :updated_by
WHERE id = :id
AND user_id = :user_id`, transaction)
err = db.TransformAndLogDbError("transaction Update", r, err)
if err != nil {
return nil, err
}
err = tx.Commit()
err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil {
return nil, err
}
return transaction, nil
}
func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, error) {
transactionMetric.WithLabelValues("get").Inc()
if user == nil {
return nil, ErrUnauthorized
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("transaction get: %v", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
var transaction types.Transaction
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Get", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest)
}
return nil, types.ErrInternal
}
return &transaction, nil
}
func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
transactionMetric.WithLabelValues("get_all").Inc()
if user == nil {
return nil, ErrUnauthorized
}
transactions := make([]*types.Transaction, 0)
err := s.db.Select(&transactions, `
SELECT *
FROM "transaction"
WHERE user_id = ?
AND (? = '' OR account_id = ?)
AND (? = '' OR treasure_chest_id = ?)
AND (? = ''
OR (? = "true" AND error IS NOT NULL)
OR (? = "false" AND error IS NULL)
)
ORDER BY timestamp DESC`,
user.Id,
filter.AccountId, filter.AccountId,
filter.TreasureChestId, filter.TreasureChestId,
filter.Error, filter.Error, filter.Error)
err = db.TransformAndLogDbError("transaction GetAll", nil, err)
if err != nil {
return nil, err
}
return transactions, nil
}
func (s TransactionImpl) Delete(user *types.User, id string) error {
transactionMetric.WithLabelValues("delete").Inc()
if user == nil {
return ErrUnauthorized
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("transaction delete: %v", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transaction Delete", nil, err)
if err != nil {
return nil
}
defer func() {
_ = tx.Rollback()
}()
var transaction types.Transaction
err = tx.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Delete", nil, err)
if err != nil {
return err
}
if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(`
UPDATE account
SET current_balance = current_balance - ?
WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
}
if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(`
UPDATE treasure_chest
SET current_balance = current_balance - ?
WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
}
r, err := tx.Exec("DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil {
return err
}
err = tx.Commit()
err = db.TransformAndLogDbError("transaction Delete", nil, err)
if err != nil {
return err
}
return nil
}
func (s TransactionImpl) RecalculateBalances(user *types.User) error {
transactionMetric.WithLabelValues("recalculate").Inc()
if user == nil {
return ErrUnauthorized
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil {
return err
}
defer func() {
_ = tx.Rollback()
}()
r, err := tx.Exec(`
UPDATE account
SET current_balance = 0
WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
r, err = tx.Exec(`
UPDATE treasure_chest
SET current_balance = 0
WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
rows, err := tx.Queryx(`
SELECT *
FROM "transaction"
WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
defer func() {
err := rows.Close()
if err != nil {
log.Error("transaction RecalculateBalances: %v", err)
}
}()
var transaction types.Transaction
for rows.Next() {
err = rows.StructScan(&transaction)
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil {
return err
}
s.updateErrors(&transaction)
r, err = tx.Exec(`
UPDATE "transaction"
SET error = ?
WHERE user_id = ?
AND id = ?`, transaction.Error, user.Id, transaction.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil {
return err
}
if transaction.Error != nil {
continue
}
if transaction.AccountId != nil {
r, err = tx.Exec(`
UPDATE account
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil {
return err
}
}
if transaction.TreasureChestId != nil {
r, err = tx.Exec(`
UPDATE treasure_chest
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil {
return err
}
}
}
err = tx.Commit()
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil {
return err
}
return nil
}
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) {
var (
id uuid.UUID
createdAt time.Time
createdBy uuid.UUID
updatedAt *time.Time
updatedBy uuid.UUID
err error
rowCount int
)
if oldTransaction == nil {
id, err = s.random.UUID()
if err != nil {
return nil, types.ErrInternal
}
createdAt = s.clock.Now()
createdBy = userId
} else {
id = oldTransaction.Id
createdAt = oldTransaction.CreatedAt
createdBy = oldTransaction.CreatedBy
time := s.clock.Now()
updatedAt = &time
updatedBy = userId
}
if input.AccountId != nil {
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId)
err = db.TransformAndLogDbError("transaction validate", nil, err)
if err != nil {
return nil, err
}
if rowCount == 0 {
log.Error("transaction validate: %v", err)
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
}
}
if input.TreasureChestId != nil {
var treasureChest types.TreasureChest
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
err = db.TransformAndLogDbError("transaction validate", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
}
return nil, err
}
if treasureChest.ParentId == nil {
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest)
}
}
if input.Party != "" {
err = validateString(input.Party, "party")
if err != nil {
return nil, err
}
}
if input.Description != "" {
err = validateString(input.Description, "description")
if err != nil {
return nil, err
}
}
transaction := types.Transaction{
Id: id,
UserId: userId,
AccountId: input.AccountId,
TreasureChestId: input.TreasureChestId,
Value: input.Value,
Timestamp: input.Timestamp,
Party: input.Party,
Description: input.Description,
Error: nil,
CreatedAt: createdAt,
CreatedBy: createdBy,
UpdatedAt: updatedAt,
UpdatedBy: &updatedBy,
}
s.updateErrors(&transaction)
return &transaction, nil
}
func (s TransactionImpl) updateErrors(transaction *types.Transaction) {
errorStr := ""
switch {
case transaction.Value < 0:
if transaction.TreasureChestId == nil {
errorStr = "no treasure chest specified"
}
case transaction.Value > 0:
if transaction.AccountId == nil && transaction.TreasureChestId == nil {
errorStr = "either an account or a treasure chest needs to be specified"
} else if transaction.AccountId != nil && transaction.TreasureChestId != nil {
errorStr = "positive amounts can only be applied to either an account or a treasure chest"
}
default:
errorStr = "\"value\" needs to be specified"
}
if errorStr == "" {
transaction.Error = nil
} else {
transaction.Error = &errorStr
}
}

View File

@@ -0,0 +1,537 @@
package service
import (
"errors"
"fmt"
"spend-sparrow/internal/db"
"spend-sparrow/internal/log"
"spend-sparrow/internal/types"
"strconv"
"time"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
transactionRecurringMetric = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "spendsparrow_transaction_recurring_total",
Help: "The total of transactionRecurring operations",
},
[]string{"operation"},
)
)
type TransactionRecurring interface {
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
GetAll(user *types.User) ([]*types.TransactionRecurring, error)
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error)
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(user *types.User, id string) error
GenerateTransactions(user *types.User) error
}
type TransactionRecurringImpl struct {
db *sqlx.DB
clock Clock
random Random
transaction Transaction
}
func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transaction Transaction) TransactionRecurring {
return TransactionRecurringImpl{
db: db,
clock: clock,
random: random,
transaction: transaction,
}
}
func (s TransactionRecurringImpl) Add(
user *types.User,
transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("add").Inc()
if user == nil {
return nil, ErrUnauthorized
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transactionRecurring Add", nil, err)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()
transactionRecurring, err := s.validateAndEnrichTransactionRecurring(tx, nil, user.Id, transactionRecurringInput)
if err != nil {
return nil, err
}
r, err := tx.NamedExec(`
INSERT INTO "transaction_recurring" (id, user_id, interval_months,
next_execution, party, description, account_id, treasure_chest_id, value, created_at, created_by)
VALUES (:id, :user_id, :interval_months,
:next_execution, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`,
transactionRecurring)
err = db.TransformAndLogDbError("transactionRecurring Insert", r, err)
if err != nil {
return nil, err
}
err = tx.Commit()
err = db.TransformAndLogDbError("transactionRecurring Add", nil, err)
if err != nil {
return nil, err
}
return transactionRecurring, nil
}
func (s TransactionRecurringImpl) Update(
user *types.User,
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("update").Inc()
if user == nil {
return nil, ErrUnauthorized
}
uuid, err := uuid.Parse(input.Id)
if err != nil {
log.Error("transactionRecurring update: %v", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()
transactionRecurring := &types.TransactionRecurring{}
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest)
}
return nil, types.ErrInternal
}
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(tx, transactionRecurring, user.Id, input)
if err != nil {
return nil, err
}
r, err := tx.NamedExec(`
UPDATE transaction_recurring
SET
interval_months = :interval_months,
next_execution = :next_execution,
party = :party,
description = :description,
account_id = :account_id,
treasure_chest_id = :treasure_chest_id,
value = :value,
updated_at = :updated_at,
updated_by = :updated_by
WHERE id = :id
AND user_id = :user_id`, transactionRecurring)
err = db.TransformAndLogDbError("transactionRecurring Update", r, err)
if err != nil {
return nil, err
}
err = tx.Commit()
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil {
return nil, err
}
return transactionRecurring, nil
}
func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("get_all_by_account").Inc()
if user == nil {
return nil, ErrUnauthorized
}
transactionRecurrings := make([]*types.TransactionRecurring, 0)
err := s.db.Select(&transactionRecurrings, `
SELECT *
FROM transaction_recurring
WHERE user_id = ?
ORDER BY created_at DESC`,
user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAll", nil, err)
if err != nil {
return nil, err
}
return transactionRecurrings, nil
}
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("get_all_by_account").Inc()
if user == nil {
return nil, ErrUnauthorized
}
accountUuid, err := uuid.Parse(accountId)
if err != nil {
log.Error("transactionRecurring GetAllByAccount: %v", err)
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()
var rowCount int
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest)
}
return nil, types.ErrInternal
}
transactionRecurrings := make([]*types.TransactionRecurring, 0)
err = tx.Select(&transactionRecurrings, `
SELECT *
FROM transaction_recurring
WHERE user_id = ?
AND account_id = ?
ORDER BY created_at DESC`,
user.Id, accountUuid)
err = db.TransformAndLogDbError("transactionRecurring GetAll", nil, err)
if err != nil {
return nil, err
}
err = tx.Commit()
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil {
return nil, err
}
return transactionRecurrings, nil
}
func (s TransactionRecurringImpl) GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("get_all_by_treasurechest").Inc()
if user == nil {
return nil, ErrUnauthorized
}
treasureChestUuid, err := uuid.Parse(treasureChestId)
if err != nil {
log.Error("transactionRecurring GetAllByTreasureChest: %v", err)
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()
var rowCount int
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest)
}
return nil, types.ErrInternal
}
transactionRecurrings := make([]*types.TransactionRecurring, 0)
err = tx.Select(&transactionRecurrings, `
SELECT *
FROM transaction_recurring
WHERE user_id = ?
AND treasure_chest_id = ?
ORDER BY created_at DESC`,
user.Id, treasureChestUuid)
err = db.TransformAndLogDbError("transactionRecurring GetAll", nil, err)
if err != nil {
return nil, err
}
err = tx.Commit()
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil {
return nil, err
}
return transactionRecurrings, nil
}
func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
transactionRecurringMetric.WithLabelValues("delete").Inc()
if user == nil {
return ErrUnauthorized
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("transactionRecurring delete: %v", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
if err != nil {
return nil
}
defer func() {
_ = tx.Rollback()
}()
var transactionRecurring types.TransactionRecurring
err = tx.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
if err != nil {
return err
}
r, err := tx.Exec("DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("transactionRecurring Delete", r, err)
if err != nil {
return err
}
err = tx.Commit()
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
if err != nil {
return err
}
return nil
}
func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
if user == nil {
return ErrUnauthorized
}
now := s.clock.Now()
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
if err != nil {
return err
}
defer func() {
_ = tx.Rollback()
}()
recurringTransactions := make([]*types.TransactionRecurring, 0)
err = tx.Select(&recurringTransactions, `
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
user.Id, now)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
if err != nil {
return err
}
for _, transactionRecurring := range recurringTransactions {
transaction := types.Transaction{
Timestamp: *transactionRecurring.NextExecution,
Party: transactionRecurring.Party,
Description: transactionRecurring.Description,
TreasureChestId: transactionRecurring.TreasureChestId,
Value: transactionRecurring.Value,
}
_, err = s.transaction.Add(tx, user, transaction)
if err != nil {
return err
}
nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0)
r, err := tx.Exec(`UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`,
nextExecution, transactionRecurring.Id, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err)
if err != nil {
return err
}
}
err = tx.Commit()
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
if err != nil {
return err
}
return nil
}
func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
tx *sqlx.Tx,
oldTransactionRecurring *types.TransactionRecurring,
userId uuid.UUID,
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
var (
id uuid.UUID
accountUuid *uuid.UUID
treasureChestUuid *uuid.UUID
createdAt time.Time
createdBy uuid.UUID
updatedAt *time.Time
updatedBy uuid.UUID
intervalMonths int64
err error
rowCount int
)
if oldTransactionRecurring == nil {
id, err = s.random.UUID()
if err != nil {
return nil, types.ErrInternal
}
createdAt = s.clock.Now()
createdBy = userId
} else {
id = oldTransactionRecurring.Id
createdAt = oldTransactionRecurring.CreatedAt
createdBy = oldTransactionRecurring.CreatedBy
time := s.clock.Now()
updatedAt = &time
updatedBy = userId
}
hasAccount := false
hasTreasureChest := false
if input.AccountId != "" {
temp, err := uuid.Parse(input.AccountId)
if err != nil {
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
}
accountUuid = &temp
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
if err != nil {
return nil, err
}
if rowCount == 0 {
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
}
hasAccount = true
}
if input.TreasureChestId != "" {
temp, err := uuid.Parse(input.TreasureChestId)
if err != nil {
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
}
treasureChestUuid = &temp
var treasureChest types.TreasureChest
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
}
return nil, err
}
if treasureChest.ParentId == nil {
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest)
}
hasTreasureChest = true
}
if !hasAccount && !hasTreasureChest {
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("either account or treasure chest is required: %w", ErrBadRequest)
}
if hasAccount && hasTreasureChest {
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", ErrBadRequest)
}
valueFloat, err := strconv.ParseFloat(input.Value, 64)
if err != nil {
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
}
valueInt := int64(valueFloat * DECIMALS_MULTIPLIER)
if input.Party != "" {
err = validateString(input.Party, "party")
if err != nil {
return nil, err
}
}
if input.Description != "" {
err = validateString(input.Description, "description")
if err != nil {
return nil, err
}
}
intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0)
if err != nil {
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("could not parse intervalMonths: %w", ErrBadRequest)
}
if intervalMonths < 1 {
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", ErrBadRequest)
}
var nextExecution *time.Time = nil
if input.NextExecution != "" {
t, err := time.Parse("2006-01-02", input.NextExecution)
if err != nil {
log.Error("transaction validate: %v", err)
return nil, fmt.Errorf("could not parse timestamp: %w", ErrBadRequest)
}
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
nextExecution = &t
}
transactionRecurring := types.TransactionRecurring{
Id: id,
UserId: userId,
IntervalMonths: intervalMonths,
NextExecution: nextExecution,
Party: input.Party,
Description: input.Description,
AccountId: accountUuid,
TreasureChestId: treasureChestUuid,
Value: valueInt,
CreatedAt: createdAt,
CreatedBy: createdBy,
UpdatedAt: updatedAt,
UpdatedBy: &updatedBy,
}
return &transactionRecurring, nil
}

View File

@@ -0,0 +1,327 @@
package service
import (
"errors"
"fmt"
"slices"
"spend-sparrow/internal/db"
"spend-sparrow/internal/log"
"spend-sparrow/internal/types"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
treasureChestMetric = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "spendsparrow_treasurechest_total",
Help: "The total of treasurechest operations",
},
[]string{"operation"},
)
)
type TreasureChest interface {
Add(user *types.User, parentId, name string) (*types.TreasureChest, error)
Update(user *types.User, id, parentId, name string) (*types.TreasureChest, error)
Get(user *types.User, id string) (*types.TreasureChest, error)
GetAll(user *types.User) ([]*types.TreasureChest, error)
Delete(user *types.User, id string) error
}
type TreasureChestImpl struct {
db *sqlx.DB
clock Clock
random Random
}
func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
return TreasureChestImpl{
db: db,
clock: clock,
random: random,
}
}
func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.TreasureChest, error) {
treasureChestMetric.WithLabelValues("add").Inc()
if user == nil {
return nil, ErrUnauthorized
}
newId, err := s.random.UUID()
if err != nil {
return nil, types.ErrInternal
}
err = validateString(name, "name")
if err != nil {
return nil, err
}
var parentUuid *uuid.UUID
if parentId != "" {
parent, err := s.Get(user, parentId)
if err != nil {
return nil, err
}
if parent.ParentId != nil {
return nil, fmt.Errorf("only a depth of 1 allowed: %w", ErrBadRequest)
}
parentUuid = &parent.Id
}
treasureChest := &types.TreasureChest{
Id: newId,
ParentId: parentUuid,
UserId: user.Id,
Name: name,
CurrentBalance: 0,
CreatedAt: s.clock.Now(),
CreatedBy: user.Id,
UpdatedAt: nil,
UpdatedBy: nil,
}
r, err := s.db.NamedExec(`
INSERT INTO treasure_chest (id, parent_id, user_id, name, current_balance, created_at, created_by)
VALUES (:id, :parent_id, :user_id, :name, :current_balance, :created_at, :created_by)`, treasureChest)
err = db.TransformAndLogDbError("treasureChest Insert", r, err)
if err != nil {
return nil, err
}
return treasureChest, nil
}
func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
treasureChestMetric.WithLabelValues("update").Inc()
if user == nil {
return nil, ErrUnauthorized
}
err := validateString(name, "name")
if err != nil {
return nil, err
}
id, err := uuid.Parse(idStr)
if err != nil {
log.Error("treasureChest update: %v", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()
treasureChest := &types.TreasureChest{}
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
}
return nil, types.ErrInternal
}
var parentUuid *uuid.UUID
if parentId != "" {
parent, err := s.Get(user, parentId)
if err != nil {
return nil, err
}
var childCount int
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil {
return nil, err
}
if parent.ParentId != nil || childCount > 0 {
return nil, fmt.Errorf("only one level allowed: %w", ErrBadRequest)
}
parentUuid = &parent.Id
}
timestamp := s.clock.Now()
treasureChest.Name = name
treasureChest.ParentId = parentUuid
treasureChest.UpdatedAt = &timestamp
treasureChest.UpdatedBy = &user.Id
r, err := tx.NamedExec(`
UPDATE treasure_chest
SET
parent_id = :parent_id,
name = :name,
current_balance = :current_balance,
updated_at = :updated_at,
updated_by = :updated_by
WHERE id = :id
AND user_id = :user_id`, treasureChest)
err = db.TransformAndLogDbError("treasureChest Update", r, err)
if err != nil {
return nil, err
}
err = tx.Commit()
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil {
return nil, err
}
return treasureChest, nil
}
func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChest, error) {
treasureChestMetric.WithLabelValues("get").Inc()
if user == nil {
return nil, ErrUnauthorized
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("treasureChest get: %v", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
var treasureChest types.TreasureChest
err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("treasureChest Get", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
}
return nil, types.ErrInternal
}
return &treasureChest, nil
}
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) {
treasureChestMetric.WithLabelValues("get_all").Inc()
if user == nil {
return nil, ErrUnauthorized
}
treasureChests := make([]*types.TreasureChest, 0)
err := s.db.Select(&treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("treasureChest GetAll", nil, err)
if err != nil {
return nil, err
}
return sortTree(treasureChests), nil
}
func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
treasureChestMetric.WithLabelValues("delete").Inc()
if user == nil {
return ErrUnauthorized
}
id, err := uuid.Parse(idStr)
if err != nil {
log.Error("treasureChest delete: %v", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil {
return nil
}
defer func() {
_ = tx.Rollback()
}()
childCount := 0
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil {
return err
}
if childCount > 0 {
return fmt.Errorf("treasure chest has children: %w", ErrBadRequest)
}
transactionsCount := 0
err = tx.Get(&transactionsCount,
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil {
return err
}
if transactionsCount > 0 {
return fmt.Errorf("treasure chest has transactions: %w", ErrBadRequest)
}
r, err := tx.Exec(`DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
err = db.TransformAndLogDbError("treasureChest Delete", r, err)
if err != nil {
return err
}
err = tx.Commit()
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil {
return err
}
return nil
}
func sortTree(nodes []*types.TreasureChest) []*types.TreasureChest {
var (
roots []*types.TreasureChest
)
children := make(map[uuid.UUID][]*types.TreasureChest)
result := make([]*types.TreasureChest, 0)
for _, node := range nodes {
if node.ParentId == nil {
roots = append(roots, node)
} else {
children[*node.ParentId] = append(children[*node.ParentId], node)
}
}
slices.SortFunc(roots, func(a, b *types.TreasureChest) int {
return compareStrings(a.Name, b.Name)
})
for _, root := range roots {
result = append(result, root)
childList := children[root.Id]
slices.SortFunc(childList, func(a, b *types.TreasureChest) int {
return compareStrings(a.Name, b.Name)
})
result = append(result, childList...)
}
return result
}
func compareStrings(a, b string) int {
if a == b {
return 0
}
if a < b {
return -1
}
return 1
}

View File

@@ -0,0 +1,10 @@
package service_test
// import (
// "spned-sparrow"
// )
//
// func TestTreasureChestProhibitDeleteIfTransactionRecurringExists(t *testing.T) {
// service := main.Setup
//
// }