From 76da3ca70362a0621f45826c2bc3d1fdda57bdef Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Thu, 29 May 2025 00:00:19 +0200 Subject: [PATCH] feat(transaction-recurring): #100 generate transactions --- db/migration.go | 20 +++-- .../generate_recurring_transactions.go | 22 ++++++ handler/transaction.go | 67 ++++++++++++++--- main.go | 3 +- service/transaction.go | 74 +++++-------------- service/transaction_recurring.go | 74 +++++++++++++++++-- types/transaction.go | 10 --- 7 files changed, 183 insertions(+), 87 deletions(-) create mode 100644 handler/middleware/generate_recurring_transactions.go diff --git a/db/migration.go b/db/migration.go index e4d8357..9f56bcb 100644 --- a/db/migration.go +++ b/db/migration.go @@ -12,6 +12,15 @@ import ( "github.com/jmoiron/sqlx" ) +type migrationLogger struct{} + +func (l migrationLogger) Printf(format string, v ...interface{}) { + log.Info(format, v...) +} +func (l migrationLogger) Verbose() bool { + return false +} + func RunMigrations(db *sqlx.DB, pathPrefix string) error { driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) if err != nil { @@ -28,12 +37,11 @@ func RunMigrations(db *sqlx.DB, pathPrefix string) error { return types.ErrInternal } - err = m.Up() - if err != nil { - if !errors.Is(err, migrate.ErrNoChange) { - log.Error("Could not run migrations: %v", err) - return types.ErrInternal - } + m.Log = migrationLogger{} + + if err = m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { + log.Error("Could not run migrations: %v", err) + return types.ErrInternal } return nil diff --git a/handler/middleware/generate_recurring_transactions.go b/handler/middleware/generate_recurring_transactions.go new file mode 100644 index 0000000..fe8fdee --- /dev/null +++ b/handler/middleware/generate_recurring_transactions.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "net/http" + "spend-sparrow/service" +) + +func GenerateRecurringTransactions(transactionRecurring service.TransactionRecurring) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := GetUser(r) + if user == nil || r.Method != http.MethodGet { + next.ServeHTTP(w, r) + return + } + + _ = transactionRecurring.GenerateTransactions(user) + + next.ServeHTTP(w, r) + }) + } +} diff --git a/handler/transaction.go b/handler/transaction.go index 4fdf94e..3d5c1b8 100644 --- a/handler/transaction.go +++ b/handler/transaction.go @@ -1,12 +1,15 @@ package handler import ( + "fmt" "net/http" "spend-sparrow/handler/middleware" "spend-sparrow/service" t "spend-sparrow/template/transaction" "spend-sparrow/types" "spend-sparrow/utils" + "strconv" + "time" "github.com/a-h/templ" "github.com/google/uuid" @@ -137,20 +140,66 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc { } var ( - transaction *types.Transaction - err error + id uuid.UUID + err error ) - input := types.TransactionInput{ - Id: r.PathValue("id"), - AccountId: r.FormValue("account-id"), - TreasureChestId: r.FormValue("treasure-chest-id"), - Value: r.FormValue("value"), - Timestamp: r.FormValue("timestamp"), + + idStr := r.PathValue("id") + if idStr != "new" { + id, err = uuid.Parse(idStr) + if err != nil { + handleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)) + return + } + } + + accountIdStr := r.FormValue("account-id") + var accountId *uuid.UUID + if accountIdStr != "" { + i, err := uuid.Parse(accountIdStr) + if err != nil { + handleError(w, r, fmt.Errorf("could not parse account id: %w", service.ErrBadRequest)) + return + } + accountId = &i + } + + treasureChestIdStr := r.FormValue("treasure-chest-id") + var treasureChestId *uuid.UUID + if treasureChestIdStr != "" { + i, err := uuid.Parse(treasureChestIdStr) + if err != nil { + handleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", service.ErrBadRequest)) + return + } + treasureChestId = &i + } + + valueF, err := strconv.ParseFloat(r.FormValue("value"), 64) + if err != nil { + handleError(w, r, fmt.Errorf("could not parse value: %w", service.ErrBadRequest)) + return + } + value := int64(valueF * service.DECIMALS_MULTIPLIER) + + timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp")) + if err != nil { + handleError(w, r, fmt.Errorf("could not parse timestamp: %w", service.ErrBadRequest)) + return + } + + input := types.Transaction{ + Id: id, + AccountId: accountId, + TreasureChestId: treasureChestId, + Value: value, + Timestamp: timestamp, Party: r.FormValue("party"), Description: r.FormValue("description"), } - if input.Id == "new" { + var transaction *types.Transaction + if idStr == "new" { transaction, err = h.s.Add(user, input) if err != nil { handleError(w, r, err) diff --git a/main.go b/main.go index 67a4283..760b984 100644 --- a/main.go +++ b/main.go @@ -126,7 +126,7 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler { accountService := service.NewAccount(d, randomService, clockService) treasureChestService := service.NewTreasureChest(d, randomService, clockService) transactionService := service.NewTransaction(d, randomService, clockService) - transactionRecurringService := service.NewTransactionRecurring(d, randomService, clockService) + transactionRecurringService := service.NewTransactionRecurring(d, randomService, clockService, transactionService) render := handler.NewRender() indexHandler := handler.NewIndex(render) @@ -148,6 +148,7 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler { return middleware.Wrapper( router, + middleware.GenerateRecurringTransactions(transactionRecurringService), middleware.SecurityHeaders(serverSettings), middleware.CacheControl, middleware.CrossSiteRequestForgery(authService), diff --git a/service/transaction.go b/service/transaction.go index d02a9b3..a17e9aa 100644 --- a/service/transaction.go +++ b/service/transaction.go @@ -3,12 +3,10 @@ package service import ( "errors" "fmt" - "strconv" - "time" - "spend-sparrow/db" "spend-sparrow/log" "spend-sparrow/types" + "time" "github.com/google/uuid" "github.com/jmoiron/sqlx" @@ -27,8 +25,8 @@ var ( ) type Transaction interface { - Add(user *types.User, transaction types.TransactionInput) (*types.Transaction, error) - Update(user *types.User, transaction types.TransactionInput) (*types.Transaction, error) + Add(user *types.User, transaction types.Transaction) (*types.Transaction, error) + Update(user *types.User, transaction types.Transaction) (*types.Transaction, error) Get(user *types.User, id string) (*types.Transaction, error) GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) Delete(user *types.User, id string) error @@ -50,7 +48,7 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction { } } -func (s TransactionImpl) Add(user *types.User, transactionInput types.TransactionInput) (*types.Transaction, error) { +func (s TransactionImpl) Add(user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { transactionMetric.WithLabelValues("add").Inc() if user == nil { @@ -112,16 +110,11 @@ func (s TransactionImpl) Add(user *types.User, transactionInput types.Transactio return transaction, nil } -func (s TransactionImpl) Update(user *types.User, input types.TransactionInput) (*types.Transaction, error) { +func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*types.Transaction, error) { transactionMetric.WithLabelValues("update").Inc() if user == nil { return nil, ErrUnauthorized } - uuid, err := uuid.Parse(input.Id) - if err != nil { - log.Error("transaction update: %v", err) - return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) - } tx, err := s.db.Beginx() err = db.TransformAndLogDbError("transaction Update", nil, err) @@ -133,7 +126,7 @@ func (s TransactionImpl) Update(user *types.User, input types.TransactionInput) }() transaction := &types.Transaction{} - err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id) err = db.TransformAndLogDbError("transaction Update", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -440,15 +433,13 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { return nil } -func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.TransactionInput) (*types.Transaction, error) { +func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) { var ( - id uuid.UUID - accountUuid *uuid.UUID - treasureChestUuid *uuid.UUID - createdAt time.Time - createdBy uuid.UUID - updatedAt *time.Time - updatedBy uuid.UUID + id uuid.UUID + createdAt time.Time + createdBy uuid.UUID + updatedAt *time.Time + updatedBy uuid.UUID err error rowCount int @@ -470,14 +461,8 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio updatedBy = userId } - if input.AccountId != "" { - temp, err := uuid.Parse(input.AccountId) - if err != nil { - log.Error("transaction validate: %v", err) - return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) - } - accountUuid = &temp - err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId) + if input.AccountId != nil { + err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId) err = db.TransformAndLogDbError("transaction validate", nil, err) if err != nil { return nil, err @@ -488,15 +473,9 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio } } - if input.TreasureChestId != "" { - temp, err := uuid.Parse(input.TreasureChestId) - if err != nil { - log.Error("transaction validate: %v", err) - return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) - } - treasureChestUuid = &temp + if input.TreasureChestId != nil { var treasureChest types.TreasureChest - err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) + err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId) err = db.TransformAndLogDbError("transaction validate", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -509,19 +488,6 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio } } - valueFloat, err := strconv.ParseFloat(input.Value, 64) - if err != nil { - log.Error("transaction validate: %v", err) - return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest) - } - valueInt := int64(valueFloat * DECIMALS_MULTIPLIER) - - timestamp, err := time.Parse("2006-01-02", input.Timestamp) - if err != nil { - log.Error("transaction validate: %v", err) - return nil, fmt.Errorf("could not parse timestamp: %w", ErrBadRequest) - } - if input.Party != "" { err = validateString(input.Party, "party") if err != nil { @@ -539,10 +505,10 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio Id: id, UserId: userId, - AccountId: accountUuid, - TreasureChestId: treasureChestUuid, - Value: valueInt, - Timestamp: timestamp, + AccountId: input.AccountId, + TreasureChestId: input.TreasureChestId, + Value: input.Value, + Timestamp: input.Timestamp, Party: input.Party, Description: input.Description, Error: nil, diff --git a/service/transaction_recurring.go b/service/transaction_recurring.go index e140127..c1ebae6 100644 --- a/service/transaction_recurring.go +++ b/service/transaction_recurring.go @@ -33,19 +33,23 @@ type TransactionRecurring interface { GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) Delete(user *types.User, id string) error + + GenerateTransactions(user *types.User) error } type TransactionRecurringImpl struct { - db *sqlx.DB - clock Clock - random Random + db *sqlx.DB + clock Clock + random Random + transaction Transaction } -func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock) TransactionRecurring { +func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transaction Transaction) TransactionRecurring { return TransactionRecurringImpl{ - db: db, - clock: clock, - random: random, + db: db, + clock: clock, + random: random, + transaction: transaction, } } @@ -326,6 +330,62 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error { return nil } +func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error { + if user == nil { + return ErrUnauthorized + } + now := s.clock.Now() + + tx, err := s.db.Beginx() + err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err) + if err != nil { + return err + } + defer func() { + _ = tx.Rollback() + }() + + recurringTransactions := make([]*types.TransactionRecurring, 0) + err = tx.Select(&recurringTransactions, ` + SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`, + user.Id, now) + err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err) + if err != nil { + return err + } + + for _, transactionRecurring := range recurringTransactions { + transaction := types.Transaction{ + Timestamp: *transactionRecurring.NextExecution, + Party: transactionRecurring.Party, + Description: transactionRecurring.Description, + + TreasureChestId: transactionRecurring.TreasureChestId, + Value: transactionRecurring.Value, + } + + _, err = s.transaction.Add(user, transaction) + if err != nil { + return err + } + + nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0) + r, err := tx.Exec(`UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`, + nextExecution, transactionRecurring.Id, user.Id) + err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err) + if err != nil { + return err + } + } + + err = tx.Commit() + err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err) + if err != nil { + return err + } + return nil +} + func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( tx *sqlx.Tx, oldTransactionRecurring *types.TransactionRecurring, diff --git a/types/transaction.go b/types/transaction.go index d7cd3b8..b53d0c2 100644 --- a/types/transaction.go +++ b/types/transaction.go @@ -34,16 +34,6 @@ type Transaction struct { UpdatedBy *uuid.UUID `db:"updated_by"` } -type TransactionInput struct { - Id string - AccountId string - TreasureChestId string - Value string - Timestamp string - Party string - Description string -} - type TransactionItemsFilter struct { AccountId string TreasureChestId string