feat(observabillity): #153 instrument sqlx
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 2m29s
Build and Push Docker Image / Build-And-Push-Docker-Image (push) Successful in 2m49s

This commit was merged in pull request #160.
This commit is contained in:
2025-06-07 21:55:59 +02:00
parent c4aca2778f
commit 11f3bcc89f
25 changed files with 434 additions and 409 deletions

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"errors"
"log/slog"
@@ -13,23 +14,23 @@ import (
)
type Auth interface {
InsertUser(user *types.User) error
UpdateUser(user *types.User) error
GetUserByEmail(email string) (*types.User, error)
GetUser(userId uuid.UUID) (*types.User, error)
DeleteUser(userId uuid.UUID) error
InsertUser(ctx context.Context, user *types.User) error
UpdateUser(ctx context.Context, user *types.User) error
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error)
DeleteUser(ctx context.Context, userId uuid.UUID) error
InsertToken(token *types.Token) error
GetToken(token string) (*types.Token, error)
GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error)
DeleteToken(token string) error
InsertToken(ctx context.Context, token *types.Token) error
GetToken(ctx context.Context, token string) (*types.Token, error)
GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error)
DeleteToken(ctx context.Context, token string) error
InsertSession(session *types.Session) error
GetSession(sessionId string) (*types.Session, error)
GetSessions(userId uuid.UUID) ([]*types.Session, error)
DeleteSession(sessionId string) error
DeleteOldSessions(userId uuid.UUID) error
InsertSession(ctx context.Context, session *types.Session) error
GetSession(ctx context.Context, sessionId string) (*types.Session, error)
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
DeleteSession(ctx context.Context, sessionId string) error
DeleteOldSessions(ctx context.Context, userId uuid.UUID) error
}
type AuthSqlite struct {
@@ -40,8 +41,8 @@ func NewAuthSqlite(db *sqlx.DB) *AuthSqlite {
return &AuthSqlite{db: db}
}
func (db AuthSqlite) InsertUser(user *types.User) error {
_, err := db.db.Exec(`
func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error {
_, err := db.db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt)
@@ -58,8 +59,8 @@ func (db AuthSqlite) InsertUser(user *types.User) error {
return nil
}
func (db AuthSqlite) UpdateUser(user *types.User) error {
_, err := db.db.Exec(`
func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error {
_, err := db.db.ExecContext(ctx, `
UPDATE user
SET email_verified = ?, email_verified_at = ?, password = ?
WHERE user_id = ?`,
@@ -73,7 +74,7 @@ func (db AuthSqlite) UpdateUser(user *types.User) error {
return nil
}
func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
var (
userId uuid.UUID
emailVerified bool
@@ -84,7 +85,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
createdAt time.Time
)
err := db.db.QueryRow(`
err := db.db.QueryRowContext(ctx, `
SELECT user_id, email_verified, email_verified_at, password, salt, created_at
FROM user
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
@@ -100,7 +101,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
}
func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) {
var (
email string
emailVerified bool
@@ -111,7 +112,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
createdAt time.Time
)
err := db.db.QueryRow(`
err := db.db.QueryRowContext(ctx, `
SELECT email, email_verified, email_verified_at, password, salt, created_at
FROM user
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
@@ -127,49 +128,49 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
}
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
tx, err := db.db.Begin()
func (db AuthSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error {
tx, err := db.db.BeginTx(ctx, nil)
if err != nil {
slog.Error("Could not start transaction", "err", err)
return types.ErrInternal
}
_, err = tx.Exec("DELETE FROM account WHERE user_id = ?", userId)
_, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId)
if err != nil {
_ = tx.Rollback()
slog.Error("Could not delete accounts", "err", err)
return types.ErrInternal
}
_, err = tx.Exec("DELETE FROM token WHERE user_id = ?", userId)
_, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId)
if err != nil {
_ = tx.Rollback()
slog.Error("Could not delete user tokens", "err", err)
return types.ErrInternal
}
_, err = tx.Exec("DELETE FROM session WHERE user_id = ?", userId)
_, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId)
if err != nil {
_ = tx.Rollback()
slog.Error("Could not delete sessions", "err", err)
return types.ErrInternal
}
_, err = tx.Exec("DELETE FROM user WHERE user_id = ?", userId)
_, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId)
if err != nil {
_ = tx.Rollback()
slog.Error("Could not delete user", "err", err)
return types.ErrInternal
}
_, err = tx.Exec("DELETE FROM treasure_chest WHERE user_id = ?", userId)
_, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId)
if err != nil {
_ = tx.Rollback()
slog.Error("Could not delete user", "err", err)
return types.ErrInternal
}
_, err = tx.Exec("DELETE FROM \"transaction\" WHERE user_id = ?", userId)
_, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId)
if err != nil {
_ = tx.Rollback()
slog.Error("Could not delete user", "err", err)
@@ -185,8 +186,8 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
return nil
}
func (db AuthSqlite) InsertToken(token *types.Token) error {
_, err := db.db.Exec(`
func (db AuthSqlite) InsertToken(ctx context.Context, token *types.Token) error {
_, err := db.db.ExecContext(ctx, `
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
@@ -198,7 +199,7 @@ func (db AuthSqlite) InsertToken(token *types.Token) error {
return nil
}
func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, error) {
var (
userId uuid.UUID
sessionId string
@@ -209,7 +210,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
expiresAt time.Time
)
err := db.db.QueryRow(`
err := db.db.QueryRowContext(ctx, `
SELECT user_id, session_id, type, created_at, expires_at
FROM token
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
@@ -239,8 +240,8 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
}
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(`
func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.QueryContext(ctx, `
SELECT token, created_at, expires_at
FROM token
WHERE user_id = ?
@@ -254,8 +255,8 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.
return getTokensFromQuery(query, userId, "", tokenType)
}
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(`
func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.QueryContext(ctx, `
SELECT token, created_at, expires_at
FROM token
WHERE session_id = ?
@@ -312,8 +313,8 @@ func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tok
return tokens, nil
}
func (db AuthSqlite) DeleteToken(token string) error {
_, err := db.db.Exec("DELETE FROM token WHERE token = ?", token)
func (db AuthSqlite) DeleteToken(ctx context.Context, token string) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token)
if err != nil {
slog.Error("Could not delete token", "err", err)
return types.ErrInternal
@@ -321,8 +322,8 @@ func (db AuthSqlite) DeleteToken(token string) error {
return nil
}
func (db AuthSqlite) InsertSession(session *types.Session) error {
_, err := db.db.Exec(`
func (db AuthSqlite) InsertSession(ctx context.Context, session *types.Session) error {
_, err := db.db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
@@ -334,14 +335,14 @@ func (db AuthSqlite) InsertSession(session *types.Session) error {
return nil
}
func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.Session, error) {
var (
userId uuid.UUID
createdAt time.Time
expiresAt time.Time
)
err := db.db.QueryRow(`
err := db.db.QueryRowContext(ctx, `
SELECT user_id, created_at, expires_at
FROM session
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
@@ -354,9 +355,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
return types.NewSession(sessionId, userId, createdAt, expiresAt), nil
}
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) {
var sessions []*types.Session
err := db.db.Select(&sessions, `
err := db.db.SelectContext(ctx, &sessions, `
SELECT *
FROM session
WHERE user_id = ?`, userId)
@@ -368,8 +369,8 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
return sessions, nil
}
func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
_, err := db.db.Exec(`
func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM session
WHERE expires_at < datetime('now')
AND user_id = ?`, userId)
@@ -380,9 +381,9 @@ func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
return nil
}
func (db AuthSqlite) DeleteSession(sessionId string) error {
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
if sessionId != "" {
_, err := db.db.Exec("DELETE FROM session WHERE session_id = ?", sessionId)
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
if err != nil {
slog.Error("Could not delete session", "err", err)
return types.ErrInternal

View File

@@ -1,6 +1,7 @@
package db
import (
"context"
"errors"
"log/slog"
"spend-sparrow/internal/types"
@@ -20,7 +21,7 @@ func (l migrationLogger) Verbose() bool {
return false
}
func RunMigrations(db *sqlx.DB, pathPrefix string) error {
func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error {
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
if err != nil {
slog.Error("Could not create Migration instance", "err", err)

View File

@@ -56,7 +56,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
}
// init db
err = db.RunMigrations(database, migrationsPrefix)
err = db.RunMigrations(ctx, database, migrationsPrefix)
if err != nil {
return fmt.Errorf("could not run migrations: %w", err)
}

View File

@@ -44,7 +44,7 @@ func (h AccountImpl) handleAccountPage() http.HandlerFunc {
return
}
accounts, err := h.s.GetAll(user)
accounts, err := h.s.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
@@ -72,7 +72,7 @@ func (h AccountImpl) handleAccountItemComp() http.HandlerFunc {
return
}
account, err := h.s.Get(user, id)
account, err := h.s.Get(r.Context(), user, id)
if err != nil {
handleError(w, r, err)
return
@@ -105,13 +105,13 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc {
id := r.PathValue("id")
name := r.FormValue("name")
if id == "new" {
account, err = h.s.Add(user, name)
account, err = h.s.Add(r.Context(), user, name)
if err != nil {
handleError(w, r, err)
return
}
} else {
account, err = h.s.UpdateName(user, id, name)
account, err = h.s.UpdateName(r.Context(), user, id, name)
if err != nil {
handleError(w, r, err)
return
@@ -135,7 +135,7 @@ func (h AccountImpl) handleDeleteAccount() http.HandlerFunc {
id := r.PathValue("id")
err := h.s.Delete(user, id)
err := h.s.Delete(r.Context(), user, id)
if err != nil {
handleError(w, r, err)
return

View File

@@ -85,7 +85,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
email := r.FormValue("email")
password := r.FormValue("password")
session, user, err := handler.service.SignIn(session, email, password)
session, user, err := handler.service.SignIn(r.Context(), session, email, password)
if err != nil {
return nil, err
}
@@ -163,7 +163,7 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
return
}
go handler.service.SendVerificationMail(user.Id, user.Email)
go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email)
_, err := w.Write([]byte("<p class=\"mt-8\">Verification email sent</p>"))
if err != nil {
@@ -178,7 +178,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
token := r.URL.Query().Get("token")
err := handler.service.VerifyUserEmail(token)
err := handler.service.VerifyUserEmail(r.Context(), token)
isVerified := err == nil
comp := auth.VerifyResponseComp(isVerified)
@@ -203,13 +203,13 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
slog.Info("signing up", "email", email)
user, err := handler.service.SignUp(email, password)
user, err := handler.service.SignUp(r.Context(), email, password)
if err != nil {
return nil, err
}
slog.Info("Sending verification email", "to", user.Email)
go handler.service.SendVerificationMail(user.Id, user.Email)
go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email)
return nil, nil
})
@@ -239,7 +239,7 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc {
session := middleware.GetSession(r)
if session != nil {
err := handler.service.SignOut(session.Id)
err := handler.service.SignOut(r.Context(), session.Id)
if err != nil {
http.Error(w, "An error occurred", http.StatusInternalServerError)
return
@@ -288,7 +288,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
password := r.FormValue("password")
err := handler.service.DeleteAccount(user, password)
err := handler.service.DeleteAccount(r.Context(), user, password)
if err != nil {
if errors.Is(err, service.ErrInvalidCredentials) {
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
@@ -334,7 +334,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
currPass := r.FormValue("current-password")
newPass := r.FormValue("new-password")
err := handler.service.ChangePassword(user, session.Id, currPass, newPass)
err := handler.service.ChangePassword(r.Context(), user, session.Id, currPass, newPass)
if err != nil {
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
return
@@ -370,7 +370,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
}
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
err := handler.service.SendForgotPasswordMail(email)
err := handler.service.SendForgotPasswordMail(r.Context(), email)
return nil, err
})
@@ -396,7 +396,7 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
token := pageUrl.Query().Get("token")
newPass := r.FormValue("new-password")
err = handler.service.ForgotPassword(token, newPass)
err = handler.service.ForgotPassword(r.Context(), token, newPass)
if err != nil {
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
} else {

View File

@@ -17,13 +17,13 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionId := getSessionID(r)
session, user, _ := service.SignInSession(sessionId)
session, user, _ := service.SignInSession(r.Context(), sessionId)
var err error
// Always sign in anonymous
// This way, we can always generate csrf tokens
if session == nil {
session, err = service.SignInAnonymous()
session, err = service.SignInAnonymous(r.Context())
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return

View File

@@ -4,30 +4,26 @@ import (
"log/slog"
"net/http"
"spend-sparrow/internal/service"
"spend-sparrow/internal/types"
"spend-sparrow/internal/utils"
"strings"
)
type csrfResponseWriter struct {
http.ResponseWriter
auth service.Auth
session *types.Session
csrfToken string
}
func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter {
func newCsrfResponseWriter(w http.ResponseWriter, csrfToken string) *csrfResponseWriter {
return &csrfResponseWriter{
ResponseWriter: w,
auth: auth,
session: session,
csrfToken: csrfToken,
}
}
func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
dataStr := string(data)
csrfToken, err := rr.auth.GetCsrfToken(rr.session)
if err == nil {
dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken)
if rr.csrfToken != "" {
dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", rr.csrfToken)
}
return rr.ResponseWriter.Write([]byte(dataStr))
@@ -37,6 +33,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := GetSession(r)
ctx := r.Context()
if r.Method == http.MethodPost ||
r.Method == http.MethodPut ||
@@ -44,7 +41,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
r.Method == http.MethodPatch {
csrfToken := r.Header.Get("Csrf-Token")
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(ctx, csrfToken, session.Id) {
slog.Info("CSRF-Token not correct", "token", csrfToken)
if r.Header.Get("Hx-Request") == "true" {
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
@@ -55,7 +52,17 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
}
}
responseWriter := newCsrfResponseWriter(w, auth, session)
token, err := auth.GetCsrfToken(ctx, session)
if err != nil {
if r.Header.Get("Hx-Request") == "true" {
utils.TriggerToastWithStatus(w, r, "error", "Could not generate CSRF Token", http.StatusBadRequest)
} else {
http.Error(w, "Could not generate CSRF Token", http.StatusBadRequest)
}
return
}
responseWriter := newCsrfResponseWriter(w, token)
next.ServeHTTP(responseWriter, r)
})
}

View File

@@ -14,7 +14,7 @@ func GenerateRecurringTransactions(transactionRecurring service.TransactionRecur
return
}
_ = transactionRecurring.GenerateTransactions(user)
_ = transactionRecurring.GenerateTransactions(r.Context(), user)
next.ServeHTTP(w, r)
})

View File

@@ -65,19 +65,19 @@ func (h TransactionImpl) handleTransactionPage() http.HandlerFunc {
Error: r.URL.Query().Get("error"),
}
transactions, err := h.s.GetAll(user, filter)
transactions, err := h.s.GetAll(r.Context(), user, filter)
if err != nil {
handleError(w, r, err)
return
}
accounts, err := h.account.GetAll(user)
accounts, err := h.account.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
}
treasureChests, err := h.treasureChest.GetAll(user)
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
@@ -105,13 +105,13 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc {
return
}
accounts, err := h.account.GetAll(user)
accounts, err := h.account.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
}
treasureChests, err := h.treasureChest.GetAll(user)
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
@@ -124,7 +124,7 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc {
return
}
transaction, err := h.s.Get(user, id)
transaction, err := h.s.Get(r.Context(), user, id)
if err != nil {
handleError(w, r, err)
return
@@ -212,26 +212,26 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
var transaction *types.Transaction
if idStr == "new" {
transaction, err = h.s.Add(nil, user, input)
transaction, err = h.s.Add(r.Context(), nil, user, input)
if err != nil {
handleError(w, r, err)
return
}
} else {
transaction, err = h.s.Update(user, input)
transaction, err = h.s.Update(r.Context(), user, input)
if err != nil {
handleError(w, r, err)
return
}
}
accounts, err := h.account.GetAll(user)
accounts, err := h.account.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
}
treasureChests, err := h.treasureChest.GetAll(user)
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
@@ -253,7 +253,7 @@ func (h TransactionImpl) handleRecalculate() http.HandlerFunc {
return
}
err := h.s.RecalculateBalances(user)
err := h.s.RecalculateBalances(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
@@ -275,7 +275,7 @@ func (h TransactionImpl) handleDeleteTransaction() http.HandlerFunc {
id := r.PathValue("id")
err := h.s.Delete(user, id)
err := h.s.Delete(r.Context(), user, id)
if err != nil {
handleError(w, r, err)
return

View File

@@ -70,13 +70,13 @@ func (h TransactionRecurringImpl) handleUpdateTransactionRecurring() http.Handle
}
if input.Id == "new" {
_, err := h.s.Add(user, input)
_, err := h.s.Add(r.Context(), user, input)
if err != nil {
handleError(w, r, err)
return
}
} else {
_, err := h.s.Update(user, input)
_, err := h.s.Update(r.Context(), user, input)
if err != nil {
handleError(w, r, err)
return
@@ -101,7 +101,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
accountId := r.URL.Query().Get("account-id")
treasureChestId := r.URL.Query().Get("treasure-chest-id")
err := h.s.Delete(user, id)
err := h.s.Delete(r.Context(), user, id)
if err != nil {
handleError(w, r, err)
return
@@ -118,13 +118,13 @@ func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Req
utils.TriggerToastWithStatus(w, r, "error", "Please select an account or treasure chest", http.StatusBadRequest)
}
if accountId != "" {
transactionsRecurring, err = h.s.GetAllByAccount(user, accountId)
transactionsRecurring, err = h.s.GetAllByAccount(r.Context(), user, accountId)
if err != nil {
handleError(w, r, err)
return
}
} else {
transactionsRecurring, err = h.s.GetAllByTreasureChest(user, treasureChestId)
transactionsRecurring, err = h.s.GetAllByTreasureChest(r.Context(), user, treasureChestId)
if err != nil {
handleError(w, r, err)
return

View File

@@ -48,13 +48,13 @@ func (h TreasureChestImpl) handleTreasureChestPage() http.HandlerFunc {
return
}
treasureChests, err := h.s.GetAll(user)
treasureChests, err := h.s.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
}
transactionsRecurring, err := h.transactionRecurring.GetAll(user)
transactionsRecurring, err := h.transactionRecurring.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
@@ -77,7 +77,7 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc {
return
}
treasureChests, err := h.s.GetAll(user)
treasureChests, err := h.s.GetAll(r.Context(), user)
if err != nil {
handleError(w, r, err)
return
@@ -90,13 +90,13 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc {
return
}
treasureChest, err := h.s.Get(user, id)
treasureChest, err := h.s.Get(r.Context(), user, id)
if err != nil {
handleError(w, r, err)
return
}
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String())
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String())
if err != nil {
handleError(w, r, err)
return
@@ -132,20 +132,20 @@ func (h TreasureChestImpl) handleUpdateTreasureChest() http.HandlerFunc {
parentId := r.FormValue("parent-id")
name := r.FormValue("name")
if id == "new" {
treasureChest, err = h.s.Add(user, parentId, name)
treasureChest, err = h.s.Add(r.Context(), user, parentId, name)
if err != nil {
handleError(w, r, err)
return
}
} else {
treasureChest, err = h.s.Update(user, id, parentId, name)
treasureChest, err = h.s.Update(r.Context(), user, id, parentId, name)
if err != nil {
handleError(w, r, err)
return
}
}
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String())
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String())
if err != nil {
handleError(w, r, err)
return
@@ -171,7 +171,7 @@ func (h TreasureChestImpl) handleDeleteTreasureChest() http.HandlerFunc {
id := r.PathValue("id")
err := h.s.Delete(user, id)
err := h.s.Delete(r.Context(), user, id)
if err != nil {
handleError(w, r, err)
return

View File

@@ -1,6 +1,7 @@
package service
import (
"context"
"errors"
"fmt"
"log/slog"
@@ -12,11 +13,11 @@ import (
)
type Account interface {
Add(user *types.User, name string) (*types.Account, error)
UpdateName(user *types.User, id string, name string) (*types.Account, error)
Get(user *types.User, id string) (*types.Account, error)
GetAll(user *types.User) ([]*types.Account, error)
Delete(user *types.User, id string) error
Add(ctx context.Context, user *types.User, name string) (*types.Account, error)
UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error)
Get(ctx context.Context, user *types.User, id string) (*types.Account, error)
GetAll(ctx context.Context, user *types.User) ([]*types.Account, error)
Delete(ctx context.Context, user *types.User, id string) error
}
type AccountImpl struct {
@@ -33,7 +34,7 @@ func NewAccount(db *sqlx.DB, random Random, clock Clock) Account {
}
}
func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) {
func (s AccountImpl) Add(ctx context.Context, user *types.User, name string) (*types.Account, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -64,7 +65,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
UpdatedBy: nil,
}
r, err := s.db.NamedExec(`
r, err := s.db.NamedExecContext(ctx, `
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account)
err = db.TransformAndLogDbError("account Insert", r, err)
@@ -75,7 +76,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
return account, nil
}
func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) {
func (s AccountImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -89,7 +90,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil {
return nil, err
@@ -99,7 +100,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
}()
var account types.Account
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -113,7 +114,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
account.UpdatedAt = &timestamp
account.UpdatedBy = &user.Id
r, err := tx.NamedExec(`
r, err := tx.NamedExecContext(ctx, `
UPDATE account
SET
name = :name,
@@ -135,7 +136,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
return &account, nil
}
func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
func (s AccountImpl) Get(ctx context.Context, user *types.User, id string) (*types.Account, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -146,7 +147,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
}
var account types.Account
err = s.db.Get(&account, `
err = s.db.GetContext(ctx, &account, `
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Get", nil, err)
if err != nil {
@@ -157,13 +158,13 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
return &account, nil
}
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
func (s AccountImpl) GetAll(ctx context.Context, user *types.User) ([]*types.Account, error) {
if user == nil {
return nil, ErrUnauthorized
}
accounts := make([]*types.Account, 0)
err := s.db.Select(&accounts, `
err := s.db.SelectContext(ctx, &accounts, `
SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id)
err = db.TransformAndLogDbError("account GetAll", nil, err)
if err != nil {
@@ -173,7 +174,7 @@ func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
return accounts, nil
}
func (s AccountImpl) Delete(user *types.User, id string) error {
func (s AccountImpl) Delete(ctx context.Context, user *types.User, id string) error {
if user == nil {
return ErrUnauthorized
}
@@ -183,7 +184,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("account Delete", nil, err)
if err != nil {
return err
@@ -193,7 +194,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
}()
transactionsCount := 0
err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid)
err = tx.GetContext(ctx, &transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Delete", nil, err)
if err != nil {
return err
@@ -202,7 +203,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("account has transactions, cannot delete: %w", ErrBadRequest)
}
res, err := tx.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("account Delete", res, err)
if err != nil {
return err

View File

@@ -26,24 +26,24 @@ var (
)
type Auth interface {
SignUp(email string, password string) (*types.User, error)
SendVerificationMail(userId uuid.UUID, email string)
VerifyUserEmail(token string) error
SignUp(ctx context.Context, email string, password string) (*types.User, error)
SendVerificationMail(ctx context.Context, userId uuid.UUID, email string)
VerifyUserEmail(ctx context.Context, token string) error
SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error)
SignInSession(sessionId string) (*types.Session, *types.User, error)
SignInAnonymous() (*types.Session, error)
SignOut(sessionId string) error
SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error)
SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error)
SignInAnonymous(ctx context.Context) (*types.Session, error)
SignOut(ctx context.Context, sessionId string) error
DeleteAccount(user *types.User, currPass string) error
DeleteAccount(ctx context.Context, user *types.User, currPass string) error
ChangePassword(user *types.User, sessionId string, currPass, newPass string) error
ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error
SendForgotPasswordMail(email string) error
ForgotPassword(token string, newPass string) error
SendForgotPasswordMail(ctx context.Context, email string) error
ForgotPassword(ctx context.Context, token string, newPass string) error
IsCsrfTokenValid(tokenStr string, sessionId string) bool
GetCsrfToken(session *types.Session) (string, error)
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
}
type AuthImpl struct {
@@ -64,8 +64,8 @@ func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *
}
}
func (service AuthImpl) SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) {
user, err := service.db.GetUserByEmail(email)
func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) {
user, err := service.db.GetUserByEmail(ctx, email)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil, nil, ErrInvalidCredentials
@@ -80,12 +80,12 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
return nil, nil, ErrInvalidCredentials
}
err = service.cleanUpSessionWithTokens(session)
err = service.cleanUpSessionWithTokens(ctx, session)
if err != nil {
return nil, nil, types.ErrInternal
}
session, err = service.createSession(user.Id)
session, err = service.createSession(ctx, user.Id)
if err != nil {
return nil, nil, types.ErrInternal
}
@@ -93,17 +93,17 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
return session, user, nil
}
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
if sessionId == "" {
return nil, nil, ErrSessionIdInvalid
}
session, err := service.db.GetSession(sessionId)
session, err := service.db.GetSession(ctx, sessionId)
if err != nil {
return nil, nil, types.ErrInternal
}
if session.ExpiresAt.Before(service.clock.Now()) {
_ = service.db.DeleteSession(sessionId)
_ = service.db.DeleteSession(ctx, sessionId)
return nil, nil, nil
}
@@ -111,7 +111,7 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
return session, nil, nil
}
user, err := service.db.GetUser(session.UserId)
user, err := service.db.GetUser(ctx, session.UserId)
if err != nil {
return nil, nil, types.ErrInternal
}
@@ -119,8 +119,8 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
return session, user, nil
}
func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
session, err := service.createSession(uuid.Nil)
func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, error) {
session, err := service.createSession(ctx, uuid.Nil)
if err != nil {
return nil, types.ErrInternal
}
@@ -130,7 +130,7 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
return session, nil
}
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
func (service AuthImpl) SignUp(ctx context.Context, email string, password string) (*types.User, error) {
_, err := mail.ParseAddress(email)
if err != nil {
return nil, ErrInvalidEmail
@@ -154,7 +154,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
err = service.db.InsertUser(user)
err = service.db.InsertUser(ctx, user)
if err != nil {
if errors.Is(err, db.ErrAlreadyExists) {
return nil, ErrAccountExists
@@ -166,8 +166,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
return user, nil
}
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify)
if err != nil && !errors.Is(err, db.ErrNotFound) {
return
}
@@ -192,7 +192,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
service.clock.Now(),
service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token)
err = service.db.InsertToken(ctx, token)
if err != nil {
return
}
@@ -208,17 +208,17 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
service.mail.SendMail(email, "Welcome to spend-sparrow", w.String())
}
func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
if tokenStr == "" {
return types.ErrInternal
}
token, err := service.db.GetToken(tokenStr)
token, err := service.db.GetToken(ctx, tokenStr)
if err != nil {
return types.ErrInternal
}
user, err := service.db.GetUser(token.UserId)
user, err := service.db.GetUser(ctx, token.UserId)
if err != nil {
return types.ErrInternal
}
@@ -236,21 +236,21 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
user.EmailVerified = true
user.EmailVerifiedAt = &now
err = service.db.UpdateUser(user)
err = service.db.UpdateUser(ctx, user)
if err != nil {
return types.ErrInternal
}
_ = service.db.DeleteToken(token.Token)
_ = service.db.DeleteToken(ctx, token.Token)
return nil
}
func (service AuthImpl) SignOut(sessionId string) error {
return service.db.DeleteSession(sessionId)
func (service AuthImpl) SignOut(ctx context.Context, sessionId string) error {
return service.db.DeleteSession(ctx, sessionId)
}
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
userDb, err := service.db.GetUser(user.Id)
func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, currPass string) error {
userDb, err := service.db.GetUser(ctx, user.Id)
if err != nil {
return types.ErrInternal
}
@@ -260,7 +260,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
return ErrInvalidCredentials
}
err = service.db.DeleteUser(user.Id)
err = service.db.DeleteUser(ctx, user.Id)
if err != nil {
return err
}
@@ -270,7 +270,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
return nil
}
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error {
if !isPasswordValid(newPass) {
return ErrInvalidPassword
}
@@ -288,18 +288,18 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP
newHash := GetHashPassword(newPass, user.Salt)
user.Password = newHash
err := service.db.UpdateUser(user)
err := service.db.UpdateUser(ctx, user)
if err != nil {
return err
}
sessions, err := service.db.GetSessions(user.Id)
sessions, err := service.db.GetSessions(ctx, user.Id)
if err != nil {
return types.ErrInternal
}
for _, s := range sessions {
if s.Id != sessionId {
err = service.db.DeleteSession(s.Id)
err = service.db.DeleteSession(ctx, s.Id)
if err != nil {
return types.ErrInternal
}
@@ -309,13 +309,13 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP
return nil
}
func (service AuthImpl) SendForgotPasswordMail(email string) error {
func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
tokenStr, err := service.random.String(32)
if err != nil {
return err
}
user, err := service.db.GetUserByEmail(email)
user, err := service.db.GetUserByEmail(ctx, email)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
return nil
@@ -332,7 +332,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
service.clock.Now(),
service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(token)
err = service.db.InsertToken(ctx, token)
if err != nil {
return types.ErrInternal
}
@@ -348,17 +348,17 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
return nil
}
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
if !isPasswordValid(newPass) {
return ErrInvalidPassword
}
token, err := service.db.GetToken(tokenStr)
token, err := service.db.GetToken(ctx, tokenStr)
if err != nil {
return ErrTokenInvalid
}
err = service.db.DeleteToken(tokenStr)
err = service.db.DeleteToken(ctx, tokenStr)
if err != nil {
return err
}
@@ -368,7 +368,7 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
return ErrTokenInvalid
}
user, err := service.db.GetUser(token.UserId)
user, err := service.db.GetUser(ctx, token.UserId)
if err != nil {
slog.Error("Could not get user from token", "err", err)
return types.ErrInternal
@@ -377,18 +377,18 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
passHash := GetHashPassword(newPass, user.Salt)
user.Password = passHash
err = service.db.UpdateUser(user)
err = service.db.UpdateUser(ctx, user)
if err != nil {
return err
}
sessions, err := service.db.GetSessions(user.Id)
sessions, err := service.db.GetSessions(ctx, user.Id)
if err != nil {
return types.ErrInternal
}
for _, session := range sessions {
err = service.db.DeleteSession(session.Id)
err = service.db.DeleteSession(ctx, session.Id)
if err != nil {
return types.ErrInternal
}
@@ -397,8 +397,8 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
return nil
}
func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool {
token, err := service.db.GetToken(tokenStr)
func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
token, err := service.db.GetToken(ctx, tokenStr)
if err != nil {
return false
}
@@ -412,12 +412,12 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
return true
}
func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session) (string, error) {
if session == nil {
return "", types.ErrInternal
}
tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
if len(tokens) > 0 {
return tokens[0].Token, nil
@@ -435,7 +435,7 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
types.TokenTypeCsrf,
service.clock.Now(),
service.clock.Now().Add(8*time.Hour))
err = service.db.InsertToken(token)
err = service.db.InsertToken(ctx, token)
if err != nil {
return "", types.ErrInternal
}
@@ -445,22 +445,22 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
return tokenStr, nil
}
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error {
if session == nil {
return nil
}
err := service.db.DeleteSession(session.Id)
err := service.db.DeleteSession(ctx, session.Id)
if err != nil {
return types.ErrInternal
}
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
if err != nil {
return types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(token.Token)
err = service.db.DeleteToken(ctx, token.Token)
if err != nil {
return types.ErrInternal
}
@@ -469,13 +469,13 @@ func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
return nil
}
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*types.Session, error) {
sessionId, err := service.random.String(32)
if err != nil {
return nil, types.ErrInternal
}
err = service.db.DeleteOldSessions(userId)
err = service.db.DeleteOldSessions(ctx, userId)
if err != nil {
return nil, types.ErrInternal
}
@@ -485,7 +485,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error)
session := types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session)
err = service.db.InsertSession(ctx, session)
if err != nil {
return nil, types.ErrInternal
}

View File

@@ -1,6 +1,7 @@
package service
import (
"context"
"errors"
"fmt"
"log/slog"
@@ -13,13 +14,13 @@ import (
)
type Transaction interface {
Add(tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
Update(user *types.User, transaction types.Transaction) (*types.Transaction, error)
Get(user *types.User, id string) (*types.Transaction, error)
GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
Delete(user *types.User, id string) error
Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
Update(ctx context.Context, user *types.User, transaction types.Transaction) (*types.Transaction, error)
Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error)
GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
Delete(ctx context.Context, user *types.User, id string) error
RecalculateBalances(user *types.User) error
RecalculateBalances(ctx context.Context, user *types.User) error
}
type TransactionImpl struct {
@@ -36,7 +37,7 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
}
}
func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -45,7 +46,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
ownsTransaction := false
if tx == nil {
ownsTransaction = true
tx, err = s.db.Beginx()
tx, err = s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transaction Add", nil, err)
if err != nil {
return nil, err
@@ -55,12 +56,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
}()
}
transaction, err := s.validateAndEnrichTransaction(tx, nil, user.Id, transactionInput)
transaction, err := s.validateAndEnrichTransaction(ctx, tx, nil, user.Id, transactionInput)
if err != nil {
return nil, err
}
r, err := tx.NamedExec(`
r, err := tx.NamedExecContext(ctx, `
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
party, description, error, created_at, created_by)
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
@@ -71,8 +72,8 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
}
if transaction.Error == nil && transaction.AccountId != nil {
r, err = tx.Exec(`
UPDATE account
r, err = tx.ExecContext(ctx, `
UPDATE actx context.Context,ccount
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction Add", r, err)
@@ -82,7 +83,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
}
if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err = tx.Exec(`
r, err = tx.ExecContext(ctx, `
UPDATE treasure_chest
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
@@ -103,12 +104,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
return transaction, nil
}
func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*types.Transaction, error) {
func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) {
if user == nil {
return nil, ErrUnauthorized
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil {
return nil, err
@@ -118,7 +119,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
}()
transaction := &types.Transaction{}
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -128,7 +129,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
}
if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(`
r, err := tx.ExecContext(ctx, `
UPDATE account
SET current_balance = current_balance - ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
@@ -138,7 +139,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
}
}
if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(`
r, err := tx.ExecContext(ctx, `
UPDATE treasure_chest
SET current_balance = current_balance - ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
@@ -148,13 +149,13 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
}
}
transaction, err = s.validateAndEnrichTransaction(tx, transaction, user.Id, input)
transaction, err = s.validateAndEnrichTransaction(ctx, tx, transaction, user.Id, input)
if err != nil {
return nil, err
}
if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(`
r, err := tx.ExecContext(ctx, `
UPDATE account
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
@@ -164,7 +165,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
}
}
if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(`
r, err := tx.ExecContext(ctx, `
UPDATE treasure_chest
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
@@ -174,7 +175,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
}
}
r, err := tx.NamedExec(`
r, err := tx.NamedExecContext(ctx, `
UPDATE "transaction"
SET
account_id = :account_id,
@@ -202,7 +203,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
return transaction, nil
}
func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, error) {
func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -213,7 +214,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
}
var transaction types.Transaction
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Get", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -225,13 +226,13 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
return &transaction, nil
}
func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
if user == nil {
return nil, ErrUnauthorized
}
transactions := make([]*types.Transaction, 0)
err := s.db.Select(&transactions, `
err := s.db.SelectContext(ctx, &transactions, `
SELECT *
FROM "transaction"
WHERE user_id = ?
@@ -254,7 +255,7 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF
return transactions, nil
}
func (s TransactionImpl) Delete(user *types.User, id string) error {
func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error {
if user == nil {
return ErrUnauthorized
}
@@ -264,7 +265,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transaction Delete", nil, err)
if err != nil {
return nil
@@ -274,14 +275,14 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
}()
var transaction types.Transaction
err = tx.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = tx.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Delete", nil, err)
if err != nil {
return err
}
if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(`
r, err := tx.ExecContext(ctx, `
UPDATE account
SET current_balance = current_balance - ?
WHERE id = ?
@@ -293,7 +294,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
}
if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(`
r, err := tx.ExecContext(ctx, `
UPDATE treasure_chest
SET current_balance = current_balance - ?
WHERE id = ?
@@ -304,7 +305,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
}
}
r, err := tx.Exec("DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id)
r, err := tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil {
return err
@@ -319,12 +320,12 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
return nil
}
func (s TransactionImpl) RecalculateBalances(user *types.User) error {
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error {
if user == nil {
return ErrUnauthorized
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil {
return err
@@ -333,7 +334,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
_ = tx.Rollback()
}()
r, err := tx.Exec(`
r, err := tx.ExecContext(ctx, `
UPDATE account
SET current_balance = 0
WHERE user_id = ?`, user.Id)
@@ -342,7 +343,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
return err
}
r, err = tx.Exec(`
r, err = tx.ExecContext(ctx, `
UPDATE treasure_chest
SET current_balance = 0
WHERE user_id = ?`, user.Id)
@@ -351,7 +352,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
return err
}
rows, err := tx.Queryx(`
rows, err := tx.QueryxContext(ctx, `
SELECT *
FROM "transaction"
WHERE user_id = ?`, user.Id)
@@ -375,7 +376,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
}
s.updateErrors(&transaction)
r, err = tx.Exec(`
r, err = tx.ExecContext(ctx, `
UPDATE "transaction"
SET error = ?
WHERE user_id = ?
@@ -390,7 +391,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
}
if transaction.AccountId != nil {
r, err = tx.Exec(`
r, err = tx.ExecContext(ctx, `
UPDATE account
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
@@ -400,7 +401,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
}
}
if transaction.TreasureChestId != nil {
r, err = tx.Exec(`
r, err = tx.ExecContext(ctx, `
UPDATE treasure_chest
SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
@@ -420,7 +421,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
return nil
}
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) {
func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) {
var (
id uuid.UUID
createdAt time.Time
@@ -449,7 +450,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
}
if input.AccountId != nil {
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId)
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId)
err = db.TransformAndLogDbError("transaction validate", nil, err)
if err != nil {
return nil, err
@@ -462,7 +463,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
if input.TreasureChestId != nil {
var treasureChest types.TreasureChest
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
err = db.TransformAndLogDbError("transaction validate", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {

View File

@@ -1,6 +1,7 @@
package service
import (
"context"
"errors"
"fmt"
"log/slog"
@@ -15,14 +16,14 @@ import (
)
type TransactionRecurring interface {
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
GetAll(user *types.User) ([]*types.TransactionRecurring, error)
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error)
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(user *types.User, id string) error
Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error)
GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error)
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(ctx context.Context, user *types.User, id string) error
GenerateTransactions(user *types.User) error
GenerateTransactions(ctx context.Context, user *types.User) error
}
type TransactionRecurringImpl struct {
@@ -41,7 +42,7 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio
}
}
func (s TransactionRecurringImpl) Add(
func (s TransactionRecurringImpl) Add(ctx context.Context,
user *types.User,
transactionRecurringInput types.TransactionRecurringInput,
) (*types.TransactionRecurring, error) {
@@ -49,7 +50,7 @@ func (s TransactionRecurringImpl) Add(
return nil, ErrUnauthorized
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring Add", nil, err)
if err != nil {
return nil, err
@@ -58,12 +59,12 @@ func (s TransactionRecurringImpl) Add(
_ = tx.Rollback()
}()
transactionRecurring, err := s.validateAndEnrichTransactionRecurring(tx, nil, user.Id, transactionRecurringInput)
transactionRecurring, err := s.validateAndEnrichTransactionRecurring(ctx, tx, nil, user.Id, transactionRecurringInput)
if err != nil {
return nil, err
}
r, err := tx.NamedExec(`
r, err := tx.NamedExecContext(ctx, `
INSERT INTO "transaction_recurring" (id, user_id, interval_months,
next_execution, party, description, account_id, treasure_chest_id, value, created_at, created_by)
VALUES (:id, :user_id, :interval_months,
@@ -83,7 +84,7 @@ func (s TransactionRecurringImpl) Add(
return transactionRecurring, nil
}
func (s TransactionRecurringImpl) Update(
func (s TransactionRecurringImpl) Update(ctx context.Context,
user *types.User,
input types.TransactionRecurringInput,
) (*types.TransactionRecurring, error) {
@@ -96,7 +97,7 @@ func (s TransactionRecurringImpl) Update(
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil {
return nil, err
@@ -106,7 +107,7 @@ func (s TransactionRecurringImpl) Update(
}()
transactionRecurring := &types.TransactionRecurring{}
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -115,12 +116,12 @@ func (s TransactionRecurringImpl) Update(
return nil, types.ErrInternal
}
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(tx, transactionRecurring, user.Id, input)
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
if err != nil {
return nil, err
}
r, err := tx.NamedExec(`
r, err := tx.NamedExecContext(ctx, `
UPDATE transaction_recurring
SET
interval_months = :interval_months,
@@ -148,13 +149,13 @@ func (s TransactionRecurringImpl) Update(
return transactionRecurring, nil
}
func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.TransactionRecurring, error) {
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) {
if user == nil {
return nil, ErrUnauthorized
}
transactionRecurrings := make([]*types.TransactionRecurring, 0)
err := s.db.Select(&transactionRecurrings, `
err := s.db.SelectContext(ctx, &transactionRecurrings, `
SELECT *
FROM transaction_recurring
WHERE user_id = ?
@@ -168,7 +169,7 @@ func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.Transaction
return transactionRecurrings, nil
}
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -179,7 +180,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil {
return nil, err
@@ -189,7 +190,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
}()
var rowCount int
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -199,7 +200,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
}
transactionRecurrings := make([]*types.TransactionRecurring, 0)
err = tx.Select(&transactionRecurrings, `
err = tx.SelectContext(ctx, &transactionRecurrings, `
SELECT *
FROM transaction_recurring
WHERE user_id = ?
@@ -220,7 +221,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
return transactionRecurrings, nil
}
func (s TransactionRecurringImpl) GetAllByTreasureChest(
func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
user *types.User,
treasureChestId string,
) ([]*types.TransactionRecurring, error) {
@@ -234,7 +235,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil {
return nil, err
@@ -244,7 +245,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
}()
var rowCount int
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -254,7 +255,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
}
transactionRecurrings := make([]*types.TransactionRecurring, 0)
err = tx.Select(&transactionRecurrings, `
err = tx.SelectContext(ctx, &transactionRecurrings, `
SELECT *
FROM transaction_recurring
WHERE user_id = ?
@@ -275,7 +276,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
return transactionRecurrings, nil
}
func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error {
if user == nil {
return ErrUnauthorized
}
@@ -285,7 +286,7 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
if err != nil {
return nil
@@ -295,13 +296,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
}()
var transactionRecurring types.TransactionRecurring
err = tx.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = tx.GetContext(ctx, &transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
if err != nil {
return err
}
r, err := tx.Exec("DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id)
r, err := tx.ExecContext(ctx, "DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("transactionRecurring Delete", r, err)
if err != nil {
return err
@@ -316,13 +317,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
return nil
}
func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error {
if user == nil {
return ErrUnauthorized
}
now := s.clock.Now()
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
if err != nil {
return err
@@ -332,7 +333,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
}()
recurringTransactions := make([]*types.TransactionRecurring, 0)
err = tx.Select(&recurringTransactions, `
err = tx.SelectContext(ctx, &recurringTransactions, `
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
user.Id, now)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
@@ -350,13 +351,13 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
Value: transactionRecurring.Value,
}
_, err = s.transaction.Add(tx, user, transaction)
_, err = s.transaction.Add(ctx, tx, user, transaction)
if err != nil {
return err
}
nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0)
r, err := tx.Exec(`UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`,
r, err := tx.ExecContext(ctx, `UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`,
nextExecution, transactionRecurring.Id, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err)
if err != nil {
@@ -373,6 +374,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
}
func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
ctx context.Context,
tx *sqlx.Tx,
oldTransactionRecurring *types.TransactionRecurring,
userId uuid.UUID,
@@ -417,7 +419,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
}
accountUuid = &temp
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
if err != nil {
return nil, err
@@ -438,7 +440,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
}
treasureChestUuid = &temp
var treasureChest types.TreasureChest
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {

View File

@@ -1,6 +1,7 @@
package service
import (
"context"
"errors"
"fmt"
"log/slog"
@@ -13,11 +14,11 @@ import (
)
type TreasureChest interface {
Add(user *types.User, parentId, name string) (*types.TreasureChest, error)
Update(user *types.User, id, parentId, name string) (*types.TreasureChest, error)
Get(user *types.User, id string) (*types.TreasureChest, error)
GetAll(user *types.User) ([]*types.TreasureChest, error)
Delete(user *types.User, id string) error
Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error)
Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error)
Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error)
GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error)
Delete(ctx context.Context, user *types.User, id string) error
}
type TreasureChestImpl struct {
@@ -34,7 +35,7 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
}
}
func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.TreasureChest, error) {
func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -51,7 +52,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
var parentUuid *uuid.UUID
if parentId != "" {
parent, err := s.Get(user, parentId)
parent, err := s.Get(ctx, user, parentId)
if err != nil {
return nil, err
}
@@ -76,7 +77,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
UpdatedBy: nil,
}
r, err := s.db.NamedExec(`
r, err := s.db.NamedExecContext(ctx, `
INSERT INTO treasure_chest (id, parent_id, user_id, name, current_balance, created_at, created_by)
VALUES (:id, :parent_id, :user_id, :name, :current_balance, :created_at, :created_by)`, treasureChest)
err = db.TransformAndLogDbError("treasureChest Insert", r, err)
@@ -87,7 +88,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
return treasureChest, nil
}
func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -101,7 +102,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil {
return nil, err
@@ -111,7 +112,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
}()
treasureChest := &types.TreasureChest{}
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -122,12 +123,12 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
var parentUuid *uuid.UUID
if parentId != "" {
parent, err := s.Get(user, parentId)
parent, err := s.Get(ctx, user, parentId)
if err != nil {
return nil, err
}
var childCount int
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil {
return nil, err
@@ -145,7 +146,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
treasureChest.UpdatedAt = &timestamp
treasureChest.UpdatedBy = &user.Id
r, err := tx.NamedExec(`
r, err := tx.NamedExecContext(ctx, `
UPDATE treasure_chest
SET
parent_id = :parent_id,
@@ -169,7 +170,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
return treasureChest, nil
}
func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChest, error) {
func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) {
if user == nil {
return nil, ErrUnauthorized
}
@@ -180,7 +181,7 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
}
var treasureChest types.TreasureChest
err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("treasureChest Get", nil, err)
if err != nil {
if errors.Is(err, db.ErrNotFound) {
@@ -192,13 +193,13 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
return &treasureChest, nil
}
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) {
func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) {
if user == nil {
return nil, ErrUnauthorized
}
treasureChests := make([]*types.TreasureChest, 0)
err := s.db.Select(&treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id)
err := s.db.SelectContext(ctx, &treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("treasureChest GetAll", nil, err)
if err != nil {
return nil, err
@@ -207,7 +208,7 @@ func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, err
return sortTree(treasureChests), nil
}
func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error {
if user == nil {
return ErrUnauthorized
}
@@ -217,7 +218,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
tx, err := s.db.Beginx()
tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil {
return nil
@@ -227,7 +228,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
}()
childCount := 0
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil {
return err
@@ -238,7 +239,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
}
transactionsCount := 0
err = tx.Get(&transactionsCount,
err = tx.GetContext(ctx, &transactionsCount,
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
@@ -250,7 +251,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
}
recurringCount := 0
err = tx.Get(&recurringCount, `
err = tx.GetContext(ctx, &recurringCount, `
SELECT COUNT(*) FROM transaction_recurring WHERE user_id = ? AND treasure_chest_id = ?`,
user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
@@ -261,7 +262,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest)
}
r, err := tx.Exec(`DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
err = db.TransformAndLogDbError("treasureChest Delete", r, err)
if err != nil {
return err

View File

@@ -29,7 +29,7 @@ func DoRedirect(w http.ResponseWriter, r *http.Request, url string) {
}
}
func WaitMinimumTime[T interface{}](waitTime time.Duration, f func() (T, error)) (T, error) {
func WaitMinimumTime[T any](waitTime time.Duration, f func() (T, error)) (T, error) {
start := time.Now()
result, err := f()
time.Sleep(waitTime - time.Since(start))