diff --git a/go.mod b/go.mod index 2055b3d..2e70bea 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/mattn/go-sqlite3 v1.14.28 github.com/stretchr/testify v1.10.0 + github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2 go.opentelemetry.io/contrib/bridges/otelslog v0.11.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 go.opentelemetry.io/otel v1.36.0 @@ -38,6 +39,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 // indirect go.opentelemetry.io/otel/metric v1.36.0 // indirect diff --git a/go.sum b/go.sum index 64654f5..d345288 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,10 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2 h1:ZjUj9BLYf9PEqBn8W/OapxhPjVRdC6CsXTdULHsyk5c= +github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2/go.mod h1:O8bHQfyinKwTXKkiKNGmLQS7vRsqRxIQTFZpYpHK3IQ= +github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2 h1:zA9ZXfdtowo0EKt+t7uqXNlHxPeygrxuFSIroiBVgPU= +github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2/go.mod h1:ySXmuW9JLCm/TjsQksuMY/7MNiWqfHnhH2xeT34uOLU= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/bridges/otelslog v0.11.0 h1:EMIiYTms4Z4m3bBuKp1VmMNRLZcl6j4YbvOPL1IhlWo= diff --git a/internal/db/auth.go b/internal/db/auth.go index ac47439..e2d2715 100644 --- a/internal/db/auth.go +++ b/internal/db/auth.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "errors" "log/slog" @@ -13,23 +14,23 @@ import ( ) type Auth interface { - InsertUser(user *types.User) error - UpdateUser(user *types.User) error - GetUserByEmail(email string) (*types.User, error) - GetUser(userId uuid.UUID) (*types.User, error) - DeleteUser(userId uuid.UUID) error + InsertUser(ctx context.Context, user *types.User) error + UpdateUser(ctx context.Context, user *types.User) error + GetUserByEmail(ctx context.Context, email string) (*types.User, error) + GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) + DeleteUser(ctx context.Context, userId uuid.UUID) error - InsertToken(token *types.Token) error - GetToken(token string) (*types.Token, error) - GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) - GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) - DeleteToken(token string) error + InsertToken(ctx context.Context, token *types.Token) error + GetToken(ctx context.Context, token string) (*types.Token, error) + GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) + GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) + DeleteToken(ctx context.Context, token string) error - InsertSession(session *types.Session) error - GetSession(sessionId string) (*types.Session, error) - GetSessions(userId uuid.UUID) ([]*types.Session, error) - DeleteSession(sessionId string) error - DeleteOldSessions(userId uuid.UUID) error + InsertSession(ctx context.Context, session *types.Session) error + GetSession(ctx context.Context, sessionId string) (*types.Session, error) + GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) + DeleteSession(ctx context.Context, sessionId string) error + DeleteOldSessions(ctx context.Context, userId uuid.UUID) error } type AuthSqlite struct { @@ -40,8 +41,8 @@ func NewAuthSqlite(db *sqlx.DB) *AuthSqlite { return &AuthSqlite{db: db} } -func (db AuthSqlite) InsertUser(user *types.User) error { - _, err := db.db.Exec(` +func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error { + _, err := db.db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt) @@ -58,8 +59,8 @@ func (db AuthSqlite) InsertUser(user *types.User) error { return nil } -func (db AuthSqlite) UpdateUser(user *types.User) error { - _, err := db.db.Exec(` +func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error { + _, err := db.db.ExecContext(ctx, ` UPDATE user SET email_verified = ?, email_verified_at = ?, password = ? WHERE user_id = ?`, @@ -73,7 +74,7 @@ func (db AuthSqlite) UpdateUser(user *types.User) error { return nil } -func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) { +func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.User, error) { var ( userId uuid.UUID emailVerified bool @@ -84,7 +85,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) { createdAt time.Time ) - err := db.db.QueryRow(` + err := db.db.QueryRowContext(ctx, ` SELECT user_id, email_verified, email_verified_at, password, salt, created_at FROM user WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) @@ -100,7 +101,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) { return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil } -func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) { +func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) { var ( email string emailVerified bool @@ -111,7 +112,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) { createdAt time.Time ) - err := db.db.QueryRow(` + err := db.db.QueryRowContext(ctx, ` SELECT email, email_verified, email_verified_at, password, salt, created_at FROM user WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) @@ -127,49 +128,49 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) { return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil } -func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { - tx, err := db.db.Begin() +func (db AuthSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error { + tx, err := db.db.BeginTx(ctx, nil) if err != nil { slog.Error("Could not start transaction", "err", err) return types.ErrInternal } - _, err = tx.Exec("DELETE FROM account WHERE user_id = ?", userId) + _, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.Error("Could not delete accounts", "err", err) return types.ErrInternal } - _, err = tx.Exec("DELETE FROM token WHERE user_id = ?", userId) + _, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.Error("Could not delete user tokens", "err", err) return types.ErrInternal } - _, err = tx.Exec("DELETE FROM session WHERE user_id = ?", userId) + _, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.Error("Could not delete sessions", "err", err) return types.ErrInternal } - _, err = tx.Exec("DELETE FROM user WHERE user_id = ?", userId) + _, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.Error("Could not delete user", "err", err) return types.ErrInternal } - _, err = tx.Exec("DELETE FROM treasure_chest WHERE user_id = ?", userId) + _, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.Error("Could not delete user", "err", err) return types.ErrInternal } - _, err = tx.Exec("DELETE FROM \"transaction\" WHERE user_id = ?", userId) + _, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.Error("Could not delete user", "err", err) @@ -185,8 +186,8 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { return nil } -func (db AuthSqlite) InsertToken(token *types.Token) error { - _, err := db.db.Exec(` +func (db AuthSqlite) InsertToken(ctx context.Context, token *types.Token) error { + _, err := db.db.ExecContext(ctx, ` INSERT INTO token (user_id, session_id, type, token, created_at, expires_at) VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt) @@ -198,7 +199,7 @@ func (db AuthSqlite) InsertToken(token *types.Token) error { return nil } -func (db AuthSqlite) GetToken(token string) (*types.Token, error) { +func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, error) { var ( userId uuid.UUID sessionId string @@ -209,7 +210,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) { expiresAt time.Time ) - err := db.db.QueryRow(` + err := db.db.QueryRowContext(ctx, ` SELECT user_id, session_id, type, created_at, expires_at FROM token WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr) @@ -239,8 +240,8 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) { return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil } -func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) { - query, err := db.db.Query(` +func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) { + query, err := db.db.QueryContext(ctx, ` SELECT token, created_at, expires_at FROM token WHERE user_id = ? @@ -254,8 +255,8 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types. return getTokensFromQuery(query, userId, "", tokenType) } -func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) { - query, err := db.db.Query(` +func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) { + query, err := db.db.QueryContext(ctx, ` SELECT token, created_at, expires_at FROM token WHERE session_id = ? @@ -312,8 +313,8 @@ func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tok return tokens, nil } -func (db AuthSqlite) DeleteToken(token string) error { - _, err := db.db.Exec("DELETE FROM token WHERE token = ?", token) +func (db AuthSqlite) DeleteToken(ctx context.Context, token string) error { + _, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token) if err != nil { slog.Error("Could not delete token", "err", err) return types.ErrInternal @@ -321,8 +322,8 @@ func (db AuthSqlite) DeleteToken(token string) error { return nil } -func (db AuthSqlite) InsertSession(session *types.Session) error { - _, err := db.db.Exec(` +func (db AuthSqlite) InsertSession(ctx context.Context, session *types.Session) error { + _, err := db.db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt) @@ -334,14 +335,14 @@ func (db AuthSqlite) InsertSession(session *types.Session) error { return nil } -func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { +func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.Session, error) { var ( userId uuid.UUID createdAt time.Time expiresAt time.Time ) - err := db.db.QueryRow(` + err := db.db.QueryRowContext(ctx, ` SELECT user_id, created_at, expires_at FROM session WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt) @@ -354,9 +355,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { return types.NewSession(sessionId, userId, createdAt, expiresAt), nil } -func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) { +func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) { var sessions []*types.Session - err := db.db.Select(&sessions, ` + err := db.db.SelectContext(ctx, &sessions, ` SELECT * FROM session WHERE user_id = ?`, userId) @@ -368,8 +369,8 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) { return sessions, nil } -func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error { - _, err := db.db.Exec(` +func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error { + _, err := db.db.ExecContext(ctx, ` DELETE FROM session WHERE expires_at < datetime('now') AND user_id = ?`, userId) @@ -380,9 +381,9 @@ func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error { return nil } -func (db AuthSqlite) DeleteSession(sessionId string) error { +func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error { if sessionId != "" { - _, err := db.db.Exec("DELETE FROM session WHERE session_id = ?", sessionId) + _, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId) if err != nil { slog.Error("Could not delete session", "err", err) return types.ErrInternal diff --git a/internal/db/migration.go b/internal/db/migration.go index f1b670f..6638ada 100644 --- a/internal/db/migration.go +++ b/internal/db/migration.go @@ -1,6 +1,7 @@ package db import ( + "context" "errors" "log/slog" "spend-sparrow/internal/types" @@ -20,7 +21,7 @@ func (l migrationLogger) Verbose() bool { return false } -func RunMigrations(db *sqlx.DB, pathPrefix string) error { +func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error { driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) if err != nil { slog.Error("Could not create Migration instance", "err", err) diff --git a/internal/default.go b/internal/default.go index ef99297..13ed5db 100644 --- a/internal/default.go +++ b/internal/default.go @@ -56,7 +56,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu } // init db - err = db.RunMigrations(database, migrationsPrefix) + err = db.RunMigrations(ctx, database, migrationsPrefix) if err != nil { return fmt.Errorf("could not run migrations: %w", err) } diff --git a/internal/handler/account.go b/internal/handler/account.go index 3ef2b7f..09ecd88 100644 --- a/internal/handler/account.go +++ b/internal/handler/account.go @@ -44,7 +44,7 @@ func (h AccountImpl) handleAccountPage() http.HandlerFunc { return } - accounts, err := h.s.GetAll(user) + accounts, err := h.s.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return @@ -72,7 +72,7 @@ func (h AccountImpl) handleAccountItemComp() http.HandlerFunc { return } - account, err := h.s.Get(user, id) + account, err := h.s.Get(r.Context(), user, id) if err != nil { handleError(w, r, err) return @@ -105,13 +105,13 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc { id := r.PathValue("id") name := r.FormValue("name") if id == "new" { - account, err = h.s.Add(user, name) + account, err = h.s.Add(r.Context(), user, name) if err != nil { handleError(w, r, err) return } } else { - account, err = h.s.UpdateName(user, id, name) + account, err = h.s.UpdateName(r.Context(), user, id, name) if err != nil { handleError(w, r, err) return @@ -135,7 +135,7 @@ func (h AccountImpl) handleDeleteAccount() http.HandlerFunc { id := r.PathValue("id") - err := h.s.Delete(user, id) + err := h.s.Delete(r.Context(), user, id) if err != nil { handleError(w, r, err) return diff --git a/internal/handler/auth.go b/internal/handler/auth.go index 27db437..73729d9 100644 --- a/internal/handler/auth.go +++ b/internal/handler/auth.go @@ -85,7 +85,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc { email := r.FormValue("email") password := r.FormValue("password") - session, user, err := handler.service.SignIn(session, email, password) + session, user, err := handler.service.SignIn(r.Context(), session, email, password) if err != nil { return nil, err } @@ -163,7 +163,7 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc { return } - go handler.service.SendVerificationMail(user.Id, user.Email) + go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email) _, err := w.Write([]byte("

Verification email sent

")) if err != nil { @@ -178,7 +178,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc { token := r.URL.Query().Get("token") - err := handler.service.VerifyUserEmail(token) + err := handler.service.VerifyUserEmail(r.Context(), token) isVerified := err == nil comp := auth.VerifyResponseComp(isVerified) @@ -203,13 +203,13 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc { _, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) { slog.Info("signing up", "email", email) - user, err := handler.service.SignUp(email, password) + user, err := handler.service.SignUp(r.Context(), email, password) if err != nil { return nil, err } slog.Info("Sending verification email", "to", user.Email) - go handler.service.SendVerificationMail(user.Id, user.Email) + go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email) return nil, nil }) @@ -239,7 +239,7 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc { session := middleware.GetSession(r) if session != nil { - err := handler.service.SignOut(session.Id) + err := handler.service.SignOut(r.Context(), session.Id) if err != nil { http.Error(w, "An error occurred", http.StatusInternalServerError) return @@ -288,7 +288,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { password := r.FormValue("password") - err := handler.service.DeleteAccount(user, password) + err := handler.service.DeleteAccount(r.Context(), user, password) if err != nil { if errors.Is(err, service.ErrInvalidCredentials) { utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest) @@ -334,7 +334,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { currPass := r.FormValue("current-password") newPass := r.FormValue("new-password") - err := handler.service.ChangePassword(user, session.Id, currPass, newPass) + err := handler.service.ChangePassword(r.Context(), user, session.Id, currPass, newPass) if err != nil { utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest) return @@ -370,7 +370,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { } _, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) { - err := handler.service.SendForgotPasswordMail(email) + err := handler.service.SendForgotPasswordMail(r.Context(), email) return nil, err }) @@ -396,7 +396,7 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { token := pageUrl.Query().Get("token") newPass := r.FormValue("new-password") - err = handler.service.ForgotPassword(token, newPass) + err = handler.service.ForgotPassword(r.Context(), token, newPass) if err != nil { utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest) } else { diff --git a/internal/handler/middleware/authenticate.go b/internal/handler/middleware/authenticate.go index bc7928a..c81fca2 100644 --- a/internal/handler/middleware/authenticate.go +++ b/internal/handler/middleware/authenticate.go @@ -17,13 +17,13 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sessionId := getSessionID(r) - session, user, _ := service.SignInSession(sessionId) + session, user, _ := service.SignInSession(r.Context(), sessionId) var err error // Always sign in anonymous // This way, we can always generate csrf tokens if session == nil { - session, err = service.SignInAnonymous() + session, err = service.SignInAnonymous(r.Context()) if err != nil { http.Error(w, "Internal Server Error", http.StatusInternalServerError) return diff --git a/internal/handler/middleware/cross_site_request_forgery.go b/internal/handler/middleware/cross_site_request_forgery.go index 65b7373..be3c77e 100644 --- a/internal/handler/middleware/cross_site_request_forgery.go +++ b/internal/handler/middleware/cross_site_request_forgery.go @@ -4,30 +4,26 @@ import ( "log/slog" "net/http" "spend-sparrow/internal/service" - "spend-sparrow/internal/types" "spend-sparrow/internal/utils" "strings" ) type csrfResponseWriter struct { http.ResponseWriter - auth service.Auth - session *types.Session + csrfToken string } -func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter { +func newCsrfResponseWriter(w http.ResponseWriter, csrfToken string) *csrfResponseWriter { return &csrfResponseWriter{ ResponseWriter: w, - auth: auth, - session: session, + csrfToken: csrfToken, } } func (rr *csrfResponseWriter) Write(data []byte) (int, error) { dataStr := string(data) - csrfToken, err := rr.auth.GetCsrfToken(rr.session) - if err == nil { - dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken) + if rr.csrfToken != "" { + dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", rr.csrfToken) } return rr.ResponseWriter.Write([]byte(dataStr)) @@ -37,6 +33,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session := GetSession(r) + ctx := r.Context() if r.Method == http.MethodPost || r.Method == http.MethodPut || @@ -44,7 +41,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler r.Method == http.MethodPatch { csrfToken := r.Header.Get("Csrf-Token") - if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { + if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(ctx, csrfToken, session.Id) { slog.Info("CSRF-Token not correct", "token", csrfToken) if r.Header.Get("Hx-Request") == "true" { utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) @@ -55,7 +52,17 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler } } - responseWriter := newCsrfResponseWriter(w, auth, session) + token, err := auth.GetCsrfToken(ctx, session) + if err != nil { + if r.Header.Get("Hx-Request") == "true" { + utils.TriggerToastWithStatus(w, r, "error", "Could not generate CSRF Token", http.StatusBadRequest) + } else { + http.Error(w, "Could not generate CSRF Token", http.StatusBadRequest) + } + return + } + + responseWriter := newCsrfResponseWriter(w, token) next.ServeHTTP(responseWriter, r) }) } diff --git a/internal/handler/middleware/generate_recurring_transactions.go b/internal/handler/middleware/generate_recurring_transactions.go index 5a5f1c9..16cf836 100644 --- a/internal/handler/middleware/generate_recurring_transactions.go +++ b/internal/handler/middleware/generate_recurring_transactions.go @@ -14,7 +14,7 @@ func GenerateRecurringTransactions(transactionRecurring service.TransactionRecur return } - _ = transactionRecurring.GenerateTransactions(user) + _ = transactionRecurring.GenerateTransactions(r.Context(), user) next.ServeHTTP(w, r) }) diff --git a/internal/handler/transaction.go b/internal/handler/transaction.go index d596360..9c28028 100644 --- a/internal/handler/transaction.go +++ b/internal/handler/transaction.go @@ -65,19 +65,19 @@ func (h TransactionImpl) handleTransactionPage() http.HandlerFunc { Error: r.URL.Query().Get("error"), } - transactions, err := h.s.GetAll(user, filter) + transactions, err := h.s.GetAll(r.Context(), user, filter) if err != nil { handleError(w, r, err) return } - accounts, err := h.account.GetAll(user) + accounts, err := h.account.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return } - treasureChests, err := h.treasureChest.GetAll(user) + treasureChests, err := h.treasureChest.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return @@ -105,13 +105,13 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc { return } - accounts, err := h.account.GetAll(user) + accounts, err := h.account.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return } - treasureChests, err := h.treasureChest.GetAll(user) + treasureChests, err := h.treasureChest.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return @@ -124,7 +124,7 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc { return } - transaction, err := h.s.Get(user, id) + transaction, err := h.s.Get(r.Context(), user, id) if err != nil { handleError(w, r, err) return @@ -212,26 +212,26 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc { var transaction *types.Transaction if idStr == "new" { - transaction, err = h.s.Add(nil, user, input) + transaction, err = h.s.Add(r.Context(), nil, user, input) if err != nil { handleError(w, r, err) return } } else { - transaction, err = h.s.Update(user, input) + transaction, err = h.s.Update(r.Context(), user, input) if err != nil { handleError(w, r, err) return } } - accounts, err := h.account.GetAll(user) + accounts, err := h.account.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return } - treasureChests, err := h.treasureChest.GetAll(user) + treasureChests, err := h.treasureChest.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return @@ -253,7 +253,7 @@ func (h TransactionImpl) handleRecalculate() http.HandlerFunc { return } - err := h.s.RecalculateBalances(user) + err := h.s.RecalculateBalances(r.Context(), user) if err != nil { handleError(w, r, err) return @@ -275,7 +275,7 @@ func (h TransactionImpl) handleDeleteTransaction() http.HandlerFunc { id := r.PathValue("id") - err := h.s.Delete(user, id) + err := h.s.Delete(r.Context(), user, id) if err != nil { handleError(w, r, err) return diff --git a/internal/handler/transaction_recurring.go b/internal/handler/transaction_recurring.go index 6ca659f..c084cef 100644 --- a/internal/handler/transaction_recurring.go +++ b/internal/handler/transaction_recurring.go @@ -70,13 +70,13 @@ func (h TransactionRecurringImpl) handleUpdateTransactionRecurring() http.Handle } if input.Id == "new" { - _, err := h.s.Add(user, input) + _, err := h.s.Add(r.Context(), user, input) if err != nil { handleError(w, r, err) return } } else { - _, err := h.s.Update(user, input) + _, err := h.s.Update(r.Context(), user, input) if err != nil { handleError(w, r, err) return @@ -101,7 +101,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle accountId := r.URL.Query().Get("account-id") treasureChestId := r.URL.Query().Get("treasure-chest-id") - err := h.s.Delete(user, id) + err := h.s.Delete(r.Context(), user, id) if err != nil { handleError(w, r, err) return @@ -118,13 +118,13 @@ func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Req utils.TriggerToastWithStatus(w, r, "error", "Please select an account or treasure chest", http.StatusBadRequest) } if accountId != "" { - transactionsRecurring, err = h.s.GetAllByAccount(user, accountId) + transactionsRecurring, err = h.s.GetAllByAccount(r.Context(), user, accountId) if err != nil { handleError(w, r, err) return } } else { - transactionsRecurring, err = h.s.GetAllByTreasureChest(user, treasureChestId) + transactionsRecurring, err = h.s.GetAllByTreasureChest(r.Context(), user, treasureChestId) if err != nil { handleError(w, r, err) return diff --git a/internal/handler/treasure_chest.go b/internal/handler/treasure_chest.go index 9423eff..a044610 100644 --- a/internal/handler/treasure_chest.go +++ b/internal/handler/treasure_chest.go @@ -48,13 +48,13 @@ func (h TreasureChestImpl) handleTreasureChestPage() http.HandlerFunc { return } - treasureChests, err := h.s.GetAll(user) + treasureChests, err := h.s.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return } - transactionsRecurring, err := h.transactionRecurring.GetAll(user) + transactionsRecurring, err := h.transactionRecurring.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return @@ -77,7 +77,7 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc { return } - treasureChests, err := h.s.GetAll(user) + treasureChests, err := h.s.GetAll(r.Context(), user) if err != nil { handleError(w, r, err) return @@ -90,13 +90,13 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc { return } - treasureChest, err := h.s.Get(user, id) + treasureChest, err := h.s.Get(r.Context(), user, id) if err != nil { handleError(w, r, err) return } - transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String()) + transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String()) if err != nil { handleError(w, r, err) return @@ -132,20 +132,20 @@ func (h TreasureChestImpl) handleUpdateTreasureChest() http.HandlerFunc { parentId := r.FormValue("parent-id") name := r.FormValue("name") if id == "new" { - treasureChest, err = h.s.Add(user, parentId, name) + treasureChest, err = h.s.Add(r.Context(), user, parentId, name) if err != nil { handleError(w, r, err) return } } else { - treasureChest, err = h.s.Update(user, id, parentId, name) + treasureChest, err = h.s.Update(r.Context(), user, id, parentId, name) if err != nil { handleError(w, r, err) return } } - transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String()) + transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String()) if err != nil { handleError(w, r, err) return @@ -171,7 +171,7 @@ func (h TreasureChestImpl) handleDeleteTreasureChest() http.HandlerFunc { id := r.PathValue("id") - err := h.s.Delete(user, id) + err := h.s.Delete(r.Context(), user, id) if err != nil { handleError(w, r, err) return diff --git a/internal/service/account.go b/internal/service/account.go index f05513b..4e1ff3a 100644 --- a/internal/service/account.go +++ b/internal/service/account.go @@ -1,6 +1,7 @@ package service import ( + "context" "errors" "fmt" "log/slog" @@ -12,11 +13,11 @@ import ( ) type Account interface { - Add(user *types.User, name string) (*types.Account, error) - UpdateName(user *types.User, id string, name string) (*types.Account, error) - Get(user *types.User, id string) (*types.Account, error) - GetAll(user *types.User) ([]*types.Account, error) - Delete(user *types.User, id string) error + Add(ctx context.Context, user *types.User, name string) (*types.Account, error) + UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error) + Get(ctx context.Context, user *types.User, id string) (*types.Account, error) + GetAll(ctx context.Context, user *types.User) ([]*types.Account, error) + Delete(ctx context.Context, user *types.User, id string) error } type AccountImpl struct { @@ -33,7 +34,7 @@ func NewAccount(db *sqlx.DB, random Random, clock Clock) Account { } } -func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) { +func (s AccountImpl) Add(ctx context.Context, user *types.User, name string) (*types.Account, error) { if user == nil { return nil, ErrUnauthorized } @@ -64,7 +65,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) UpdatedBy: nil, } - r, err := s.db.NamedExec(` + r, err := s.db.NamedExecContext(ctx, ` INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by) VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account) err = db.TransformAndLogDbError("account Insert", r, err) @@ -75,7 +76,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) return account, nil } -func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) { +func (s AccountImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error) { if user == nil { return nil, ErrUnauthorized } @@ -89,7 +90,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("account Update", nil, err) if err != nil { return nil, err @@ -99,7 +100,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type }() var account types.Account - err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("account Update", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -113,7 +114,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type account.UpdatedAt = ×tamp account.UpdatedBy = &user.Id - r, err := tx.NamedExec(` + r, err := tx.NamedExecContext(ctx, ` UPDATE account SET name = :name, @@ -135,7 +136,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type return &account, nil } -func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { +func (s AccountImpl) Get(ctx context.Context, user *types.User, id string) (*types.Account, error) { if user == nil { return nil, ErrUnauthorized } @@ -146,7 +147,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { } var account types.Account - err = s.db.Get(&account, ` + err = s.db.GetContext(ctx, &account, ` SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("account Get", nil, err) if err != nil { @@ -157,13 +158,13 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { return &account, nil } -func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { +func (s AccountImpl) GetAll(ctx context.Context, user *types.User) ([]*types.Account, error) { if user == nil { return nil, ErrUnauthorized } accounts := make([]*types.Account, 0) - err := s.db.Select(&accounts, ` + err := s.db.SelectContext(ctx, &accounts, ` SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id) err = db.TransformAndLogDbError("account GetAll", nil, err) if err != nil { @@ -173,7 +174,7 @@ func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { return accounts, nil } -func (s AccountImpl) Delete(user *types.User, id string) error { +func (s AccountImpl) Delete(ctx context.Context, user *types.User, id string) error { if user == nil { return ErrUnauthorized } @@ -183,7 +184,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error { return fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("account Delete", nil, err) if err != nil { return err @@ -193,7 +194,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error { }() transactionsCount := 0 - err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid) + err = tx.GetContext(ctx, &transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("account Delete", nil, err) if err != nil { return err @@ -202,7 +203,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error { return fmt.Errorf("account has transactions, cannot delete: %w", ErrBadRequest) } - res, err := tx.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id) + res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id) err = db.TransformAndLogDbError("account Delete", res, err) if err != nil { return err diff --git a/internal/service/auth.go b/internal/service/auth.go index 29f0928..7d68184 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -26,24 +26,24 @@ var ( ) type Auth interface { - SignUp(email string, password string) (*types.User, error) - SendVerificationMail(userId uuid.UUID, email string) - VerifyUserEmail(token string) error + SignUp(ctx context.Context, email string, password string) (*types.User, error) + SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) + VerifyUserEmail(ctx context.Context, token string) error - SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) - SignInSession(sessionId string) (*types.Session, *types.User, error) - SignInAnonymous() (*types.Session, error) - SignOut(sessionId string) error + SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) + SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) + SignInAnonymous(ctx context.Context) (*types.Session, error) + SignOut(ctx context.Context, sessionId string) error - DeleteAccount(user *types.User, currPass string) error + DeleteAccount(ctx context.Context, user *types.User, currPass string) error - ChangePassword(user *types.User, sessionId string, currPass, newPass string) error + ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error - SendForgotPasswordMail(email string) error - ForgotPassword(token string, newPass string) error + SendForgotPasswordMail(ctx context.Context, email string) error + ForgotPassword(ctx context.Context, token string, newPass string) error - IsCsrfTokenValid(tokenStr string, sessionId string) bool - GetCsrfToken(session *types.Session) (string, error) + IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool + GetCsrfToken(ctx context.Context, session *types.Session) (string, error) } type AuthImpl struct { @@ -64,8 +64,8 @@ func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings * } } -func (service AuthImpl) SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) { - user, err := service.db.GetUserByEmail(email) +func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) { + user, err := service.db.GetUserByEmail(ctx, email) if err != nil { if errors.Is(err, db.ErrNotFound) { return nil, nil, ErrInvalidCredentials @@ -80,12 +80,12 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st return nil, nil, ErrInvalidCredentials } - err = service.cleanUpSessionWithTokens(session) + err = service.cleanUpSessionWithTokens(ctx, session) if err != nil { return nil, nil, types.ErrInternal } - session, err = service.createSession(user.Id) + session, err = service.createSession(ctx, user.Id) if err != nil { return nil, nil, types.ErrInternal } @@ -93,17 +93,17 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st return session, user, nil } -func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) { +func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) { if sessionId == "" { return nil, nil, ErrSessionIdInvalid } - session, err := service.db.GetSession(sessionId) + session, err := service.db.GetSession(ctx, sessionId) if err != nil { return nil, nil, types.ErrInternal } if session.ExpiresAt.Before(service.clock.Now()) { - _ = service.db.DeleteSession(sessionId) + _ = service.db.DeleteSession(ctx, sessionId) return nil, nil, nil } @@ -111,7 +111,7 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types. return session, nil, nil } - user, err := service.db.GetUser(session.UserId) + user, err := service.db.GetUser(ctx, session.UserId) if err != nil { return nil, nil, types.ErrInternal } @@ -119,8 +119,8 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types. return session, user, nil } -func (service AuthImpl) SignInAnonymous() (*types.Session, error) { - session, err := service.createSession(uuid.Nil) +func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, error) { + session, err := service.createSession(ctx, uuid.Nil) if err != nil { return nil, types.ErrInternal } @@ -130,7 +130,7 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) { return session, nil } -func (service AuthImpl) SignUp(email string, password string) (*types.User, error) { +func (service AuthImpl) SignUp(ctx context.Context, email string, password string) (*types.User, error) { _, err := mail.ParseAddress(email) if err != nil { return nil, ErrInvalidEmail @@ -154,7 +154,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) - err = service.db.InsertUser(user) + err = service.db.InsertUser(ctx, user) if err != nil { if errors.Is(err, db.ErrAlreadyExists) { return nil, ErrAccountExists @@ -166,8 +166,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro return user, nil } -func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { - tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify) +func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) { + tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify) if err != nil && !errors.Is(err, db.ErrNotFound) { return } @@ -192,7 +192,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { service.clock.Now(), service.clock.Now().Add(24*time.Hour)) - err = service.db.InsertToken(token) + err = service.db.InsertToken(ctx, token) if err != nil { return } @@ -208,17 +208,17 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { service.mail.SendMail(email, "Welcome to spend-sparrow", w.String()) } -func (service AuthImpl) VerifyUserEmail(tokenStr string) error { +func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error { if tokenStr == "" { return types.ErrInternal } - token, err := service.db.GetToken(tokenStr) + token, err := service.db.GetToken(ctx, tokenStr) if err != nil { return types.ErrInternal } - user, err := service.db.GetUser(token.UserId) + user, err := service.db.GetUser(ctx, token.UserId) if err != nil { return types.ErrInternal } @@ -236,21 +236,21 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error { user.EmailVerified = true user.EmailVerifiedAt = &now - err = service.db.UpdateUser(user) + err = service.db.UpdateUser(ctx, user) if err != nil { return types.ErrInternal } - _ = service.db.DeleteToken(token.Token) + _ = service.db.DeleteToken(ctx, token.Token) return nil } -func (service AuthImpl) SignOut(sessionId string) error { - return service.db.DeleteSession(sessionId) +func (service AuthImpl) SignOut(ctx context.Context, sessionId string) error { + return service.db.DeleteSession(ctx, sessionId) } -func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error { - userDb, err := service.db.GetUser(user.Id) +func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, currPass string) error { + userDb, err := service.db.GetUser(ctx, user.Id) if err != nil { return types.ErrInternal } @@ -260,7 +260,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error { return ErrInvalidCredentials } - err = service.db.DeleteUser(user.Id) + err = service.db.DeleteUser(ctx, user.Id) if err != nil { return err } @@ -270,7 +270,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error { return nil } -func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error { +func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error { if !isPasswordValid(newPass) { return ErrInvalidPassword } @@ -288,18 +288,18 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP newHash := GetHashPassword(newPass, user.Salt) user.Password = newHash - err := service.db.UpdateUser(user) + err := service.db.UpdateUser(ctx, user) if err != nil { return err } - sessions, err := service.db.GetSessions(user.Id) + sessions, err := service.db.GetSessions(ctx, user.Id) if err != nil { return types.ErrInternal } for _, s := range sessions { if s.Id != sessionId { - err = service.db.DeleteSession(s.Id) + err = service.db.DeleteSession(ctx, s.Id) if err != nil { return types.ErrInternal } @@ -309,13 +309,13 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP return nil } -func (service AuthImpl) SendForgotPasswordMail(email string) error { +func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string) error { tokenStr, err := service.random.String(32) if err != nil { return err } - user, err := service.db.GetUserByEmail(email) + user, err := service.db.GetUserByEmail(ctx, email) if err != nil { if errors.Is(err, db.ErrNotFound) { return nil @@ -332,7 +332,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error { service.clock.Now(), service.clock.Now().Add(15*time.Minute)) - err = service.db.InsertToken(token) + err = service.db.InsertToken(ctx, token) if err != nil { return types.ErrInternal } @@ -348,17 +348,17 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error { return nil } -func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { +func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error { if !isPasswordValid(newPass) { return ErrInvalidPassword } - token, err := service.db.GetToken(tokenStr) + token, err := service.db.GetToken(ctx, tokenStr) if err != nil { return ErrTokenInvalid } - err = service.db.DeleteToken(tokenStr) + err = service.db.DeleteToken(ctx, tokenStr) if err != nil { return err } @@ -368,7 +368,7 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { return ErrTokenInvalid } - user, err := service.db.GetUser(token.UserId) + user, err := service.db.GetUser(ctx, token.UserId) if err != nil { slog.Error("Could not get user from token", "err", err) return types.ErrInternal @@ -377,18 +377,18 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { passHash := GetHashPassword(newPass, user.Salt) user.Password = passHash - err = service.db.UpdateUser(user) + err = service.db.UpdateUser(ctx, user) if err != nil { return err } - sessions, err := service.db.GetSessions(user.Id) + sessions, err := service.db.GetSessions(ctx, user.Id) if err != nil { return types.ErrInternal } for _, session := range sessions { - err = service.db.DeleteSession(session.Id) + err = service.db.DeleteSession(ctx, session.Id) if err != nil { return types.ErrInternal } @@ -397,8 +397,8 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { return nil } -func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool { - token, err := service.db.GetToken(tokenStr) +func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool { + token, err := service.db.GetToken(ctx, tokenStr) if err != nil { return false } @@ -412,12 +412,12 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool return true } -func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) { +func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session) (string, error) { if session == nil { return "", types.ErrInternal } - tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf) + tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) if len(tokens) > 0 { return tokens[0].Token, nil @@ -435,7 +435,7 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) { types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(8*time.Hour)) - err = service.db.InsertToken(token) + err = service.db.InsertToken(ctx, token) if err != nil { return "", types.ErrInternal } @@ -445,22 +445,22 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) { return tokenStr, nil } -func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error { +func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error { if session == nil { return nil } - err := service.db.DeleteSession(session.Id) + err := service.db.DeleteSession(ctx, session.Id) if err != nil { return types.ErrInternal } - tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf) + tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) if err != nil { return types.ErrInternal } for _, token := range tokens { - err = service.db.DeleteToken(token.Token) + err = service.db.DeleteToken(ctx, token.Token) if err != nil { return types.ErrInternal } @@ -469,13 +469,13 @@ func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error { return nil } -func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) { +func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*types.Session, error) { sessionId, err := service.random.String(32) if err != nil { return nil, types.ErrInternal } - err = service.db.DeleteOldSessions(userId) + err = service.db.DeleteOldSessions(ctx, userId) if err != nil { return nil, types.ErrInternal } @@ -485,7 +485,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) session := types.NewSession(sessionId, userId, createAt, expiresAt) - err = service.db.InsertSession(session) + err = service.db.InsertSession(ctx, session) if err != nil { return nil, types.ErrInternal } diff --git a/internal/service/transaction.go b/internal/service/transaction.go index 3754a76..84cde90 100644 --- a/internal/service/transaction.go +++ b/internal/service/transaction.go @@ -1,6 +1,7 @@ package service import ( + "context" "errors" "fmt" "log/slog" @@ -13,13 +14,13 @@ import ( ) type Transaction interface { - Add(tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error) - Update(user *types.User, transaction types.Transaction) (*types.Transaction, error) - Get(user *types.User, id string) (*types.Transaction, error) - GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) - Delete(user *types.User, id string) error + Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error) + Update(ctx context.Context, user *types.User, transaction types.Transaction) (*types.Transaction, error) + Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) + GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) + Delete(ctx context.Context, user *types.User, id string) error - RecalculateBalances(user *types.User) error + RecalculateBalances(ctx context.Context, user *types.User) error } type TransactionImpl struct { @@ -36,7 +37,7 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction { } } -func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { +func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { if user == nil { return nil, ErrUnauthorized } @@ -45,7 +46,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ ownsTransaction := false if tx == nil { ownsTransaction = true - tx, err = s.db.Beginx() + tx, err = s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transaction Add", nil, err) if err != nil { return nil, err @@ -55,12 +56,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ }() } - transaction, err := s.validateAndEnrichTransaction(tx, nil, user.Id, transactionInput) + transaction, err := s.validateAndEnrichTransaction(ctx, tx, nil, user.Id, transactionInput) if err != nil { return nil, err } - r, err := tx.NamedExec(` + r, err := tx.NamedExecContext(ctx, ` INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp, party, description, error, created_at, created_by) VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp, @@ -71,8 +72,8 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ } if transaction.Error == nil && transaction.AccountId != nil { - r, err = tx.Exec(` - UPDATE account + r, err = tx.ExecContext(ctx, ` + UPDATE actx context.Context,ccount SET current_balance = current_balance + ? WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) err = db.TransformAndLogDbError("transaction Add", r, err) @@ -82,7 +83,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ } if transaction.Error == nil && transaction.TreasureChestId != nil { - r, err = tx.Exec(` + r, err = tx.ExecContext(ctx, ` UPDATE treasure_chest SET current_balance = current_balance + ? WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) @@ -103,12 +104,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ return transaction, nil } -func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*types.Transaction, error) { +func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) { if user == nil { return nil, ErrUnauthorized } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transaction Update", nil, err) if err != nil { return nil, err @@ -118,7 +119,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ }() transaction := &types.Transaction{} - err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id) + err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id) err = db.TransformAndLogDbError("transaction Update", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -128,7 +129,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ } if transaction.Error == nil && transaction.AccountId != nil { - r, err := tx.Exec(` + r, err := tx.ExecContext(ctx, ` UPDATE account SET current_balance = current_balance - ? WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) @@ -138,7 +139,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ } } if transaction.Error == nil && transaction.TreasureChestId != nil { - r, err := tx.Exec(` + r, err := tx.ExecContext(ctx, ` UPDATE treasure_chest SET current_balance = current_balance - ? WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) @@ -148,13 +149,13 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ } } - transaction, err = s.validateAndEnrichTransaction(tx, transaction, user.Id, input) + transaction, err = s.validateAndEnrichTransaction(ctx, tx, transaction, user.Id, input) if err != nil { return nil, err } if transaction.Error == nil && transaction.AccountId != nil { - r, err := tx.Exec(` + r, err := tx.ExecContext(ctx, ` UPDATE account SET current_balance = current_balance + ? WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) @@ -164,7 +165,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ } } if transaction.Error == nil && transaction.TreasureChestId != nil { - r, err := tx.Exec(` + r, err := tx.ExecContext(ctx, ` UPDATE treasure_chest SET current_balance = current_balance + ? WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) @@ -174,7 +175,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ } } - r, err := tx.NamedExec(` + r, err := tx.NamedExecContext(ctx, ` UPDATE "transaction" SET account_id = :account_id, @@ -202,7 +203,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ return transaction, nil } -func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, error) { +func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) { if user == nil { return nil, ErrUnauthorized } @@ -213,7 +214,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e } var transaction types.Transaction - err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("transaction Get", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -225,13 +226,13 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e return &transaction, nil } -func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) { +func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) { if user == nil { return nil, ErrUnauthorized } transactions := make([]*types.Transaction, 0) - err := s.db.Select(&transactions, ` + err := s.db.SelectContext(ctx, &transactions, ` SELECT * FROM "transaction" WHERE user_id = ? @@ -254,7 +255,7 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF return transactions, nil } -func (s TransactionImpl) Delete(user *types.User, id string) error { +func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error { if user == nil { return ErrUnauthorized } @@ -264,7 +265,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error { return fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transaction Delete", nil, err) if err != nil { return nil @@ -274,14 +275,14 @@ func (s TransactionImpl) Delete(user *types.User, id string) error { }() var transaction types.Transaction - err = tx.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = tx.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("transaction Delete", nil, err) if err != nil { return err } if transaction.Error == nil && transaction.AccountId != nil { - r, err := tx.Exec(` + r, err := tx.ExecContext(ctx, ` UPDATE account SET current_balance = current_balance - ? WHERE id = ? @@ -293,7 +294,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error { } if transaction.Error == nil && transaction.TreasureChestId != nil { - r, err := tx.Exec(` + r, err := tx.ExecContext(ctx, ` UPDATE treasure_chest SET current_balance = current_balance - ? WHERE id = ? @@ -304,7 +305,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error { } } - r, err := tx.Exec("DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id) + r, err := tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id) err = db.TransformAndLogDbError("transaction Delete", r, err) if err != nil { return err @@ -319,12 +320,12 @@ func (s TransactionImpl) Delete(user *types.User, id string) error { return nil } -func (s TransactionImpl) RecalculateBalances(user *types.User) error { +func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error { if user == nil { return ErrUnauthorized } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err) if err != nil { return err @@ -333,7 +334,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { _ = tx.Rollback() }() - r, err := tx.Exec(` + r, err := tx.ExecContext(ctx, ` UPDATE account SET current_balance = 0 WHERE user_id = ?`, user.Id) @@ -342,7 +343,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { return err } - r, err = tx.Exec(` + r, err = tx.ExecContext(ctx, ` UPDATE treasure_chest SET current_balance = 0 WHERE user_id = ?`, user.Id) @@ -351,7 +352,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { return err } - rows, err := tx.Queryx(` + rows, err := tx.QueryxContext(ctx, ` SELECT * FROM "transaction" WHERE user_id = ?`, user.Id) @@ -375,7 +376,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { } s.updateErrors(&transaction) - r, err = tx.Exec(` + r, err = tx.ExecContext(ctx, ` UPDATE "transaction" SET error = ? WHERE user_id = ? @@ -390,7 +391,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { } if transaction.AccountId != nil { - r, err = tx.Exec(` + r, err = tx.ExecContext(ctx, ` UPDATE account SET current_balance = current_balance + ? WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) @@ -400,7 +401,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { } } if transaction.TreasureChestId != nil { - r, err = tx.Exec(` + r, err = tx.ExecContext(ctx, ` UPDATE treasure_chest SET current_balance = current_balance + ? WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) @@ -420,7 +421,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { return nil } -func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) { +func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) { var ( id uuid.UUID createdAt time.Time @@ -449,7 +450,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio } if input.AccountId != nil { - err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId) + err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId) err = db.TransformAndLogDbError("transaction validate", nil, err) if err != nil { return nil, err @@ -462,7 +463,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio if input.TreasureChestId != nil { var treasureChest types.TreasureChest - err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId) + err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId) err = db.TransformAndLogDbError("transaction validate", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { diff --git a/internal/service/transaction_recurring.go b/internal/service/transaction_recurring.go index c4825fe..1de5dc2 100644 --- a/internal/service/transaction_recurring.go +++ b/internal/service/transaction_recurring.go @@ -1,6 +1,7 @@ package service import ( + "context" "errors" "fmt" "log/slog" @@ -15,14 +16,14 @@ import ( ) type TransactionRecurring interface { - Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) - Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) - GetAll(user *types.User) ([]*types.TransactionRecurring, error) - GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) - GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) - Delete(user *types.User, id string) error + Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) + Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) + GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) + GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) + GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) + Delete(ctx context.Context, user *types.User, id string) error - GenerateTransactions(user *types.User) error + GenerateTransactions(ctx context.Context, user *types.User) error } type TransactionRecurringImpl struct { @@ -41,7 +42,7 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio } } -func (s TransactionRecurringImpl) Add( +func (s TransactionRecurringImpl) Add(ctx context.Context, user *types.User, transactionRecurringInput types.TransactionRecurringInput, ) (*types.TransactionRecurring, error) { @@ -49,7 +50,7 @@ func (s TransactionRecurringImpl) Add( return nil, ErrUnauthorized } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transactionRecurring Add", nil, err) if err != nil { return nil, err @@ -58,12 +59,12 @@ func (s TransactionRecurringImpl) Add( _ = tx.Rollback() }() - transactionRecurring, err := s.validateAndEnrichTransactionRecurring(tx, nil, user.Id, transactionRecurringInput) + transactionRecurring, err := s.validateAndEnrichTransactionRecurring(ctx, tx, nil, user.Id, transactionRecurringInput) if err != nil { return nil, err } - r, err := tx.NamedExec(` + r, err := tx.NamedExecContext(ctx, ` INSERT INTO "transaction_recurring" (id, user_id, interval_months, next_execution, party, description, account_id, treasure_chest_id, value, created_at, created_by) VALUES (:id, :user_id, :interval_months, @@ -83,7 +84,7 @@ func (s TransactionRecurringImpl) Add( return transactionRecurring, nil } -func (s TransactionRecurringImpl) Update( +func (s TransactionRecurringImpl) Update(ctx context.Context, user *types.User, input types.TransactionRecurringInput, ) (*types.TransactionRecurring, error) { @@ -96,7 +97,7 @@ func (s TransactionRecurringImpl) Update( return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transactionRecurring Update", nil, err) if err != nil { return nil, err @@ -106,7 +107,7 @@ func (s TransactionRecurringImpl) Update( }() transactionRecurring := &types.TransactionRecurring{} - err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("transactionRecurring Update", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -115,12 +116,12 @@ func (s TransactionRecurringImpl) Update( return nil, types.ErrInternal } - transactionRecurring, err = s.validateAndEnrichTransactionRecurring(tx, transactionRecurring, user.Id, input) + transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input) if err != nil { return nil, err } - r, err := tx.NamedExec(` + r, err := tx.NamedExecContext(ctx, ` UPDATE transaction_recurring SET interval_months = :interval_months, @@ -148,13 +149,13 @@ func (s TransactionRecurringImpl) Update( return transactionRecurring, nil } -func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.TransactionRecurring, error) { +func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) { if user == nil { return nil, ErrUnauthorized } transactionRecurrings := make([]*types.TransactionRecurring, 0) - err := s.db.Select(&transactionRecurrings, ` + err := s.db.SelectContext(ctx, &transactionRecurrings, ` SELECT * FROM transaction_recurring WHERE user_id = ? @@ -168,7 +169,7 @@ func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.Transaction return transactionRecurrings, nil } -func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) { +func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) { if user == nil { return nil, ErrUnauthorized } @@ -179,7 +180,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err) if err != nil { return nil, err @@ -189,7 +190,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st }() var rowCount int - err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id) + err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id) err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -199,7 +200,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st } transactionRecurrings := make([]*types.TransactionRecurring, 0) - err = tx.Select(&transactionRecurrings, ` + err = tx.SelectContext(ctx, &transactionRecurrings, ` SELECT * FROM transaction_recurring WHERE user_id = ? @@ -220,7 +221,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st return transactionRecurrings, nil } -func (s TransactionRecurringImpl) GetAllByTreasureChest( +func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string, ) ([]*types.TransactionRecurring, error) { @@ -234,7 +235,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest( return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err) if err != nil { return nil, err @@ -244,7 +245,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest( }() var rowCount int - err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id) + err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id) err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -254,7 +255,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest( } transactionRecurrings := make([]*types.TransactionRecurring, 0) - err = tx.Select(&transactionRecurrings, ` + err = tx.SelectContext(ctx, &transactionRecurrings, ` SELECT * FROM transaction_recurring WHERE user_id = ? @@ -275,7 +276,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest( return transactionRecurrings, nil } -func (s TransactionRecurringImpl) Delete(user *types.User, id string) error { +func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error { if user == nil { return ErrUnauthorized } @@ -285,7 +286,7 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error { return fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err) if err != nil { return nil @@ -295,13 +296,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error { }() var transactionRecurring types.TransactionRecurring - err = tx.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = tx.GetContext(ctx, &transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err) if err != nil { return err } - r, err := tx.Exec("DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id) + r, err := tx.ExecContext(ctx, "DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id) err = db.TransformAndLogDbError("transactionRecurring Delete", r, err) if err != nil { return err @@ -316,13 +317,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error { return nil } -func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error { +func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error { if user == nil { return ErrUnauthorized } now := s.clock.Now() - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err) if err != nil { return err @@ -332,7 +333,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error { }() recurringTransactions := make([]*types.TransactionRecurring, 0) - err = tx.Select(&recurringTransactions, ` + err = tx.SelectContext(ctx, &recurringTransactions, ` SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`, user.Id, now) err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err) @@ -350,13 +351,13 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error { Value: transactionRecurring.Value, } - _, err = s.transaction.Add(tx, user, transaction) + _, err = s.transaction.Add(ctx, tx, user, transaction) if err != nil { return err } nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0) - r, err := tx.Exec(`UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`, + r, err := tx.ExecContext(ctx, `UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`, nextExecution, transactionRecurring.Id, user.Id) err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err) if err != nil { @@ -373,6 +374,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error { } func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( + ctx context.Context, tx *sqlx.Tx, oldTransactionRecurring *types.TransactionRecurring, userId uuid.UUID, @@ -417,7 +419,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) } accountUuid = &temp - err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId) + err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId) err = db.TransformAndLogDbError("transactionRecurring validate", nil, err) if err != nil { return nil, err @@ -438,7 +440,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( } treasureChestUuid = &temp var treasureChest types.TreasureChest - err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) + err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = db.TransformAndLogDbError("transactionRecurring validate", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { diff --git a/internal/service/treasure_chest.go b/internal/service/treasure_chest.go index 17570e0..521bb98 100644 --- a/internal/service/treasure_chest.go +++ b/internal/service/treasure_chest.go @@ -1,6 +1,7 @@ package service import ( + "context" "errors" "fmt" "log/slog" @@ -13,11 +14,11 @@ import ( ) type TreasureChest interface { - Add(user *types.User, parentId, name string) (*types.TreasureChest, error) - Update(user *types.User, id, parentId, name string) (*types.TreasureChest, error) - Get(user *types.User, id string) (*types.TreasureChest, error) - GetAll(user *types.User) ([]*types.TreasureChest, error) - Delete(user *types.User, id string) error + Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) + Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error) + Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) + GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) + Delete(ctx context.Context, user *types.User, id string) error } type TreasureChestImpl struct { @@ -34,7 +35,7 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest { } } -func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.TreasureChest, error) { +func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) { if user == nil { return nil, ErrUnauthorized } @@ -51,7 +52,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types. var parentUuid *uuid.UUID if parentId != "" { - parent, err := s.Get(user, parentId) + parent, err := s.Get(ctx, user, parentId) if err != nil { return nil, err } @@ -76,7 +77,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types. UpdatedBy: nil, } - r, err := s.db.NamedExec(` + r, err := s.db.NamedExecContext(ctx, ` INSERT INTO treasure_chest (id, parent_id, user_id, name, current_balance, created_at, created_by) VALUES (:id, :parent_id, :user_id, :name, :current_balance, :created_at, :created_by)`, treasureChest) err = db.TransformAndLogDbError("treasureChest Insert", r, err) @@ -87,7 +88,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types. return treasureChest, nil } -func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) { +func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) { if user == nil { return nil, ErrUnauthorized } @@ -101,7 +102,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("treasureChest Update", nil, err) if err != nil { return nil, err @@ -111,7 +112,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string }() treasureChest := &types.TreasureChest{} - err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id) + err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id) err = db.TransformAndLogDbError("treasureChest Update", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -122,12 +123,12 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string var parentUuid *uuid.UUID if parentId != "" { - parent, err := s.Get(user, parentId) + parent, err := s.Get(ctx, user, parentId) if err != nil { return nil, err } var childCount int - err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id) + err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id) err = db.TransformAndLogDbError("treasureChest Update", nil, err) if err != nil { return nil, err @@ -145,7 +146,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string treasureChest.UpdatedAt = ×tamp treasureChest.UpdatedBy = &user.Id - r, err := tx.NamedExec(` + r, err := tx.NamedExecContext(ctx, ` UPDATE treasure_chest SET parent_id = :parent_id, @@ -169,7 +170,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string return treasureChest, nil } -func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChest, error) { +func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) { if user == nil { return nil, ErrUnauthorized } @@ -180,7 +181,7 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes } var treasureChest types.TreasureChest - err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) + err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("treasureChest Get", nil, err) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -192,13 +193,13 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes return &treasureChest, nil } -func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) { +func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) { if user == nil { return nil, ErrUnauthorized } treasureChests := make([]*types.TreasureChest, 0) - err := s.db.Select(&treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id) + err := s.db.SelectContext(ctx, &treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id) err = db.TransformAndLogDbError("treasureChest GetAll", nil, err) if err != nil { return nil, err @@ -207,7 +208,7 @@ func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, err return sortTree(treasureChests), nil } -func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { +func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error { if user == nil { return ErrUnauthorized } @@ -217,7 +218,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { return fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - tx, err := s.db.Beginx() + tx, err := s.db.BeginTxx(ctx, nil) err = db.TransformAndLogDbError("treasureChest Delete", nil, err) if err != nil { return nil @@ -227,7 +228,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { }() childCount := 0 - err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id) + err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id) err = db.TransformAndLogDbError("treasureChest Delete", nil, err) if err != nil { return err @@ -238,7 +239,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { } transactionsCount := 0 - err = tx.Get(&transactionsCount, + err = tx.GetContext(ctx, &transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`, user.Id, id) err = db.TransformAndLogDbError("treasureChest Delete", nil, err) @@ -250,7 +251,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { } recurringCount := 0 - err = tx.Get(&recurringCount, ` + err = tx.GetContext(ctx, &recurringCount, ` SELECT COUNT(*) FROM transaction_recurring WHERE user_id = ? AND treasure_chest_id = ?`, user.Id, id) err = db.TransformAndLogDbError("treasureChest Delete", nil, err) @@ -261,7 +262,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest) } - r, err := tx.Exec(`DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id) + r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id) err = db.TransformAndLogDbError("treasureChest Delete", r, err) if err != nil { return err diff --git a/internal/utils/http.go b/internal/utils/http.go index a170dc0..49633de 100644 --- a/internal/utils/http.go +++ b/internal/utils/http.go @@ -29,7 +29,7 @@ func DoRedirect(w http.ResponseWriter, r *http.Request, url string) { } } -func WaitMinimumTime[T interface{}](waitTime time.Duration, f func() (T, error)) (T, error) { +func WaitMinimumTime[T any](waitTime time.Duration, f func() (T, error)) (T, error) { start := time.Now() result, err := f() time.Sleep(waitTime - time.Since(start)) diff --git a/main.go b/main.go index abd9bf8..fc36829 100644 --- a/main.go +++ b/main.go @@ -6,9 +6,11 @@ import ( "os" "spend-sparrow/internal" - "github.com/jmoiron/sqlx" "github.com/joho/godotenv" _ "github.com/mattn/go-sqlite3" + "github.com/uptrace/opentelemetry-go-extra/otelsql" + "github.com/uptrace/opentelemetry-go-extra/otelsqlx" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" ) func main() { @@ -18,7 +20,8 @@ func main() { return } - db, err := sqlx.Open("sqlite3", "./data/spend-sparrow.db") + db, err := otelsqlx.Open("sqlite3", "./data/spend-sparrow.db", + otelsql.WithAttributes(semconv.DBSystemSqlite)) if err != nil { slog.Error("Could not open Database data.db", "err", err) return diff --git a/test/auth_it_test.go b/test/auth_it_test.go index f411f7e..44be9d1 100644 --- a/test/auth_it_test.go +++ b/test/auth_it_test.go @@ -1,6 +1,7 @@ package test_test import ( + "context" "spend-sparrow/internal/db" "spend-sparrow/internal/types" "testing" @@ -26,7 +27,7 @@ func setupDb(t *testing.T) *sqlx.DB { } }) - err = db.RunMigrations(d, "../") + err = db.RunMigrations(context.Background(), d, "../") if err != nil { t.Fatalf("Error running migrations: %v", err) } @@ -47,14 +48,14 @@ func TestUser(t *testing.T) { createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) - err := underTest.InsertUser(expected) + err := underTest.InsertUser(context.Background(), expected) require.NoError(t, err) - actual, err := underTest.GetUser(expected.Id) + actual, err := underTest.GetUser(context.Background(), expected.Id) require.NoError(t, err) assert.Equal(t, expected, actual) - actual, err = underTest.GetUserByEmail(expected.Email) + actual, err = underTest.GetUserByEmail(context.Background(), expected.Email) require.NoError(t, err) assert.Equal(t, expected, actual) }) @@ -64,7 +65,7 @@ func TestUser(t *testing.T) { underTest := db.NewAuthSqlite(d) - _, err := underTest.GetUserByEmail("nonExistentEmail") + _, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail") assert.Equal(t, db.ErrNotFound, err) }) t.Run("should return ErrUserExist", func(t *testing.T) { @@ -77,10 +78,10 @@ func TestUser(t *testing.T) { createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) - err := underTest.InsertUser(user) + err := underTest.InsertUser(context.Background(), user) require.NoError(t, err) - err = underTest.InsertUser(user) + err = underTest.InsertUser(context.Background(), user) assert.Equal(t, db.ErrAlreadyExists, err) }) t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { @@ -92,7 +93,7 @@ func TestUser(t *testing.T) { createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) - err := underTest.InsertUser(user) + err := underTest.InsertUser(context.Background(), user) assert.Equal(t, types.ErrInternal, err) }) } @@ -110,21 +111,21 @@ func TestToken(t *testing.T) { expiresAt := createAt.Add(24 * time.Hour) expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt) - err := underTest.InsertToken(expected) + err := underTest.InsertToken(context.Background(), expected) require.NoError(t, err) - actual, err := underTest.GetToken(expected.Token) + actual, err := underTest.GetToken(context.Background(), expected.Token) require.NoError(t, err) assert.Equal(t, expected, actual) expected.SessionId = "" - actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type) + actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type) require.NoError(t, err) assert.Equal(t, []*types.Token{expected}, actuals) expected.SessionId = "sessionId" expected.UserId = uuid.Nil - actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type) + actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type) require.NoError(t, err) assert.Equal(t, []*types.Token{expected}, actuals) }) @@ -140,14 +141,14 @@ func TestToken(t *testing.T) { expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt) expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt) - err := underTest.InsertToken(expected1) + err := underTest.InsertToken(context.Background(), expected1) require.NoError(t, err) - err = underTest.InsertToken(expected2) + err = underTest.InsertToken(context.Background(), expected2) require.NoError(t, err) expected1.UserId = uuid.Nil expected2.UserId = uuid.Nil - actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type) + actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type) require.NoError(t, err) assert.Equal(t, []*types.Token{expected1, expected2}, actuals) @@ -155,7 +156,7 @@ func TestToken(t *testing.T) { expected2.SessionId = "" expected1.UserId = userId expected2.UserId = userId - actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type) + actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type) require.NoError(t, err) assert.Equal(t, []*types.Token{expected1, expected2}, actuals) }) @@ -165,13 +166,13 @@ func TestToken(t *testing.T) { underTest := db.NewAuthSqlite(d) - _, err := underTest.GetToken("nonExistent") + _, err := underTest.GetToken(context.Background(), "nonExistent") assert.Equal(t, db.ErrNotFound, err) - _, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify) + _, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), types.TokenTypeEmailVerify) assert.Equal(t, db.ErrNotFound, err) - _, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify) + _, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", types.TokenTypeEmailVerify) assert.Equal(t, db.ErrNotFound, err) }) t.Run("should return ErrAlreadyExists", func(t *testing.T) { @@ -184,10 +185,10 @@ func TestToken(t *testing.T) { createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) - err := underTest.InsertUser(user) + err := underTest.InsertUser(context.Background(), user) require.NoError(t, err) - err = underTest.InsertUser(user) + err = underTest.InsertUser(context.Background(), user) assert.Equal(t, db.ErrAlreadyExists, err) }) t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { @@ -199,7 +200,7 @@ func TestToken(t *testing.T) { createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) - err := underTest.InsertUser(user) + err := underTest.InsertUser(context.Background(), user) assert.Equal(t, types.ErrInternal, err) }) } diff --git a/test/auth_test.go b/test/auth_test.go index e14e539..b4f8807 100644 --- a/test/auth_test.go +++ b/test/auth_test.go @@ -1,6 +1,7 @@ package test_test import ( + "context" "spend-sparrow/internal/db" "spend-sparrow/internal/service" "spend-sparrow/internal/types" @@ -36,7 +37,7 @@ func TestSignUp(t *testing.T) { underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) - _, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!") + _, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!") assert.Equal(t, service.ErrInvalidEmail, err) }) @@ -58,7 +59,7 @@ func TestSignUp(t *testing.T) { } for _, password := range weakPasswords { - _, err := underTest.SignUp("some@valid.email", password) + _, err := underTest.SignUp(context.Background(), "some@valid.email", password) assert.Equal(t, service.ErrInvalidPassword, err) } }) @@ -81,10 +82,10 @@ func TestSignUp(t *testing.T) { mockRandom.EXPECT().UUID().Return(userId, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil) mockClock.EXPECT().Now().Return(createTime) - mockAuthDb.EXPECT().InsertUser(expected).Return(nil) + mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) - actual, err := underTest.SignUp(email, password) + actual, err := underTest.SignUp(context.Background(), email, password) require.NoError(t, err) @@ -109,11 +110,11 @@ func TestSignUp(t *testing.T) { mockRandom.EXPECT().Bytes(16).Return(salt, nil) mockClock.EXPECT().Now().Return(createTime) - mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists) + mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(db.ErrAlreadyExists) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) - _, err := underTest.SignUp(user.Email, password) + _, err := underTest.SignUp(context.Background(), user.Email, password) assert.Equal(t, service.ErrAccountExists, err) }) } @@ -140,7 +141,7 @@ func TestSendVerificationMail(t *testing.T) { mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) - mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify).Return(tokens, nil) + mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, types.TokenTypeEmailVerify).Return(tokens, nil) mockMail.EXPECT().SendMail(email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool { return strings.Contains(message, token.Token) @@ -148,6 +149,6 @@ func TestSendVerificationMail(t *testing.T) { underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) - underTest.SendVerificationMail(userId, email) + underTest.SendVerificationMail(context.Background(), userId, email) }) } diff --git a/test/it_test.go b/test/it_test.go index db1ed22..8358f59 100644 --- a/test/it_test.go +++ b/test/it_test.go @@ -182,16 +182,16 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s csrfToken := "my-verifying-token" + add email := add + "mail@mail.de" - _, err := db.Exec(` + _, err := db.ExecContext(context.Background(), ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt")) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(context.Background(), ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(context.Background(), ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf) require.NoError(t, err) diff --git a/test/main_it_test.go b/test/main_it_test.go index 8b81f5e..1de0b6d 100644 --- a/test/main_it_test.go +++ b/test/main_it_test.go @@ -112,11 +112,11 @@ func TestIntegrationAuth(t *testing.T) { sessionId := "session-id" pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) @@ -138,7 +138,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) @@ -165,7 +165,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) @@ -208,7 +208,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) @@ -248,7 +248,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) require.NoError(t, err) @@ -296,7 +296,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) require.NoError(t, err) @@ -415,7 +415,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) require.NoError(t, err) @@ -451,10 +451,10 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) - err = db.QueryRow("SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) }) @@ -469,11 +469,11 @@ func TestIntegrationAuth(t *testing.T) { sessionId := "session-id" pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) @@ -548,7 +548,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt")) require.NoError(t, err) @@ -627,11 +627,11 @@ func TestIntegrationAuth(t *testing.T) { assert.Contains(t, resp.Header.Get("Hx-Trigger"), "An activation link has been send to your email") var rows int - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE email = ? AND email_verified = FALSE", "mail@mail.de").Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE email = ? AND email_verified = FALSE", "mail@mail.de").Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) var token string - err = db.QueryRow("SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token) + err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token) require.NoError(t, err) assert.NotEmpty(t, token) }) @@ -644,7 +644,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) require.NoError(t, err) @@ -658,7 +658,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -670,11 +670,11 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() token := "my-outdated-verifying-token" - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO token (token, user_id, type, created_at, expires_at) VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify) require.NoError(t, err) @@ -688,7 +688,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -700,11 +700,11 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() token := "my-verifying-token" - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify) require.NoError(t, err) @@ -718,7 +718,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = TRUE", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = TRUE", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -747,11 +747,11 @@ func TestIntegrationAuth(t *testing.T) { sessionId := "session-id" pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) @@ -765,7 +765,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var csrfToken string - err = db.QueryRow("SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken) + err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken) require.NoError(t, err) req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil) @@ -785,7 +785,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, -1, cookie.MaxAge) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) }) @@ -825,13 +825,13 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) @@ -871,13 +871,13 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) @@ -964,22 +964,22 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) - err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ?", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ?", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ?", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ?", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) - err = db.QueryRow("SELECT COUNT(*) FROM account WHERE user_id = ?", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM account WHERE user_id = ?", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) - err = db.QueryRow("SELECT COUNT(*) FROM treasure_chest WHERE user_id = ?", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM treasure_chest WHERE user_id = ?", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) - err = db.QueryRow("SELECT COUNT(*) FROM \"transaction\" WHERE user_id = ?", userId).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM \"transaction\" WHERE user_id = ?", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) }) @@ -1040,13 +1040,13 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) @@ -1069,7 +1069,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -1080,13 +1080,13 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) @@ -1119,7 +1119,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -1130,13 +1130,13 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) @@ -1169,7 +1169,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -1181,21 +1181,21 @@ func TestIntegrationAuth(t *testing.T) { userIdOther := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES ("second", ?, datetime(), datetime("now", "+1 day"))`, userId) require.NoError(t, err) - _, err = db.Exec(` + _, err = db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES ("other", ?, datetime(), datetime("now", "+1 day"))`, userIdOther) require.NoError(t, err) @@ -1232,12 +1232,12 @@ func TestIntegrationAuth(t *testing.T) { pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt")) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) var sessionIds []string - sessions, err := db.Query(`SELECT session_id FROM session WHERE NOT user_id = ? ORDER BY session_id`, uuid.Nil) + sessions, err := db.QueryContext(ctx, `SELECT session_id FROM session WHERE NOT user_id = ? ORDER BY session_id`, uuid.Nil) require.NoError(t, err) for sessions.Next() { var sessionId string @@ -1260,13 +1260,13 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := d.Exec(` + _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) sessionId := "session-id" - _, err = d.Exec(` + _, err = d.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId) require.NoError(t, err) @@ -1288,7 +1288,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := d.Exec(` + _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) @@ -1317,7 +1317,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = d.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) + err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) }) @@ -1363,7 +1363,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := db.Exec(` + _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) @@ -1399,7 +1399,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg) var rows int - err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -1413,7 +1413,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := d.Exec(` + _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) @@ -1445,7 +1445,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) + err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -1456,7 +1456,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := d.Exec(` + _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) @@ -1473,7 +1473,7 @@ func TestIntegrationAuth(t *testing.T) { assert.NotEmpty(t, anonymousCsrfToken) token := "password-reset-token" - _, err = d.Exec(` + _, err = d.ExecContext(ctx, ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset) require.NoError(t, err) @@ -1494,7 +1494,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) + err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -1505,7 +1505,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := d.Exec(` + _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) @@ -1522,7 +1522,7 @@ func TestIntegrationAuth(t *testing.T) { assert.NotEmpty(t, anonymousCsrfToken) token := "password-reset-token" - _, err = d.Exec(` + _, err = d.ExecContext(ctx, ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset) require.NoError(t, err) @@ -1543,7 +1543,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) + err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -1554,12 +1554,12 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() pass := service.GetHashPassword("password", []byte("salt")) - _, err := d.Exec(` + _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) require.NoError(t, err) - _, err = d.Exec(` + _, err = d.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId) require.NoError(t, err) @@ -1590,7 +1590,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var token string - err = d.QueryRow("SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token) + err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token) require.NoError(t, err) formData = url.Values{ @@ -1608,7 +1608,7 @@ func TestIntegrationAuth(t *testing.T) { _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) - sessions, err := d.Query("SELECT session_id FROM session WHERE user_id = ?", userId) + sessions, err := d.QueryContext(ctx, "SELECT session_id FROM session WHERE user_id = ?", userId) require.NoError(t, err) assert.False(t, sessions.Next()) }) @@ -1623,11 +1623,11 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() sessionId := "session-id" - _, err := d.Exec(` + _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) require.NoError(t, err) - _, err = d.Exec(` + _, err = d.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId) require.NoError(t, err) @@ -1643,7 +1643,7 @@ func TestIntegrationAuth(t *testing.T) { assert.NotEqual(t, sessionId, newSession.Value) var rows int - err = d.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) + err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) }) @@ -1670,11 +1670,11 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() sessionId := "session-id" - _, err := d.Exec(` + _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) require.NoError(t, err) - _, err = d.Exec(` + _, err = d.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId) require.NoError(t, err) @@ -1769,7 +1769,7 @@ func TestIntegrationAccount(t *testing.T) { _ = resp.Body.Close() var id uuid.UUID - err = db.Get(&id, "SELECT id FROM account") + err = db.GetContext(ctx, &id, "SELECT id FROM account") require.NoError(t, err) // Update diff --git a/test/treasure_chest_it_test.go b/test/treasure_chest_it_test.go index ebdd0f4..c1e817d 100644 --- a/test/treasure_chest_it_test.go +++ b/test/treasure_chest_it_test.go @@ -22,7 +22,7 @@ func TestTreasureChestShouldNotDeleteIfTransactionRecurringExists(t *testing.T) assert.Equal(t, http.StatusOK, resp.StatusCode) var parentId string - err := db.Get(&parentId, "SELECT id FROM treasure_chest") + err := db.GetContext(ctx, &parentId, "SELECT id FROM treasure_chest") require.NoError(t, err) formData = url.Values{ @@ -33,7 +33,7 @@ func TestTreasureChestShouldNotDeleteIfTransactionRecurringExists(t *testing.T) assert.Equal(t, http.StatusOK, resp.StatusCode) var childId string - err = db.Get(&childId, "SELECT id FROM treasure_chest WHERE parent_id = ?", parentId) + err = db.GetContext(ctx, &childId, "SELECT id FROM treasure_chest WHERE parent_id = ?", parentId) require.NoError(t, err) formData = url.Values{