feat(observabillity): #153 instrument sqlx
This commit was merged in pull request #160.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log/slog"
|
||||
@@ -13,23 +14,23 @@ import (
|
||||
)
|
||||
|
||||
type Auth interface {
|
||||
InsertUser(user *types.User) error
|
||||
UpdateUser(user *types.User) error
|
||||
GetUserByEmail(email string) (*types.User, error)
|
||||
GetUser(userId uuid.UUID) (*types.User, error)
|
||||
DeleteUser(userId uuid.UUID) error
|
||||
InsertUser(ctx context.Context, user *types.User) error
|
||||
UpdateUser(ctx context.Context, user *types.User) error
|
||||
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
|
||||
GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error)
|
||||
DeleteUser(ctx context.Context, userId uuid.UUID) error
|
||||
|
||||
InsertToken(token *types.Token) error
|
||||
GetToken(token string) (*types.Token, error)
|
||||
GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
|
||||
GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error)
|
||||
DeleteToken(token string) error
|
||||
InsertToken(ctx context.Context, token *types.Token) error
|
||||
GetToken(ctx context.Context, token string) (*types.Token, error)
|
||||
GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
|
||||
GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error)
|
||||
DeleteToken(ctx context.Context, token string) error
|
||||
|
||||
InsertSession(session *types.Session) error
|
||||
GetSession(sessionId string) (*types.Session, error)
|
||||
GetSessions(userId uuid.UUID) ([]*types.Session, error)
|
||||
DeleteSession(sessionId string) error
|
||||
DeleteOldSessions(userId uuid.UUID) error
|
||||
InsertSession(ctx context.Context, session *types.Session) error
|
||||
GetSession(ctx context.Context, sessionId string) (*types.Session, error)
|
||||
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
|
||||
DeleteSession(ctx context.Context, sessionId string) error
|
||||
DeleteOldSessions(ctx context.Context, userId uuid.UUID) error
|
||||
}
|
||||
|
||||
type AuthSqlite struct {
|
||||
@@ -40,8 +41,8 @@ func NewAuthSqlite(db *sqlx.DB) *AuthSqlite {
|
||||
return &AuthSqlite{db: db}
|
||||
}
|
||||
|
||||
func (db AuthSqlite) InsertUser(user *types.User) error {
|
||||
_, err := db.db.Exec(`
|
||||
func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error {
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt)
|
||||
@@ -58,8 +59,8 @@ func (db AuthSqlite) InsertUser(user *types.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) UpdateUser(user *types.User) error {
|
||||
_, err := db.db.Exec(`
|
||||
func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error {
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
UPDATE user
|
||||
SET email_verified = ?, email_verified_at = ?, password = ?
|
||||
WHERE user_id = ?`,
|
||||
@@ -73,7 +74,7 @@ func (db AuthSqlite) UpdateUser(user *types.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
|
||||
func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
||||
var (
|
||||
userId uuid.UUID
|
||||
emailVerified bool
|
||||
@@ -84,7 +85,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
err := db.db.QueryRow(`
|
||||
err := db.db.QueryRowContext(ctx, `
|
||||
SELECT user_id, email_verified, email_verified_at, password, salt, created_at
|
||||
FROM user
|
||||
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
||||
@@ -100,7 +101,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
|
||||
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
|
||||
func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) {
|
||||
var (
|
||||
email string
|
||||
emailVerified bool
|
||||
@@ -111,7 +112,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
err := db.db.QueryRow(`
|
||||
err := db.db.QueryRowContext(ctx, `
|
||||
SELECT email, email_verified, email_verified_at, password, salt, created_at
|
||||
FROM user
|
||||
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
||||
@@ -127,49 +128,49 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
|
||||
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
|
||||
tx, err := db.db.Begin()
|
||||
func (db AuthSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error {
|
||||
tx, err := db.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
slog.Error("Could not start transaction", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM account WHERE user_id = ?", userId)
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
slog.Error("Could not delete accounts", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM token WHERE user_id = ?", userId)
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
slog.Error("Could not delete user tokens", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM session WHERE user_id = ?", userId)
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
slog.Error("Could not delete sessions", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM user WHERE user_id = ?", userId)
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
slog.Error("Could not delete user", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM treasure_chest WHERE user_id = ?", userId)
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
slog.Error("Could not delete user", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
_, err = tx.Exec("DELETE FROM \"transaction\" WHERE user_id = ?", userId)
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
slog.Error("Could not delete user", "err", err)
|
||||
@@ -185,8 +186,8 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) InsertToken(token *types.Token) error {
|
||||
_, err := db.db.Exec(`
|
||||
func (db AuthSqlite) InsertToken(ctx context.Context, token *types.Token) error {
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
|
||||
|
||||
@@ -198,7 +199,7 @@ func (db AuthSqlite) InsertToken(token *types.Token) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
|
||||
func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, error) {
|
||||
var (
|
||||
userId uuid.UUID
|
||||
sessionId string
|
||||
@@ -209,7 +210,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
|
||||
expiresAt time.Time
|
||||
)
|
||||
|
||||
err := db.db.QueryRow(`
|
||||
err := db.db.QueryRowContext(ctx, `
|
||||
SELECT user_id, session_id, type, created_at, expires_at
|
||||
FROM token
|
||||
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
|
||||
@@ -239,8 +240,8 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
|
||||
return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
|
||||
query, err := db.db.Query(`
|
||||
func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
|
||||
query, err := db.db.QueryContext(ctx, `
|
||||
SELECT token, created_at, expires_at
|
||||
FROM token
|
||||
WHERE user_id = ?
|
||||
@@ -254,8 +255,8 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.
|
||||
return getTokensFromQuery(query, userId, "", tokenType)
|
||||
}
|
||||
|
||||
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
|
||||
query, err := db.db.Query(`
|
||||
func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
|
||||
query, err := db.db.QueryContext(ctx, `
|
||||
SELECT token, created_at, expires_at
|
||||
FROM token
|
||||
WHERE session_id = ?
|
||||
@@ -312,8 +313,8 @@ func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tok
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) DeleteToken(token string) error {
|
||||
_, err := db.db.Exec("DELETE FROM token WHERE token = ?", token)
|
||||
func (db AuthSqlite) DeleteToken(ctx context.Context, token string) error {
|
||||
_, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token)
|
||||
if err != nil {
|
||||
slog.Error("Could not delete token", "err", err)
|
||||
return types.ErrInternal
|
||||
@@ -321,8 +322,8 @@ func (db AuthSqlite) DeleteToken(token string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) InsertSession(session *types.Session) error {
|
||||
_, err := db.db.Exec(`
|
||||
func (db AuthSqlite) InsertSession(ctx context.Context, session *types.Session) error {
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
|
||||
|
||||
@@ -334,14 +335,14 @@ func (db AuthSqlite) InsertSession(session *types.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
||||
func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.Session, error) {
|
||||
var (
|
||||
userId uuid.UUID
|
||||
createdAt time.Time
|
||||
expiresAt time.Time
|
||||
)
|
||||
|
||||
err := db.db.QueryRow(`
|
||||
err := db.db.QueryRowContext(ctx, `
|
||||
SELECT user_id, created_at, expires_at
|
||||
FROM session
|
||||
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
|
||||
@@ -354,9 +355,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
||||
return types.NewSession(sessionId, userId, createdAt, expiresAt), nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
|
||||
func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) {
|
||||
var sessions []*types.Session
|
||||
err := db.db.Select(&sessions, `
|
||||
err := db.db.SelectContext(ctx, &sessions, `
|
||||
SELECT *
|
||||
FROM session
|
||||
WHERE user_id = ?`, userId)
|
||||
@@ -368,8 +369,8 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
|
||||
_, err := db.db.Exec(`
|
||||
func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error {
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
DELETE FROM session
|
||||
WHERE expires_at < datetime('now')
|
||||
AND user_id = ?`, userId)
|
||||
@@ -380,9 +381,9 @@ func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) DeleteSession(sessionId string) error {
|
||||
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
|
||||
if sessionId != "" {
|
||||
_, err := db.db.Exec("DELETE FROM session WHERE session_id = ?", sessionId)
|
||||
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
|
||||
if err != nil {
|
||||
slog.Error("Could not delete session", "err", err)
|
||||
return types.ErrInternal
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"spend-sparrow/internal/types"
|
||||
@@ -20,7 +21,7 @@ func (l migrationLogger) Verbose() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func RunMigrations(db *sqlx.DB, pathPrefix string) error {
|
||||
func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error {
|
||||
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
slog.Error("Could not create Migration instance", "err", err)
|
||||
|
||||
@@ -56,7 +56,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
|
||||
}
|
||||
|
||||
// init db
|
||||
err = db.RunMigrations(database, migrationsPrefix)
|
||||
err = db.RunMigrations(ctx, database, migrationsPrefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not run migrations: %w", err)
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func (h AccountImpl) handleAccountPage() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := h.s.GetAll(user)
|
||||
accounts, err := h.s.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -72,7 +72,7 @@ func (h AccountImpl) handleAccountItemComp() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.s.Get(user, id)
|
||||
account, err := h.s.Get(r.Context(), user, id)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -105,13 +105,13 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc {
|
||||
id := r.PathValue("id")
|
||||
name := r.FormValue("name")
|
||||
if id == "new" {
|
||||
account, err = h.s.Add(user, name)
|
||||
account, err = h.s.Add(r.Context(), user, name)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
account, err = h.s.UpdateName(user, id, name)
|
||||
account, err = h.s.UpdateName(r.Context(), user, id, name)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -135,7 +135,7 @@ func (h AccountImpl) handleDeleteAccount() http.HandlerFunc {
|
||||
|
||||
id := r.PathValue("id")
|
||||
|
||||
err := h.s.Delete(user, id)
|
||||
err := h.s.Delete(r.Context(), user, id)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
|
||||
@@ -85,7 +85,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
||||
email := r.FormValue("email")
|
||||
password := r.FormValue("password")
|
||||
|
||||
session, user, err := handler.service.SignIn(session, email, password)
|
||||
session, user, err := handler.service.SignIn(r.Context(), session, email, password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
go handler.service.SendVerificationMail(user.Id, user.Email)
|
||||
go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email)
|
||||
|
||||
_, err := w.Write([]byte("<p class=\"mt-8\">Verification email sent</p>"))
|
||||
if err != nil {
|
||||
@@ -178,7 +178,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
||||
|
||||
token := r.URL.Query().Get("token")
|
||||
|
||||
err := handler.service.VerifyUserEmail(token)
|
||||
err := handler.service.VerifyUserEmail(r.Context(), token)
|
||||
|
||||
isVerified := err == nil
|
||||
comp := auth.VerifyResponseComp(isVerified)
|
||||
@@ -203,13 +203,13 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
|
||||
|
||||
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
|
||||
slog.Info("signing up", "email", email)
|
||||
user, err := handler.service.SignUp(email, password)
|
||||
user, err := handler.service.SignUp(r.Context(), email, password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("Sending verification email", "to", user.Email)
|
||||
go handler.service.SendVerificationMail(user.Id, user.Email)
|
||||
go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
@@ -239,7 +239,7 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc {
|
||||
session := middleware.GetSession(r)
|
||||
|
||||
if session != nil {
|
||||
err := handler.service.SignOut(session.Id)
|
||||
err := handler.service.SignOut(r.Context(), session.Id)
|
||||
if err != nil {
|
||||
http.Error(w, "An error occurred", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -288,7 +288,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
|
||||
|
||||
password := r.FormValue("password")
|
||||
|
||||
err := handler.service.DeleteAccount(user, password)
|
||||
err := handler.service.DeleteAccount(r.Context(), user, password)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrInvalidCredentials) {
|
||||
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
|
||||
@@ -334,7 +334,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
|
||||
currPass := r.FormValue("current-password")
|
||||
newPass := r.FormValue("new-password")
|
||||
|
||||
err := handler.service.ChangePassword(user, session.Id, currPass, newPass)
|
||||
err := handler.service.ChangePassword(r.Context(), user, session.Id, currPass, newPass)
|
||||
if err != nil {
|
||||
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
@@ -370,7 +370,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
|
||||
}
|
||||
|
||||
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
|
||||
err := handler.service.SendForgotPasswordMail(email)
|
||||
err := handler.service.SendForgotPasswordMail(r.Context(), email)
|
||||
return nil, err
|
||||
})
|
||||
|
||||
@@ -396,7 +396,7 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
|
||||
token := pageUrl.Query().Get("token")
|
||||
newPass := r.FormValue("new-password")
|
||||
|
||||
err = handler.service.ForgotPassword(token, newPass)
|
||||
err = handler.service.ForgotPassword(r.Context(), token, newPass)
|
||||
if err != nil {
|
||||
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
|
||||
} else {
|
||||
|
||||
@@ -17,13 +17,13 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sessionId := getSessionID(r)
|
||||
session, user, _ := service.SignInSession(sessionId)
|
||||
session, user, _ := service.SignInSession(r.Context(), sessionId)
|
||||
|
||||
var err error
|
||||
// Always sign in anonymous
|
||||
// This way, we can always generate csrf tokens
|
||||
if session == nil {
|
||||
session, err = service.SignInAnonymous()
|
||||
session, err = service.SignInAnonymous(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -4,30 +4,26 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"spend-sparrow/internal/service"
|
||||
"spend-sparrow/internal/types"
|
||||
"spend-sparrow/internal/utils"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type csrfResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
auth service.Auth
|
||||
session *types.Session
|
||||
csrfToken string
|
||||
}
|
||||
|
||||
func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter {
|
||||
func newCsrfResponseWriter(w http.ResponseWriter, csrfToken string) *csrfResponseWriter {
|
||||
return &csrfResponseWriter{
|
||||
ResponseWriter: w,
|
||||
auth: auth,
|
||||
session: session,
|
||||
csrfToken: csrfToken,
|
||||
}
|
||||
}
|
||||
|
||||
func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
|
||||
dataStr := string(data)
|
||||
csrfToken, err := rr.auth.GetCsrfToken(rr.session)
|
||||
if err == nil {
|
||||
dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken)
|
||||
if rr.csrfToken != "" {
|
||||
dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", rr.csrfToken)
|
||||
}
|
||||
|
||||
return rr.ResponseWriter.Write([]byte(dataStr))
|
||||
@@ -37,6 +33,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session := GetSession(r)
|
||||
ctx := r.Context()
|
||||
|
||||
if r.Method == http.MethodPost ||
|
||||
r.Method == http.MethodPut ||
|
||||
@@ -44,7 +41,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
||||
r.Method == http.MethodPatch {
|
||||
csrfToken := r.Header.Get("Csrf-Token")
|
||||
|
||||
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
|
||||
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(ctx, csrfToken, session.Id) {
|
||||
slog.Info("CSRF-Token not correct", "token", csrfToken)
|
||||
if r.Header.Get("Hx-Request") == "true" {
|
||||
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
|
||||
@@ -55,7 +52,17 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
||||
}
|
||||
}
|
||||
|
||||
responseWriter := newCsrfResponseWriter(w, auth, session)
|
||||
token, err := auth.GetCsrfToken(ctx, session)
|
||||
if err != nil {
|
||||
if r.Header.Get("Hx-Request") == "true" {
|
||||
utils.TriggerToastWithStatus(w, r, "error", "Could not generate CSRF Token", http.StatusBadRequest)
|
||||
} else {
|
||||
http.Error(w, "Could not generate CSRF Token", http.StatusBadRequest)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
responseWriter := newCsrfResponseWriter(w, token)
|
||||
next.ServeHTTP(responseWriter, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ func GenerateRecurringTransactions(transactionRecurring service.TransactionRecur
|
||||
return
|
||||
}
|
||||
|
||||
_ = transactionRecurring.GenerateTransactions(user)
|
||||
_ = transactionRecurring.GenerateTransactions(r.Context(), user)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
@@ -65,19 +65,19 @@ func (h TransactionImpl) handleTransactionPage() http.HandlerFunc {
|
||||
Error: r.URL.Query().Get("error"),
|
||||
}
|
||||
|
||||
transactions, err := h.s.GetAll(user, filter)
|
||||
transactions, err := h.s.GetAll(r.Context(), user, filter)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := h.account.GetAll(user)
|
||||
accounts, err := h.account.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
treasureChests, err := h.treasureChest.GetAll(user)
|
||||
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -105,13 +105,13 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
accounts, err := h.account.GetAll(user)
|
||||
accounts, err := h.account.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
treasureChests, err := h.treasureChest.GetAll(user)
|
||||
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -124,7 +124,7 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
transaction, err := h.s.Get(user, id)
|
||||
transaction, err := h.s.Get(r.Context(), user, id)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -212,26 +212,26 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
||||
|
||||
var transaction *types.Transaction
|
||||
if idStr == "new" {
|
||||
transaction, err = h.s.Add(nil, user, input)
|
||||
transaction, err = h.s.Add(r.Context(), nil, user, input)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
transaction, err = h.s.Update(user, input)
|
||||
transaction, err = h.s.Update(r.Context(), user, input)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
accounts, err := h.account.GetAll(user)
|
||||
accounts, err := h.account.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
treasureChests, err := h.treasureChest.GetAll(user)
|
||||
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -253,7 +253,7 @@ func (h TransactionImpl) handleRecalculate() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.s.RecalculateBalances(user)
|
||||
err := h.s.RecalculateBalances(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -275,7 +275,7 @@ func (h TransactionImpl) handleDeleteTransaction() http.HandlerFunc {
|
||||
|
||||
id := r.PathValue("id")
|
||||
|
||||
err := h.s.Delete(user, id)
|
||||
err := h.s.Delete(r.Context(), user, id)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
|
||||
@@ -70,13 +70,13 @@ func (h TransactionRecurringImpl) handleUpdateTransactionRecurring() http.Handle
|
||||
}
|
||||
|
||||
if input.Id == "new" {
|
||||
_, err := h.s.Add(user, input)
|
||||
_, err := h.s.Add(r.Context(), user, input)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_, err := h.s.Update(user, input)
|
||||
_, err := h.s.Update(r.Context(), user, input)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -101,7 +101,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
|
||||
accountId := r.URL.Query().Get("account-id")
|
||||
treasureChestId := r.URL.Query().Get("treasure-chest-id")
|
||||
|
||||
err := h.s.Delete(user, id)
|
||||
err := h.s.Delete(r.Context(), user, id)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -118,13 +118,13 @@ func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Req
|
||||
utils.TriggerToastWithStatus(w, r, "error", "Please select an account or treasure chest", http.StatusBadRequest)
|
||||
}
|
||||
if accountId != "" {
|
||||
transactionsRecurring, err = h.s.GetAllByAccount(user, accountId)
|
||||
transactionsRecurring, err = h.s.GetAllByAccount(r.Context(), user, accountId)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
transactionsRecurring, err = h.s.GetAllByTreasureChest(user, treasureChestId)
|
||||
transactionsRecurring, err = h.s.GetAllByTreasureChest(r.Context(), user, treasureChestId)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
|
||||
@@ -48,13 +48,13 @@ func (h TreasureChestImpl) handleTreasureChestPage() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
treasureChests, err := h.s.GetAll(user)
|
||||
treasureChests, err := h.s.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
transactionsRecurring, err := h.transactionRecurring.GetAll(user)
|
||||
transactionsRecurring, err := h.transactionRecurring.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -77,7 +77,7 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
treasureChests, err := h.s.GetAll(user)
|
||||
treasureChests, err := h.s.GetAll(r.Context(), user)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -90,13 +90,13 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
treasureChest, err := h.s.Get(user, id)
|
||||
treasureChest, err := h.s.Get(r.Context(), user, id)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String())
|
||||
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String())
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -132,20 +132,20 @@ func (h TreasureChestImpl) handleUpdateTreasureChest() http.HandlerFunc {
|
||||
parentId := r.FormValue("parent-id")
|
||||
name := r.FormValue("name")
|
||||
if id == "new" {
|
||||
treasureChest, err = h.s.Add(user, parentId, name)
|
||||
treasureChest, err = h.s.Add(r.Context(), user, parentId, name)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
treasureChest, err = h.s.Update(user, id, parentId, name)
|
||||
treasureChest, err = h.s.Update(r.Context(), user, id, parentId, name)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String())
|
||||
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String())
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
@@ -171,7 +171,7 @@ func (h TreasureChestImpl) handleDeleteTreasureChest() http.HandlerFunc {
|
||||
|
||||
id := r.PathValue("id")
|
||||
|
||||
err := h.s.Delete(user, id)
|
||||
err := h.s.Delete(r.Context(), user, id)
|
||||
if err != nil {
|
||||
handleError(w, r, err)
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -12,11 +13,11 @@ import (
|
||||
)
|
||||
|
||||
type Account interface {
|
||||
Add(user *types.User, name string) (*types.Account, error)
|
||||
UpdateName(user *types.User, id string, name string) (*types.Account, error)
|
||||
Get(user *types.User, id string) (*types.Account, error)
|
||||
GetAll(user *types.User) ([]*types.Account, error)
|
||||
Delete(user *types.User, id string) error
|
||||
Add(ctx context.Context, user *types.User, name string) (*types.Account, error)
|
||||
UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error)
|
||||
Get(ctx context.Context, user *types.User, id string) (*types.Account, error)
|
||||
GetAll(ctx context.Context, user *types.User) ([]*types.Account, error)
|
||||
Delete(ctx context.Context, user *types.User, id string) error
|
||||
}
|
||||
|
||||
type AccountImpl struct {
|
||||
@@ -33,7 +34,7 @@ func NewAccount(db *sqlx.DB, random Random, clock Clock) Account {
|
||||
}
|
||||
}
|
||||
|
||||
func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) {
|
||||
func (s AccountImpl) Add(ctx context.Context, user *types.User, name string) (*types.Account, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -64,7 +65,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
|
||||
UpdatedBy: nil,
|
||||
}
|
||||
|
||||
r, err := s.db.NamedExec(`
|
||||
r, err := s.db.NamedExecContext(ctx, `
|
||||
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
|
||||
VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account)
|
||||
err = db.TransformAndLogDbError("account Insert", r, err)
|
||||
@@ -75,7 +76,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) {
|
||||
func (s AccountImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -89,7 +90,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("account Update", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -99,7 +100,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
||||
}()
|
||||
|
||||
var account types.Account
|
||||
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("account Update", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
@@ -113,7 +114,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
||||
account.UpdatedAt = ×tamp
|
||||
account.UpdatedBy = &user.Id
|
||||
|
||||
r, err := tx.NamedExec(`
|
||||
r, err := tx.NamedExecContext(ctx, `
|
||||
UPDATE account
|
||||
SET
|
||||
name = :name,
|
||||
@@ -135,7 +136,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
||||
func (s AccountImpl) Get(ctx context.Context, user *types.User, id string) (*types.Account, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -146,7 +147,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
||||
}
|
||||
|
||||
var account types.Account
|
||||
err = s.db.Get(&account, `
|
||||
err = s.db.GetContext(ctx, &account, `
|
||||
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("account Get", nil, err)
|
||||
if err != nil {
|
||||
@@ -157,13 +158,13 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
||||
func (s AccountImpl) GetAll(ctx context.Context, user *types.User) ([]*types.Account, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
accounts := make([]*types.Account, 0)
|
||||
err := s.db.Select(&accounts, `
|
||||
err := s.db.SelectContext(ctx, &accounts, `
|
||||
SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id)
|
||||
err = db.TransformAndLogDbError("account GetAll", nil, err)
|
||||
if err != nil {
|
||||
@@ -173,7 +174,7 @@ func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func (s AccountImpl) Delete(user *types.User, id string) error {
|
||||
func (s AccountImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
@@ -183,7 +184,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
|
||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("account Delete", nil, err)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -193,7 +194,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
|
||||
}()
|
||||
|
||||
transactionsCount := 0
|
||||
err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid)
|
||||
err = tx.GetContext(ctx, &transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("account Delete", nil, err)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -202,7 +203,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
|
||||
return fmt.Errorf("account has transactions, cannot delete: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
res, err := tx.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
|
||||
res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
|
||||
err = db.TransformAndLogDbError("account Delete", res, err)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -26,24 +26,24 @@ var (
|
||||
)
|
||||
|
||||
type Auth interface {
|
||||
SignUp(email string, password string) (*types.User, error)
|
||||
SendVerificationMail(userId uuid.UUID, email string)
|
||||
VerifyUserEmail(token string) error
|
||||
SignUp(ctx context.Context, email string, password string) (*types.User, error)
|
||||
SendVerificationMail(ctx context.Context, userId uuid.UUID, email string)
|
||||
VerifyUserEmail(ctx context.Context, token string) error
|
||||
|
||||
SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error)
|
||||
SignInSession(sessionId string) (*types.Session, *types.User, error)
|
||||
SignInAnonymous() (*types.Session, error)
|
||||
SignOut(sessionId string) error
|
||||
SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error)
|
||||
SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error)
|
||||
SignInAnonymous(ctx context.Context) (*types.Session, error)
|
||||
SignOut(ctx context.Context, sessionId string) error
|
||||
|
||||
DeleteAccount(user *types.User, currPass string) error
|
||||
DeleteAccount(ctx context.Context, user *types.User, currPass string) error
|
||||
|
||||
ChangePassword(user *types.User, sessionId string, currPass, newPass string) error
|
||||
ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error
|
||||
|
||||
SendForgotPasswordMail(email string) error
|
||||
ForgotPassword(token string, newPass string) error
|
||||
SendForgotPasswordMail(ctx context.Context, email string) error
|
||||
ForgotPassword(ctx context.Context, token string, newPass string) error
|
||||
|
||||
IsCsrfTokenValid(tokenStr string, sessionId string) bool
|
||||
GetCsrfToken(session *types.Session) (string, error)
|
||||
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
|
||||
GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
|
||||
}
|
||||
|
||||
type AuthImpl struct {
|
||||
@@ -64,8 +64,8 @@ func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *
|
||||
}
|
||||
}
|
||||
|
||||
func (service AuthImpl) SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) {
|
||||
user, err := service.db.GetUserByEmail(email)
|
||||
func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) {
|
||||
user, err := service.db.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, nil, ErrInvalidCredentials
|
||||
@@ -80,12 +80,12 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
|
||||
return nil, nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
err = service.cleanUpSessionWithTokens(session)
|
||||
err = service.cleanUpSessionWithTokens(ctx, session)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
|
||||
session, err = service.createSession(user.Id)
|
||||
session, err = service.createSession(ctx, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
@@ -93,17 +93,17 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
|
||||
return session, user, nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
|
||||
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
|
||||
if sessionId == "" {
|
||||
return nil, nil, ErrSessionIdInvalid
|
||||
}
|
||||
|
||||
session, err := service.db.GetSession(sessionId)
|
||||
session, err := service.db.GetSession(ctx, sessionId)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
if session.ExpiresAt.Before(service.clock.Now()) {
|
||||
_ = service.db.DeleteSession(sessionId)
|
||||
_ = service.db.DeleteSession(ctx, sessionId)
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
|
||||
return session, nil, nil
|
||||
}
|
||||
|
||||
user, err := service.db.GetUser(session.UserId)
|
||||
user, err := service.db.GetUser(ctx, session.UserId)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
@@ -119,8 +119,8 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
|
||||
return session, user, nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
|
||||
session, err := service.createSession(uuid.Nil)
|
||||
func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, error) {
|
||||
session, err := service.createSession(ctx, uuid.Nil)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
@@ -130,7 +130,7 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
|
||||
func (service AuthImpl) SignUp(ctx context.Context, email string, password string) (*types.User, error) {
|
||||
_, err := mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidEmail
|
||||
@@ -154,7 +154,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
|
||||
|
||||
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
|
||||
|
||||
err = service.db.InsertUser(user)
|
||||
err = service.db.InsertUser(ctx, user)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrAlreadyExists) {
|
||||
return nil, ErrAccountExists
|
||||
@@ -166,8 +166,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
||||
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
|
||||
func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
|
||||
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify)
|
||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
||||
service.clock.Now(),
|
||||
service.clock.Now().Add(24*time.Hour))
|
||||
|
||||
err = service.db.InsertToken(token)
|
||||
err = service.db.InsertToken(ctx, token)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -208,17 +208,17 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
||||
service.mail.SendMail(email, "Welcome to spend-sparrow", w.String())
|
||||
}
|
||||
|
||||
func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
|
||||
func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
|
||||
if tokenStr == "" {
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
token, err := service.db.GetToken(tokenStr)
|
||||
token, err := service.db.GetToken(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
user, err := service.db.GetUser(token.UserId)
|
||||
user, err := service.db.GetUser(ctx, token.UserId)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
@@ -236,21 +236,21 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
|
||||
user.EmailVerified = true
|
||||
user.EmailVerifiedAt = &now
|
||||
|
||||
err = service.db.UpdateUser(user)
|
||||
err = service.db.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
_ = service.db.DeleteToken(token.Token)
|
||||
_ = service.db.DeleteToken(ctx, token.Token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) SignOut(sessionId string) error {
|
||||
return service.db.DeleteSession(sessionId)
|
||||
func (service AuthImpl) SignOut(ctx context.Context, sessionId string) error {
|
||||
return service.db.DeleteSession(ctx, sessionId)
|
||||
}
|
||||
|
||||
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
||||
userDb, err := service.db.GetUser(user.Id)
|
||||
func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, currPass string) error {
|
||||
userDb, err := service.db.GetUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
@@ -260,7 +260,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
||||
return ErrInvalidCredentials
|
||||
}
|
||||
|
||||
err = service.db.DeleteUser(user.Id)
|
||||
err = service.db.DeleteUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -270,7 +270,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
|
||||
func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error {
|
||||
if !isPasswordValid(newPass) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
@@ -288,18 +288,18 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP
|
||||
newHash := GetHashPassword(newPass, user.Salt)
|
||||
user.Password = newHash
|
||||
|
||||
err := service.db.UpdateUser(user)
|
||||
err := service.db.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessions, err := service.db.GetSessions(user.Id)
|
||||
sessions, err := service.db.GetSessions(ctx, user.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
for _, s := range sessions {
|
||||
if s.Id != sessionId {
|
||||
err = service.db.DeleteSession(s.Id)
|
||||
err = service.db.DeleteSession(ctx, s.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
@@ -309,13 +309,13 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) SendForgotPasswordMail(email string) error {
|
||||
func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
|
||||
tokenStr, err := service.random.String(32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := service.db.GetUserByEmail(email)
|
||||
user, err := service.db.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil
|
||||
@@ -332,7 +332,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
|
||||
service.clock.Now(),
|
||||
service.clock.Now().Add(15*time.Minute))
|
||||
|
||||
err = service.db.InsertToken(token)
|
||||
err = service.db.InsertToken(ctx, token)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
@@ -348,17 +348,17 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
||||
func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
|
||||
if !isPasswordValid(newPass) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
token, err := service.db.GetToken(tokenStr)
|
||||
token, err := service.db.GetToken(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return ErrTokenInvalid
|
||||
}
|
||||
|
||||
err = service.db.DeleteToken(tokenStr)
|
||||
err = service.db.DeleteToken(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -368,7 +368,7 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
||||
return ErrTokenInvalid
|
||||
}
|
||||
|
||||
user, err := service.db.GetUser(token.UserId)
|
||||
user, err := service.db.GetUser(ctx, token.UserId)
|
||||
if err != nil {
|
||||
slog.Error("Could not get user from token", "err", err)
|
||||
return types.ErrInternal
|
||||
@@ -377,18 +377,18 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
||||
passHash := GetHashPassword(newPass, user.Salt)
|
||||
|
||||
user.Password = passHash
|
||||
err = service.db.UpdateUser(user)
|
||||
err = service.db.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessions, err := service.db.GetSessions(user.Id)
|
||||
sessions, err := service.db.GetSessions(ctx, user.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
err = service.db.DeleteSession(session.Id)
|
||||
err = service.db.DeleteSession(ctx, session.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
@@ -397,8 +397,8 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool {
|
||||
token, err := service.db.GetToken(tokenStr)
|
||||
func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
|
||||
token, err := service.db.GetToken(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -412,12 +412,12 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
|
||||
return true
|
||||
}
|
||||
|
||||
func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
|
||||
func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session) (string, error) {
|
||||
if session == nil {
|
||||
return "", types.ErrInternal
|
||||
}
|
||||
|
||||
tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
|
||||
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||
|
||||
if len(tokens) > 0 {
|
||||
return tokens[0].Token, nil
|
||||
@@ -435,7 +435,7 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
|
||||
types.TokenTypeCsrf,
|
||||
service.clock.Now(),
|
||||
service.clock.Now().Add(8*time.Hour))
|
||||
err = service.db.InsertToken(token)
|
||||
err = service.db.InsertToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", types.ErrInternal
|
||||
}
|
||||
@@ -445,22 +445,22 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
|
||||
return tokenStr, nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
|
||||
func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error {
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := service.db.DeleteSession(session.Id)
|
||||
err := service.db.DeleteSession(ctx, session.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
|
||||
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
for _, token := range tokens {
|
||||
err = service.db.DeleteToken(token.Token)
|
||||
err = service.db.DeleteToken(ctx, token.Token)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
@@ -469,13 +469,13 @@ func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
|
||||
func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*types.Session, error) {
|
||||
sessionId, err := service.random.String(32)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
err = service.db.DeleteOldSessions(userId)
|
||||
err = service.db.DeleteOldSessions(ctx, userId)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
@@ -485,7 +485,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error)
|
||||
|
||||
session := types.NewSession(sessionId, userId, createAt, expiresAt)
|
||||
|
||||
err = service.db.InsertSession(session)
|
||||
err = service.db.InsertSession(ctx, session)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -13,13 +14,13 @@ import (
|
||||
)
|
||||
|
||||
type Transaction interface {
|
||||
Add(tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||
Update(user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||
Get(user *types.User, id string) (*types.Transaction, error)
|
||||
GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
|
||||
Delete(user *types.User, id string) error
|
||||
Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||
Update(ctx context.Context, user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||
Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error)
|
||||
GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
|
||||
Delete(ctx context.Context, user *types.User, id string) error
|
||||
|
||||
RecalculateBalances(user *types.User) error
|
||||
RecalculateBalances(ctx context.Context, user *types.User) error
|
||||
}
|
||||
|
||||
type TransactionImpl struct {
|
||||
@@ -36,7 +37,7 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
|
||||
}
|
||||
}
|
||||
|
||||
func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
|
||||
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -45,7 +46,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
||||
ownsTransaction := false
|
||||
if tx == nil {
|
||||
ownsTransaction = true
|
||||
tx, err = s.db.Beginx()
|
||||
tx, err = s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transaction Add", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -55,12 +56,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
||||
}()
|
||||
}
|
||||
|
||||
transaction, err := s.validateAndEnrichTransaction(tx, nil, user.Id, transactionInput)
|
||||
transaction, err := s.validateAndEnrichTransaction(ctx, tx, nil, user.Id, transactionInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := tx.NamedExec(`
|
||||
r, err := tx.NamedExecContext(ctx, `
|
||||
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
|
||||
party, description, error, created_at, created_by)
|
||||
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
|
||||
@@ -71,8 +72,8 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
||||
}
|
||||
|
||||
if transaction.Error == nil && transaction.AccountId != nil {
|
||||
r, err = tx.Exec(`
|
||||
UPDATE account
|
||||
r, err = tx.ExecContext(ctx, `
|
||||
UPDATE actx context.Context,ccount
|
||||
SET current_balance = current_balance + ?
|
||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||
err = db.TransformAndLogDbError("transaction Add", r, err)
|
||||
@@ -82,7 +83,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
||||
}
|
||||
|
||||
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
||||
r, err = tx.Exec(`
|
||||
r, err = tx.ExecContext(ctx, `
|
||||
UPDATE treasure_chest
|
||||
SET current_balance = current_balance + ?
|
||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||
@@ -103,12 +104,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
||||
return transaction, nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*types.Transaction, error) {
|
||||
func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transaction Update", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -118,7 +119,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
||||
}()
|
||||
|
||||
transaction := &types.Transaction{}
|
||||
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
|
||||
err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
|
||||
err = db.TransformAndLogDbError("transaction Update", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
@@ -128,7 +129,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
||||
}
|
||||
|
||||
if transaction.Error == nil && transaction.AccountId != nil {
|
||||
r, err := tx.Exec(`
|
||||
r, err := tx.ExecContext(ctx, `
|
||||
UPDATE account
|
||||
SET current_balance = current_balance - ?
|
||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||
@@ -138,7 +139,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
||||
}
|
||||
}
|
||||
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
||||
r, err := tx.Exec(`
|
||||
r, err := tx.ExecContext(ctx, `
|
||||
UPDATE treasure_chest
|
||||
SET current_balance = current_balance - ?
|
||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||
@@ -148,13 +149,13 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
||||
}
|
||||
}
|
||||
|
||||
transaction, err = s.validateAndEnrichTransaction(tx, transaction, user.Id, input)
|
||||
transaction, err = s.validateAndEnrichTransaction(ctx, tx, transaction, user.Id, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if transaction.Error == nil && transaction.AccountId != nil {
|
||||
r, err := tx.Exec(`
|
||||
r, err := tx.ExecContext(ctx, `
|
||||
UPDATE account
|
||||
SET current_balance = current_balance + ?
|
||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||
@@ -164,7 +165,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
||||
}
|
||||
}
|
||||
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
||||
r, err := tx.Exec(`
|
||||
r, err := tx.ExecContext(ctx, `
|
||||
UPDATE treasure_chest
|
||||
SET current_balance = current_balance + ?
|
||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||
@@ -174,7 +175,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
||||
}
|
||||
}
|
||||
|
||||
r, err := tx.NamedExec(`
|
||||
r, err := tx.NamedExecContext(ctx, `
|
||||
UPDATE "transaction"
|
||||
SET
|
||||
account_id = :account_id,
|
||||
@@ -202,7 +203,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
||||
return transaction, nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, error) {
|
||||
func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -213,7 +214,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
|
||||
}
|
||||
|
||||
var transaction types.Transaction
|
||||
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("transaction Get", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
@@ -225,13 +226,13 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
|
||||
return &transaction, nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
|
||||
func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
transactions := make([]*types.Transaction, 0)
|
||||
err := s.db.Select(&transactions, `
|
||||
err := s.db.SelectContext(ctx, &transactions, `
|
||||
SELECT *
|
||||
FROM "transaction"
|
||||
WHERE user_id = ?
|
||||
@@ -254,7 +255,7 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) Delete(user *types.User, id string) error {
|
||||
func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
@@ -264,7 +265,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transaction Delete", nil, err)
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -274,14 +275,14 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
||||
}()
|
||||
|
||||
var transaction types.Transaction
|
||||
err = tx.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = tx.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("transaction Delete", nil, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if transaction.Error == nil && transaction.AccountId != nil {
|
||||
r, err := tx.Exec(`
|
||||
r, err := tx.ExecContext(ctx, `
|
||||
UPDATE account
|
||||
SET current_balance = current_balance - ?
|
||||
WHERE id = ?
|
||||
@@ -293,7 +294,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
||||
}
|
||||
|
||||
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
||||
r, err := tx.Exec(`
|
||||
r, err := tx.ExecContext(ctx, `
|
||||
UPDATE treasure_chest
|
||||
SET current_balance = current_balance - ?
|
||||
WHERE id = ?
|
||||
@@ -304,7 +305,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
||||
}
|
||||
}
|
||||
|
||||
r, err := tx.Exec("DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id)
|
||||
r, err := tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id)
|
||||
err = db.TransformAndLogDbError("transaction Delete", r, err)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -319,12 +320,12 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
||||
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -333,7 +334,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
r, err := tx.Exec(`
|
||||
r, err := tx.ExecContext(ctx, `
|
||||
UPDATE account
|
||||
SET current_balance = 0
|
||||
WHERE user_id = ?`, user.Id)
|
||||
@@ -342,7 +343,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
r, err = tx.Exec(`
|
||||
r, err = tx.ExecContext(ctx, `
|
||||
UPDATE treasure_chest
|
||||
SET current_balance = 0
|
||||
WHERE user_id = ?`, user.Id)
|
||||
@@ -351,7 +352,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := tx.Queryx(`
|
||||
rows, err := tx.QueryxContext(ctx, `
|
||||
SELECT *
|
||||
FROM "transaction"
|
||||
WHERE user_id = ?`, user.Id)
|
||||
@@ -375,7 +376,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
||||
}
|
||||
|
||||
s.updateErrors(&transaction)
|
||||
r, err = tx.Exec(`
|
||||
r, err = tx.ExecContext(ctx, `
|
||||
UPDATE "transaction"
|
||||
SET error = ?
|
||||
WHERE user_id = ?
|
||||
@@ -390,7 +391,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
||||
}
|
||||
|
||||
if transaction.AccountId != nil {
|
||||
r, err = tx.Exec(`
|
||||
r, err = tx.ExecContext(ctx, `
|
||||
UPDATE account
|
||||
SET current_balance = current_balance + ?
|
||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||
@@ -400,7 +401,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
||||
}
|
||||
}
|
||||
if transaction.TreasureChestId != nil {
|
||||
r, err = tx.Exec(`
|
||||
r, err = tx.ExecContext(ctx, `
|
||||
UPDATE treasure_chest
|
||||
SET current_balance = current_balance + ?
|
||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||
@@ -420,7 +421,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) {
|
||||
func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) {
|
||||
var (
|
||||
id uuid.UUID
|
||||
createdAt time.Time
|
||||
@@ -449,7 +450,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
||||
}
|
||||
|
||||
if input.AccountId != nil {
|
||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId)
|
||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId)
|
||||
err = db.TransformAndLogDbError("transaction validate", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -462,7 +463,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
||||
|
||||
if input.TreasureChestId != nil {
|
||||
var treasureChest types.TreasureChest
|
||||
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
|
||||
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
|
||||
err = db.TransformAndLogDbError("transaction validate", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -15,14 +16,14 @@ import (
|
||||
)
|
||||
|
||||
type TransactionRecurring interface {
|
||||
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||
GetAll(user *types.User) ([]*types.TransactionRecurring, error)
|
||||
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error)
|
||||
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||
Delete(user *types.User, id string) error
|
||||
Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||
Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||
GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error)
|
||||
GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error)
|
||||
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||
Delete(ctx context.Context, user *types.User, id string) error
|
||||
|
||||
GenerateTransactions(user *types.User) error
|
||||
GenerateTransactions(ctx context.Context, user *types.User) error
|
||||
}
|
||||
|
||||
type TransactionRecurringImpl struct {
|
||||
@@ -41,7 +42,7 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio
|
||||
}
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) Add(
|
||||
func (s TransactionRecurringImpl) Add(ctx context.Context,
|
||||
user *types.User,
|
||||
transactionRecurringInput types.TransactionRecurringInput,
|
||||
) (*types.TransactionRecurring, error) {
|
||||
@@ -49,7 +50,7 @@ func (s TransactionRecurringImpl) Add(
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transactionRecurring Add", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -58,12 +59,12 @@ func (s TransactionRecurringImpl) Add(
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
|
||||
transactionRecurring, err := s.validateAndEnrichTransactionRecurring(tx, nil, user.Id, transactionRecurringInput)
|
||||
transactionRecurring, err := s.validateAndEnrichTransactionRecurring(ctx, tx, nil, user.Id, transactionRecurringInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := tx.NamedExec(`
|
||||
r, err := tx.NamedExecContext(ctx, `
|
||||
INSERT INTO "transaction_recurring" (id, user_id, interval_months,
|
||||
next_execution, party, description, account_id, treasure_chest_id, value, created_at, created_by)
|
||||
VALUES (:id, :user_id, :interval_months,
|
||||
@@ -83,7 +84,7 @@ func (s TransactionRecurringImpl) Add(
|
||||
return transactionRecurring, nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) Update(
|
||||
func (s TransactionRecurringImpl) Update(ctx context.Context,
|
||||
user *types.User,
|
||||
input types.TransactionRecurringInput,
|
||||
) (*types.TransactionRecurring, error) {
|
||||
@@ -96,7 +97,7 @@ func (s TransactionRecurringImpl) Update(
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -106,7 +107,7 @@ func (s TransactionRecurringImpl) Update(
|
||||
}()
|
||||
|
||||
transactionRecurring := &types.TransactionRecurring{}
|
||||
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
@@ -115,12 +116,12 @@ func (s TransactionRecurringImpl) Update(
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(tx, transactionRecurring, user.Id, input)
|
||||
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err := tx.NamedExec(`
|
||||
r, err := tx.NamedExecContext(ctx, `
|
||||
UPDATE transaction_recurring
|
||||
SET
|
||||
interval_months = :interval_months,
|
||||
@@ -148,13 +149,13 @@ func (s TransactionRecurringImpl) Update(
|
||||
return transactionRecurring, nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.TransactionRecurring, error) {
|
||||
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||
err := s.db.Select(&transactionRecurrings, `
|
||||
err := s.db.SelectContext(ctx, &transactionRecurrings, `
|
||||
SELECT *
|
||||
FROM transaction_recurring
|
||||
WHERE user_id = ?
|
||||
@@ -168,7 +169,7 @@ func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.Transaction
|
||||
return transactionRecurrings, nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
||||
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -179,7 +180,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
||||
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -189,7 +190,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
||||
}()
|
||||
|
||||
var rowCount int
|
||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
@@ -199,7 +200,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
||||
}
|
||||
|
||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||
err = tx.Select(&transactionRecurrings, `
|
||||
err = tx.SelectContext(ctx, &transactionRecurrings, `
|
||||
SELECT *
|
||||
FROM transaction_recurring
|
||||
WHERE user_id = ?
|
||||
@@ -220,7 +221,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
||||
return transactionRecurrings, nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
||||
func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
||||
user *types.User,
|
||||
treasureChestId string,
|
||||
) ([]*types.TransactionRecurring, error) {
|
||||
@@ -234,7 +235,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
||||
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -244,7 +245,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
||||
}()
|
||||
|
||||
var rowCount int
|
||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
@@ -254,7 +255,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
||||
}
|
||||
|
||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||
err = tx.Select(&transactionRecurrings, `
|
||||
err = tx.SelectContext(ctx, &transactionRecurrings, `
|
||||
SELECT *
|
||||
FROM transaction_recurring
|
||||
WHERE user_id = ?
|
||||
@@ -275,7 +276,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
||||
return transactionRecurrings, nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
|
||||
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
@@ -285,7 +286,7 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
|
||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -295,13 +296,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
|
||||
}()
|
||||
|
||||
var transactionRecurring types.TransactionRecurring
|
||||
err = tx.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = tx.GetContext(ctx, &transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r, err := tx.Exec("DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id)
|
||||
r, err := tx.ExecContext(ctx, "DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id)
|
||||
err = db.TransformAndLogDbError("transactionRecurring Delete", r, err)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -316,13 +317,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
|
||||
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
now := s.clock.Now()
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -332,7 +333,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
|
||||
}()
|
||||
|
||||
recurringTransactions := make([]*types.TransactionRecurring, 0)
|
||||
err = tx.Select(&recurringTransactions, `
|
||||
err = tx.SelectContext(ctx, &recurringTransactions, `
|
||||
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
|
||||
user.Id, now)
|
||||
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
||||
@@ -350,13 +351,13 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
|
||||
Value: transactionRecurring.Value,
|
||||
}
|
||||
|
||||
_, err = s.transaction.Add(tx, user, transaction)
|
||||
_, err = s.transaction.Add(ctx, tx, user, transaction)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0)
|
||||
r, err := tx.Exec(`UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`,
|
||||
r, err := tx.ExecContext(ctx, `UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`,
|
||||
nextExecution, transactionRecurring.Id, user.Id)
|
||||
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err)
|
||||
if err != nil {
|
||||
@@ -373,6 +374,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||
ctx context.Context,
|
||||
tx *sqlx.Tx,
|
||||
oldTransactionRecurring *types.TransactionRecurring,
|
||||
userId uuid.UUID,
|
||||
@@ -417,7 +419,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
||||
}
|
||||
accountUuid = &temp
|
||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
|
||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
|
||||
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -438,7 +440,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||
}
|
||||
treasureChestUuid = &temp
|
||||
var treasureChest types.TreasureChest
|
||||
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
||||
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
||||
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -13,11 +14,11 @@ import (
|
||||
)
|
||||
|
||||
type TreasureChest interface {
|
||||
Add(user *types.User, parentId, name string) (*types.TreasureChest, error)
|
||||
Update(user *types.User, id, parentId, name string) (*types.TreasureChest, error)
|
||||
Get(user *types.User, id string) (*types.TreasureChest, error)
|
||||
GetAll(user *types.User) ([]*types.TreasureChest, error)
|
||||
Delete(user *types.User, id string) error
|
||||
Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error)
|
||||
Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error)
|
||||
Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error)
|
||||
GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error)
|
||||
Delete(ctx context.Context, user *types.User, id string) error
|
||||
}
|
||||
|
||||
type TreasureChestImpl struct {
|
||||
@@ -34,7 +35,7 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
|
||||
}
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.TreasureChest, error) {
|
||||
func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -51,7 +52,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
|
||||
|
||||
var parentUuid *uuid.UUID
|
||||
if parentId != "" {
|
||||
parent, err := s.Get(user, parentId)
|
||||
parent, err := s.Get(ctx, user, parentId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -76,7 +77,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
|
||||
UpdatedBy: nil,
|
||||
}
|
||||
|
||||
r, err := s.db.NamedExec(`
|
||||
r, err := s.db.NamedExecContext(ctx, `
|
||||
INSERT INTO treasure_chest (id, parent_id, user_id, name, current_balance, created_at, created_by)
|
||||
VALUES (:id, :parent_id, :user_id, :name, :current_balance, :created_at, :created_by)`, treasureChest)
|
||||
err = db.TransformAndLogDbError("treasureChest Insert", r, err)
|
||||
@@ -87,7 +88,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
|
||||
return treasureChest, nil
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
|
||||
func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -101,7 +102,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -111,7 +112,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
||||
}()
|
||||
|
||||
treasureChest := &types.TreasureChest{}
|
||||
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
||||
err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
||||
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
@@ -122,12 +123,12 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
||||
|
||||
var parentUuid *uuid.UUID
|
||||
if parentId != "" {
|
||||
parent, err := s.Get(user, parentId)
|
||||
parent, err := s.Get(ctx, user, parentId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var childCount int
|
||||
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
|
||||
err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
|
||||
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -145,7 +146,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
||||
treasureChest.UpdatedAt = ×tamp
|
||||
treasureChest.UpdatedBy = &user.Id
|
||||
|
||||
r, err := tx.NamedExec(`
|
||||
r, err := tx.NamedExecContext(ctx, `
|
||||
UPDATE treasure_chest
|
||||
SET
|
||||
parent_id = :parent_id,
|
||||
@@ -169,7 +170,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
||||
return treasureChest, nil
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChest, error) {
|
||||
func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
@@ -180,7 +181,7 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
|
||||
}
|
||||
|
||||
var treasureChest types.TreasureChest
|
||||
err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError("treasureChest Get", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
@@ -192,13 +193,13 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
|
||||
return &treasureChest, nil
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) {
|
||||
func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
}
|
||||
|
||||
treasureChests := make([]*types.TreasureChest, 0)
|
||||
err := s.db.Select(&treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id)
|
||||
err := s.db.SelectContext(ctx, &treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id)
|
||||
err = db.TransformAndLogDbError("treasureChest GetAll", nil, err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -207,7 +208,7 @@ func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, err
|
||||
return sortTree(treasureChests), nil
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
||||
func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
@@ -217,7 +218,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.Beginx()
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -227,7 +228,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
||||
}()
|
||||
|
||||
childCount := 0
|
||||
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
|
||||
err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
|
||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -238,7 +239,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
||||
}
|
||||
|
||||
transactionsCount := 0
|
||||
err = tx.Get(&transactionsCount,
|
||||
err = tx.GetContext(ctx, &transactionsCount,
|
||||
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
|
||||
user.Id, id)
|
||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||
@@ -250,7 +251,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
||||
}
|
||||
|
||||
recurringCount := 0
|
||||
err = tx.Get(&recurringCount, `
|
||||
err = tx.GetContext(ctx, &recurringCount, `
|
||||
SELECT COUNT(*) FROM transaction_recurring WHERE user_id = ? AND treasure_chest_id = ?`,
|
||||
user.Id, id)
|
||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||
@@ -261,7 +262,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
||||
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest)
|
||||
}
|
||||
|
||||
r, err := tx.Exec(`DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
||||
r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
||||
err = db.TransformAndLogDbError("treasureChest Delete", r, err)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -29,7 +29,7 @@ func DoRedirect(w http.ResponseWriter, r *http.Request, url string) {
|
||||
}
|
||||
}
|
||||
|
||||
func WaitMinimumTime[T interface{}](waitTime time.Duration, f func() (T, error)) (T, error) {
|
||||
func WaitMinimumTime[T any](waitTime time.Duration, f func() (T, error)) (T, error) {
|
||||
start := time.Now()
|
||||
result, err := f()
|
||||
time.Sleep(waitTime - time.Since(start))
|
||||
|
||||
Reference in New Issue
Block a user