diff --git a/handler/transaction.go b/handler/transaction.go index 3d5c1b8..22a8d59 100644 --- a/handler/transaction.go +++ b/handler/transaction.go @@ -200,7 +200,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc { var transaction *types.Transaction if idStr == "new" { - transaction, err = h.s.Add(user, input) + transaction, err = h.s.Add(nil, user, input) if err != nil { handleError(w, r, err) return diff --git a/service/transaction.go b/service/transaction.go index a17e9aa..ab0a130 100644 --- a/service/transaction.go +++ b/service/transaction.go @@ -25,7 +25,7 @@ var ( ) type Transaction interface { - Add(user *types.User, transaction types.Transaction) (*types.Transaction, error) + Add(tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error) Update(user *types.User, transaction types.Transaction) (*types.Transaction, error) Get(user *types.User, id string) (*types.Transaction, error) GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) @@ -48,21 +48,24 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction { } } -func (s TransactionImpl) Add(user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { +func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { transactionMetric.WithLabelValues("add").Inc() if user == nil { return nil, ErrUnauthorized } - tx, err := s.db.Beginx() - err = db.TransformAndLogDbError("transaction Add", nil, err) - if err != nil { - return nil, err + var err error + if tx == nil { + tx, err = s.db.Beginx() + err = db.TransformAndLogDbError("transaction Add", nil, err) + if err != nil { + return nil, err + } + defer func() { + _ = tx.Rollback() + }() } - defer func() { - _ = tx.Rollback() - }() transaction, err := s.validateAndEnrichTransaction(tx, nil, user.Id, transactionInput) if err != nil { diff --git a/service/transaction_recurring.go b/service/transaction_recurring.go index c1ebae6..f5a776c 100644 --- a/service/transaction_recurring.go +++ b/service/transaction_recurring.go @@ -364,7 +364,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error { Value: transactionRecurring.Value, } - _, err = s.transaction.Add(user, transaction) + _, err = s.transaction.Add(tx, user, transaction) if err != nil { return err }