feat(observabillity): #153 instrument sqlx
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 8m11s

This commit is contained in:
2025-06-07 21:55:59 +02:00
parent c4aca2778f
commit 6291700a3b
26 changed files with 425 additions and 402 deletions

View File

@@ -24,6 +24,7 @@ linters:
- cyclop - cyclop
- contextcheck - contextcheck
- bodyclose # i don't care in the tests, the implementation itself doesn't do http requests - bodyclose # i don't care in the tests, the implementation itself doesn't do http requests
- containedctx
settings: settings:
nestif: nestif:
min-complexity: 6 min-complexity: 6

2
go.mod
View File

@@ -12,6 +12,7 @@ require (
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/mattn/go-sqlite3 v1.14.28 github.com/mattn/go-sqlite3 v1.14.28
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2
go.opentelemetry.io/contrib/bridges/otelslog v0.11.0 go.opentelemetry.io/contrib/bridges/otelslog v0.11.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0
go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel v1.36.0
@@ -38,6 +39,7 @@ require (
github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect github.com/stretchr/objx v0.5.2 // indirect
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 // indirect
go.opentelemetry.io/otel/metric v1.36.0 // indirect go.opentelemetry.io/otel/metric v1.36.0 // indirect

4
go.sum
View File

@@ -51,6 +51,10 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2 h1:ZjUj9BLYf9PEqBn8W/OapxhPjVRdC6CsXTdULHsyk5c=
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2/go.mod h1:O8bHQfyinKwTXKkiKNGmLQS7vRsqRxIQTFZpYpHK3IQ=
github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2 h1:zA9ZXfdtowo0EKt+t7uqXNlHxPeygrxuFSIroiBVgPU=
github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2/go.mod h1:ySXmuW9JLCm/TjsQksuMY/7MNiWqfHnhH2xeT34uOLU=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/bridges/otelslog v0.11.0 h1:EMIiYTms4Z4m3bBuKp1VmMNRLZcl6j4YbvOPL1IhlWo= go.opentelemetry.io/contrib/bridges/otelslog v0.11.0 h1:EMIiYTms4Z4m3bBuKp1VmMNRLZcl6j4YbvOPL1IhlWo=

View File

@@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"log/slog" "log/slog"
@@ -13,23 +14,23 @@ import (
) )
type Auth interface { type Auth interface {
InsertUser(user *types.User) error InsertUser(ctx context.Context, user *types.User) error
UpdateUser(user *types.User) error UpdateUser(ctx context.Context, user *types.User) error
GetUserByEmail(email string) (*types.User, error) GetUserByEmail(ctx context.Context, email string) (*types.User, error)
GetUser(userId uuid.UUID) (*types.User, error) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error)
DeleteUser(userId uuid.UUID) error DeleteUser(ctx context.Context, userId uuid.UUID) error
InsertToken(token *types.Token) error InsertToken(ctx context.Context, token *types.Token) error
GetToken(token string) (*types.Token, error) GetToken(ctx context.Context, token string) (*types.Token, error)
GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error)
DeleteToken(token string) error DeleteToken(ctx context.Context, token string) error
InsertSession(session *types.Session) error InsertSession(ctx context.Context, session *types.Session) error
GetSession(sessionId string) (*types.Session, error) GetSession(ctx context.Context, sessionId string) (*types.Session, error)
GetSessions(userId uuid.UUID) ([]*types.Session, error) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
DeleteSession(sessionId string) error DeleteSession(ctx context.Context, sessionId string) error
DeleteOldSessions(userId uuid.UUID) error DeleteOldSessions(ctx context.Context, userId uuid.UUID) error
} }
type AuthSqlite struct { type AuthSqlite struct {
@@ -40,8 +41,8 @@ func NewAuthSqlite(db *sqlx.DB) *AuthSqlite {
return &AuthSqlite{db: db} return &AuthSqlite{db: db}
} }
func (db AuthSqlite) InsertUser(user *types.User) error { func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error {
_, err := db.db.Exec(` _, err := db.db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt) user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt)
@@ -58,8 +59,8 @@ func (db AuthSqlite) InsertUser(user *types.User) error {
return nil return nil
} }
func (db AuthSqlite) UpdateUser(user *types.User) error { func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error {
_, err := db.db.Exec(` _, err := db.db.ExecContext(ctx, `
UPDATE user UPDATE user
SET email_verified = ?, email_verified_at = ?, password = ? SET email_verified = ?, email_verified_at = ?, password = ?
WHERE user_id = ?`, WHERE user_id = ?`,
@@ -73,7 +74,7 @@ func (db AuthSqlite) UpdateUser(user *types.User) error {
return nil return nil
} }
func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) { func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
emailVerified bool emailVerified bool
@@ -84,7 +85,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
createdAt time.Time createdAt time.Time
) )
err := db.db.QueryRow(` err := db.db.QueryRowContext(ctx, `
SELECT user_id, email_verified, email_verified_at, password, salt, created_at SELECT user_id, email_verified, email_verified_at, password, salt, created_at
FROM user FROM user
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
@@ -100,7 +101,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
} }
func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) { func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) {
var ( var (
email string email string
emailVerified bool emailVerified bool
@@ -111,7 +112,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
createdAt time.Time createdAt time.Time
) )
err := db.db.QueryRow(` err := db.db.QueryRowContext(ctx, `
SELECT email, email_verified, email_verified_at, password, salt, created_at SELECT email, email_verified, email_verified_at, password, salt, created_at
FROM user FROM user
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
@@ -127,49 +128,49 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
} }
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { func (db AuthSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error {
tx, err := db.db.Begin() tx, err := db.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
slog.Error("Could not start transaction", "err", err) slog.Error("Could not start transaction", "err", err)
return types.ErrInternal return types.ErrInternal
} }
_, err = tx.Exec("DELETE FROM account WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.Error("Could not delete accounts", "err", err) slog.Error("Could not delete accounts", "err", err)
return types.ErrInternal return types.ErrInternal
} }
_, err = tx.Exec("DELETE FROM token WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.Error("Could not delete user tokens", "err", err) slog.Error("Could not delete user tokens", "err", err)
return types.ErrInternal return types.ErrInternal
} }
_, err = tx.Exec("DELETE FROM session WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.Error("Could not delete sessions", "err", err) slog.Error("Could not delete sessions", "err", err)
return types.ErrInternal return types.ErrInternal
} }
_, err = tx.Exec("DELETE FROM user WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.Error("Could not delete user", "err", err) slog.Error("Could not delete user", "err", err)
return types.ErrInternal return types.ErrInternal
} }
_, err = tx.Exec("DELETE FROM treasure_chest WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.Error("Could not delete user", "err", err) slog.Error("Could not delete user", "err", err)
return types.ErrInternal return types.ErrInternal
} }
_, err = tx.Exec("DELETE FROM \"transaction\" WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.Error("Could not delete user", "err", err) slog.Error("Could not delete user", "err", err)
@@ -185,8 +186,8 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
return nil return nil
} }
func (db AuthSqlite) InsertToken(token *types.Token) error { func (db AuthSqlite) InsertToken(ctx context.Context, token *types.Token) error {
_, err := db.db.Exec(` _, err := db.db.ExecContext(ctx, `
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at) INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt) VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
@@ -198,7 +199,7 @@ func (db AuthSqlite) InsertToken(token *types.Token) error {
return nil return nil
} }
func (db AuthSqlite) GetToken(token string) (*types.Token, error) { func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
sessionId string sessionId string
@@ -209,7 +210,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
expiresAt time.Time expiresAt time.Time
) )
err := db.db.QueryRow(` err := db.db.QueryRowContext(ctx, `
SELECT user_id, session_id, type, created_at, expires_at SELECT user_id, session_id, type, created_at, expires_at
FROM token FROM token
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr) WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
@@ -239,8 +240,8 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
} }
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) { func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(` query, err := db.db.QueryContext(ctx, `
SELECT token, created_at, expires_at SELECT token, created_at, expires_at
FROM token FROM token
WHERE user_id = ? WHERE user_id = ?
@@ -254,8 +255,8 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.
return getTokensFromQuery(query, userId, "", tokenType) return getTokensFromQuery(query, userId, "", tokenType)
} }
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) { func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(` query, err := db.db.QueryContext(ctx, `
SELECT token, created_at, expires_at SELECT token, created_at, expires_at
FROM token FROM token
WHERE session_id = ? WHERE session_id = ?
@@ -312,8 +313,8 @@ func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tok
return tokens, nil return tokens, nil
} }
func (db AuthSqlite) DeleteToken(token string) error { func (db AuthSqlite) DeleteToken(ctx context.Context, token string) error {
_, err := db.db.Exec("DELETE FROM token WHERE token = ?", token) _, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token)
if err != nil { if err != nil {
slog.Error("Could not delete token", "err", err) slog.Error("Could not delete token", "err", err)
return types.ErrInternal return types.ErrInternal
@@ -321,8 +322,8 @@ func (db AuthSqlite) DeleteToken(token string) error {
return nil return nil
} }
func (db AuthSqlite) InsertSession(session *types.Session) error { func (db AuthSqlite) InsertSession(ctx context.Context, session *types.Session) error {
_, err := db.db.Exec(` _, err := db.db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt) VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
@@ -334,14 +335,14 @@ func (db AuthSqlite) InsertSession(session *types.Session) error {
return nil return nil
} }
func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.Session, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
createdAt time.Time createdAt time.Time
expiresAt time.Time expiresAt time.Time
) )
err := db.db.QueryRow(` err := db.db.QueryRowContext(ctx, `
SELECT user_id, created_at, expires_at SELECT user_id, created_at, expires_at
FROM session FROM session
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt) WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
@@ -354,9 +355,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
return types.NewSession(sessionId, userId, createdAt, expiresAt), nil return types.NewSession(sessionId, userId, createdAt, expiresAt), nil
} }
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) { func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) {
var sessions []*types.Session var sessions []*types.Session
err := db.db.Select(&sessions, ` err := db.db.SelectContext(ctx, &sessions, `
SELECT * SELECT *
FROM session FROM session
WHERE user_id = ?`, userId) WHERE user_id = ?`, userId)
@@ -368,8 +369,8 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
return sessions, nil return sessions, nil
} }
func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error { func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error {
_, err := db.db.Exec(` _, err := db.db.ExecContext(ctx, `
DELETE FROM session DELETE FROM session
WHERE expires_at < datetime('now') WHERE expires_at < datetime('now')
AND user_id = ?`, userId) AND user_id = ?`, userId)
@@ -380,9 +381,9 @@ func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
return nil return nil
} }
func (db AuthSqlite) DeleteSession(sessionId string) error { func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
if sessionId != "" { if sessionId != "" {
_, err := db.db.Exec("DELETE FROM session WHERE session_id = ?", sessionId) _, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
if err != nil { if err != nil {
slog.Error("Could not delete session", "err", err) slog.Error("Could not delete session", "err", err)
return types.ErrInternal return types.ErrInternal

View File

@@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"errors" "errors"
"log/slog" "log/slog"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
@@ -20,7 +21,7 @@ func (l migrationLogger) Verbose() bool {
return false return false
} }
func RunMigrations(db *sqlx.DB, pathPrefix string) error { func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error {
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
if err != nil { if err != nil {
slog.Error("Could not create Migration instance", "err", err) slog.Error("Could not create Migration instance", "err", err)

View File

@@ -56,7 +56,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
} }
// init db // init db
err = db.RunMigrations(database, migrationsPrefix) err = db.RunMigrations(ctx, database, migrationsPrefix)
if err != nil { if err != nil {
return fmt.Errorf("could not run migrations: %w", err) return fmt.Errorf("could not run migrations: %w", err)
} }

View File

@@ -44,7 +44,7 @@ func (h AccountImpl) handleAccountPage() http.HandlerFunc {
return return
} }
accounts, err := h.s.GetAll(user) accounts, err := h.s.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -72,7 +72,7 @@ func (h AccountImpl) handleAccountItemComp() http.HandlerFunc {
return return
} }
account, err := h.s.Get(user, id) account, err := h.s.Get(r.Context(), user, id)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -105,13 +105,13 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc {
id := r.PathValue("id") id := r.PathValue("id")
name := r.FormValue("name") name := r.FormValue("name")
if id == "new" { if id == "new" {
account, err = h.s.Add(user, name) account, err = h.s.Add(r.Context(), user, name)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
} else { } else {
account, err = h.s.UpdateName(user, id, name) account, err = h.s.UpdateName(r.Context(), user, id, name)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -135,7 +135,7 @@ func (h AccountImpl) handleDeleteAccount() http.HandlerFunc {
id := r.PathValue("id") id := r.PathValue("id")
err := h.s.Delete(user, id) err := h.s.Delete(r.Context(), user, id)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return

View File

@@ -85,7 +85,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
email := r.FormValue("email") email := r.FormValue("email")
password := r.FormValue("password") password := r.FormValue("password")
session, user, err := handler.service.SignIn(session, email, password) session, user, err := handler.service.SignIn(r.Context(), session, email, password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -163,7 +163,7 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
return return
} }
go handler.service.SendVerificationMail(user.Id, user.Email) go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email)
_, err := w.Write([]byte("<p class=\"mt-8\">Verification email sent</p>")) _, err := w.Write([]byte("<p class=\"mt-8\">Verification email sent</p>"))
if err != nil { if err != nil {
@@ -178,7 +178,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
token := r.URL.Query().Get("token") token := r.URL.Query().Get("token")
err := handler.service.VerifyUserEmail(token) err := handler.service.VerifyUserEmail(r.Context(), token)
isVerified := err == nil isVerified := err == nil
comp := auth.VerifyResponseComp(isVerified) comp := auth.VerifyResponseComp(isVerified)
@@ -203,13 +203,13 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) { _, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
slog.Info("signing up", "email", email) slog.Info("signing up", "email", email)
user, err := handler.service.SignUp(email, password) user, err := handler.service.SignUp(r.Context(), email, password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
slog.Info("Sending verification email", "to", user.Email) slog.Info("Sending verification email", "to", user.Email)
go handler.service.SendVerificationMail(user.Id, user.Email) go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email)
return nil, nil return nil, nil
}) })
@@ -239,7 +239,7 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc {
session := middleware.GetSession(r) session := middleware.GetSession(r)
if session != nil { if session != nil {
err := handler.service.SignOut(session.Id) err := handler.service.SignOut(r.Context(), session.Id)
if err != nil { if err != nil {
http.Error(w, "An error occurred", http.StatusInternalServerError) http.Error(w, "An error occurred", http.StatusInternalServerError)
return return
@@ -288,7 +288,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
password := r.FormValue("password") password := r.FormValue("password")
err := handler.service.DeleteAccount(user, password) err := handler.service.DeleteAccount(r.Context(), user, password)
if err != nil { if err != nil {
if errors.Is(err, service.ErrInvalidCredentials) { if errors.Is(err, service.ErrInvalidCredentials) {
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
@@ -334,7 +334,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
currPass := r.FormValue("current-password") currPass := r.FormValue("current-password")
newPass := r.FormValue("new-password") newPass := r.FormValue("new-password")
err := handler.service.ChangePassword(user, session.Id, currPass, newPass) err := handler.service.ChangePassword(r.Context(), user, session.Id, currPass, newPass)
if err != nil { if err != nil {
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
return return
@@ -370,7 +370,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
} }
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) { _, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
err := handler.service.SendForgotPasswordMail(email) err := handler.service.SendForgotPasswordMail(r.Context(), email)
return nil, err return nil, err
}) })
@@ -396,7 +396,7 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
token := pageUrl.Query().Get("token") token := pageUrl.Query().Get("token")
newPass := r.FormValue("new-password") newPass := r.FormValue("new-password")
err = handler.service.ForgotPassword(token, newPass) err = handler.service.ForgotPassword(r.Context(), token, newPass)
if err != nil { if err != nil {
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
} else { } else {

View File

@@ -17,13 +17,13 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionId := getSessionID(r) sessionId := getSessionID(r)
session, user, _ := service.SignInSession(sessionId) session, user, _ := service.SignInSession(r.Context(), sessionId)
var err error var err error
// Always sign in anonymous // Always sign in anonymous
// This way, we can always generate csrf tokens // This way, we can always generate csrf tokens
if session == nil { if session == nil {
session, err = service.SignInAnonymous() session, err = service.SignInAnonymous(r.Context())
if err != nil { if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return return

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"log/slog" "log/slog"
"net/http" "net/http"
"spend-sparrow/internal/service" "spend-sparrow/internal/service"
@@ -11,13 +12,15 @@ import (
type csrfResponseWriter struct { type csrfResponseWriter struct {
http.ResponseWriter http.ResponseWriter
ctx context.Context
auth service.Auth auth service.Auth
session *types.Session session *types.Session
} }
func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter { func newCsrfResponseWriter(ctx context.Context, w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter {
return &csrfResponseWriter{ return &csrfResponseWriter{
ResponseWriter: w, ResponseWriter: w,
ctx: ctx,
auth: auth, auth: auth,
session: session, session: session,
} }
@@ -25,7 +28,7 @@ func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *ty
func (rr *csrfResponseWriter) Write(data []byte) (int, error) { func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
dataStr := string(data) dataStr := string(data)
csrfToken, err := rr.auth.GetCsrfToken(rr.session) csrfToken, err := rr.auth.GetCsrfToken(rr.ctx, rr.session)
if err == nil { if err == nil {
dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken) dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken)
} }
@@ -37,6 +40,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := GetSession(r) session := GetSession(r)
ctx := r.Context()
if r.Method == http.MethodPost || if r.Method == http.MethodPost ||
r.Method == http.MethodPut || r.Method == http.MethodPut ||
@@ -44,7 +48,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
r.Method == http.MethodPatch { r.Method == http.MethodPatch {
csrfToken := r.Header.Get("Csrf-Token") csrfToken := r.Header.Get("Csrf-Token")
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(ctx, csrfToken, session.Id) {
slog.Info("CSRF-Token not correct", "token", csrfToken) slog.Info("CSRF-Token not correct", "token", csrfToken)
if r.Header.Get("Hx-Request") == "true" { if r.Header.Get("Hx-Request") == "true" {
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
@@ -55,7 +59,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
} }
} }
responseWriter := newCsrfResponseWriter(w, auth, session) responseWriter := newCsrfResponseWriter(ctx, w, auth, session)
next.ServeHTTP(responseWriter, r) next.ServeHTTP(responseWriter, r)
}) })
} }

View File

@@ -14,7 +14,7 @@ func GenerateRecurringTransactions(transactionRecurring service.TransactionRecur
return return
} }
_ = transactionRecurring.GenerateTransactions(user) _ = transactionRecurring.GenerateTransactions(r.Context(), user)
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })

View File

@@ -65,19 +65,19 @@ func (h TransactionImpl) handleTransactionPage() http.HandlerFunc {
Error: r.URL.Query().Get("error"), Error: r.URL.Query().Get("error"),
} }
transactions, err := h.s.GetAll(user, filter) transactions, err := h.s.GetAll(r.Context(), user, filter)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
accounts, err := h.account.GetAll(user) accounts, err := h.account.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
treasureChests, err := h.treasureChest.GetAll(user) treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -105,13 +105,13 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc {
return return
} }
accounts, err := h.account.GetAll(user) accounts, err := h.account.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
treasureChests, err := h.treasureChest.GetAll(user) treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -124,7 +124,7 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc {
return return
} }
transaction, err := h.s.Get(user, id) transaction, err := h.s.Get(r.Context(), user, id)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -212,26 +212,26 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
var transaction *types.Transaction var transaction *types.Transaction
if idStr == "new" { if idStr == "new" {
transaction, err = h.s.Add(nil, user, input) transaction, err = h.s.Add(r.Context(), nil, user, input)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
} else { } else {
transaction, err = h.s.Update(user, input) transaction, err = h.s.Update(r.Context(), user, input)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
} }
accounts, err := h.account.GetAll(user) accounts, err := h.account.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
treasureChests, err := h.treasureChest.GetAll(user) treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -253,7 +253,7 @@ func (h TransactionImpl) handleRecalculate() http.HandlerFunc {
return return
} }
err := h.s.RecalculateBalances(user) err := h.s.RecalculateBalances(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -275,7 +275,7 @@ func (h TransactionImpl) handleDeleteTransaction() http.HandlerFunc {
id := r.PathValue("id") id := r.PathValue("id")
err := h.s.Delete(user, id) err := h.s.Delete(r.Context(), user, id)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return

View File

@@ -70,13 +70,13 @@ func (h TransactionRecurringImpl) handleUpdateTransactionRecurring() http.Handle
} }
if input.Id == "new" { if input.Id == "new" {
_, err := h.s.Add(user, input) _, err := h.s.Add(r.Context(), user, input)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
} else { } else {
_, err := h.s.Update(user, input) _, err := h.s.Update(r.Context(), user, input)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -101,7 +101,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
accountId := r.URL.Query().Get("account-id") accountId := r.URL.Query().Get("account-id")
treasureChestId := r.URL.Query().Get("treasure-chest-id") treasureChestId := r.URL.Query().Get("treasure-chest-id")
err := h.s.Delete(user, id) err := h.s.Delete(r.Context(), user, id)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -118,13 +118,13 @@ func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Req
utils.TriggerToastWithStatus(w, r, "error", "Please select an account or treasure chest", http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", "Please select an account or treasure chest", http.StatusBadRequest)
} }
if accountId != "" { if accountId != "" {
transactionsRecurring, err = h.s.GetAllByAccount(user, accountId) transactionsRecurring, err = h.s.GetAllByAccount(r.Context(), user, accountId)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
} else { } else {
transactionsRecurring, err = h.s.GetAllByTreasureChest(user, treasureChestId) transactionsRecurring, err = h.s.GetAllByTreasureChest(r.Context(), user, treasureChestId)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return

View File

@@ -48,13 +48,13 @@ func (h TreasureChestImpl) handleTreasureChestPage() http.HandlerFunc {
return return
} }
treasureChests, err := h.s.GetAll(user) treasureChests, err := h.s.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
transactionsRecurring, err := h.transactionRecurring.GetAll(user) transactionsRecurring, err := h.transactionRecurring.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -77,7 +77,7 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc {
return return
} }
treasureChests, err := h.s.GetAll(user) treasureChests, err := h.s.GetAll(r.Context(), user)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -90,13 +90,13 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc {
return return
} }
treasureChest, err := h.s.Get(user, id) treasureChest, err := h.s.Get(r.Context(), user, id)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String()) transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String())
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -132,20 +132,20 @@ func (h TreasureChestImpl) handleUpdateTreasureChest() http.HandlerFunc {
parentId := r.FormValue("parent-id") parentId := r.FormValue("parent-id")
name := r.FormValue("name") name := r.FormValue("name")
if id == "new" { if id == "new" {
treasureChest, err = h.s.Add(user, parentId, name) treasureChest, err = h.s.Add(r.Context(), user, parentId, name)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
} else { } else {
treasureChest, err = h.s.Update(user, id, parentId, name) treasureChest, err = h.s.Update(r.Context(), user, id, parentId, name)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
} }
} }
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String()) transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String())
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return
@@ -171,7 +171,7 @@ func (h TreasureChestImpl) handleDeleteTreasureChest() http.HandlerFunc {
id := r.PathValue("id") id := r.PathValue("id")
err := h.s.Delete(user, id) err := h.s.Delete(r.Context(), user, id)
if err != nil { if err != nil {
handleError(w, r, err) handleError(w, r, err)
return return

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
@@ -12,11 +13,11 @@ import (
) )
type Account interface { type Account interface {
Add(user *types.User, name string) (*types.Account, error) Add(ctx context.Context, user *types.User, name string) (*types.Account, error)
UpdateName(user *types.User, id string, name string) (*types.Account, error) UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error)
Get(user *types.User, id string) (*types.Account, error) Get(ctx context.Context, user *types.User, id string) (*types.Account, error)
GetAll(user *types.User) ([]*types.Account, error) GetAll(ctx context.Context, user *types.User) ([]*types.Account, error)
Delete(user *types.User, id string) error Delete(ctx context.Context, user *types.User, id string) error
} }
type AccountImpl struct { type AccountImpl struct {
@@ -33,7 +34,7 @@ func NewAccount(db *sqlx.DB, random Random, clock Clock) Account {
} }
} }
func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) { func (s AccountImpl) Add(ctx context.Context, user *types.User, name string) (*types.Account, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -64,7 +65,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
UpdatedBy: nil, UpdatedBy: nil,
} }
r, err := s.db.NamedExec(` r, err := s.db.NamedExecContext(ctx, `
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by) INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account) VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account)
err = db.TransformAndLogDbError("account Insert", r, err) err = db.TransformAndLogDbError("account Insert", r, err)
@@ -75,7 +76,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
return account, nil return account, nil
} }
func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) { func (s AccountImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -89,7 +90,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("account Update", nil, err) err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -99,7 +100,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
}() }()
var account types.Account var account types.Account
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Update", nil, err) err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
@@ -113,7 +114,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
account.UpdatedAt = &timestamp account.UpdatedAt = &timestamp
account.UpdatedBy = &user.Id account.UpdatedBy = &user.Id
r, err := tx.NamedExec(` r, err := tx.NamedExecContext(ctx, `
UPDATE account UPDATE account
SET SET
name = :name, name = :name,
@@ -135,7 +136,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
return &account, nil return &account, nil
} }
func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { func (s AccountImpl) Get(ctx context.Context, user *types.User, id string) (*types.Account, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -146,7 +147,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
} }
var account types.Account var account types.Account
err = s.db.Get(&account, ` err = s.db.GetContext(ctx, &account, `
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Get", nil, err) err = db.TransformAndLogDbError("account Get", nil, err)
if err != nil { if err != nil {
@@ -157,13 +158,13 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
return &account, nil return &account, nil
} }
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { func (s AccountImpl) GetAll(ctx context.Context, user *types.User) ([]*types.Account, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
accounts := make([]*types.Account, 0) accounts := make([]*types.Account, 0)
err := s.db.Select(&accounts, ` err := s.db.SelectContext(ctx, &accounts, `
SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id) SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id)
err = db.TransformAndLogDbError("account GetAll", nil, err) err = db.TransformAndLogDbError("account GetAll", nil, err)
if err != nil { if err != nil {
@@ -173,7 +174,7 @@ func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
return accounts, nil return accounts, nil
} }
func (s AccountImpl) Delete(user *types.User, id string) error { func (s AccountImpl) Delete(ctx context.Context, user *types.User, id string) error {
if user == nil { if user == nil {
return ErrUnauthorized return ErrUnauthorized
} }
@@ -183,7 +184,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("account Delete", nil, err) err = db.TransformAndLogDbError("account Delete", nil, err)
if err != nil { if err != nil {
return err return err
@@ -193,7 +194,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
}() }()
transactionsCount := 0 transactionsCount := 0
err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid) err = tx.GetContext(ctx, &transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Delete", nil, err) err = db.TransformAndLogDbError("account Delete", nil, err)
if err != nil { if err != nil {
return err return err
@@ -202,7 +203,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("account has transactions, cannot delete: %w", ErrBadRequest) return fmt.Errorf("account has transactions, cannot delete: %w", ErrBadRequest)
} }
res, err := tx.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id) res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("account Delete", res, err) err = db.TransformAndLogDbError("account Delete", res, err)
if err != nil { if err != nil {
return err return err

View File

@@ -26,24 +26,24 @@ var (
) )
type Auth interface { type Auth interface {
SignUp(email string, password string) (*types.User, error) SignUp(ctx context.Context, email string, password string) (*types.User, error)
SendVerificationMail(userId uuid.UUID, email string) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string)
VerifyUserEmail(token string) error VerifyUserEmail(ctx context.Context, token string) error
SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error)
SignInSession(sessionId string) (*types.Session, *types.User, error) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error)
SignInAnonymous() (*types.Session, error) SignInAnonymous(ctx context.Context) (*types.Session, error)
SignOut(sessionId string) error SignOut(ctx context.Context, sessionId string) error
DeleteAccount(user *types.User, currPass string) error DeleteAccount(ctx context.Context, user *types.User, currPass string) error
ChangePassword(user *types.User, sessionId string, currPass, newPass string) error ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error
SendForgotPasswordMail(email string) error SendForgotPasswordMail(ctx context.Context, email string) error
ForgotPassword(token string, newPass string) error ForgotPassword(ctx context.Context, token string, newPass string) error
IsCsrfTokenValid(tokenStr string, sessionId string) bool IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
GetCsrfToken(session *types.Session) (string, error) GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
} }
type AuthImpl struct { type AuthImpl struct {
@@ -64,8 +64,8 @@ func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *
} }
} }
func (service AuthImpl) SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) { func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) {
user, err := service.db.GetUserByEmail(email) user, err := service.db.GetUserByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
return nil, nil, ErrInvalidCredentials return nil, nil, ErrInvalidCredentials
@@ -80,12 +80,12 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
return nil, nil, ErrInvalidCredentials return nil, nil, ErrInvalidCredentials
} }
err = service.cleanUpSessionWithTokens(session) err = service.cleanUpSessionWithTokens(ctx, session)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, types.ErrInternal
} }
session, err = service.createSession(user.Id) session, err = service.createSession(ctx, user.Id)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, types.ErrInternal
} }
@@ -93,17 +93,17 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
return session, user, nil return session, user, nil
} }
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) { func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
if sessionId == "" { if sessionId == "" {
return nil, nil, ErrSessionIdInvalid return nil, nil, ErrSessionIdInvalid
} }
session, err := service.db.GetSession(sessionId) session, err := service.db.GetSession(ctx, sessionId)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, types.ErrInternal
} }
if session.ExpiresAt.Before(service.clock.Now()) { if session.ExpiresAt.Before(service.clock.Now()) {
_ = service.db.DeleteSession(sessionId) _ = service.db.DeleteSession(ctx, sessionId)
return nil, nil, nil return nil, nil, nil
} }
@@ -111,7 +111,7 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
return session, nil, nil return session, nil, nil
} }
user, err := service.db.GetUser(session.UserId) user, err := service.db.GetUser(ctx, session.UserId)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, types.ErrInternal
} }
@@ -119,8 +119,8 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
return session, user, nil return session, user, nil
} }
func (service AuthImpl) SignInAnonymous() (*types.Session, error) { func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, error) {
session, err := service.createSession(uuid.Nil) session, err := service.createSession(ctx, uuid.Nil)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, types.ErrInternal
} }
@@ -130,7 +130,7 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
return session, nil return session, nil
} }
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) { func (service AuthImpl) SignUp(ctx context.Context, email string, password string) (*types.User, error) {
_, err := mail.ParseAddress(email) _, err := mail.ParseAddress(email)
if err != nil { if err != nil {
return nil, ErrInvalidEmail return nil, ErrInvalidEmail
@@ -154,7 +154,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
err = service.db.InsertUser(user) err = service.db.InsertUser(ctx, user)
if err != nil { if err != nil {
if errors.Is(err, db.ErrAlreadyExists) { if errors.Is(err, db.ErrAlreadyExists) {
return nil, ErrAccountExists return nil, ErrAccountExists
@@ -166,8 +166,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
return user, nil return user, nil
} }
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify) tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, db.ErrNotFound) {
return return
} }
@@ -192,7 +192,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(24*time.Hour)) service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token) err = service.db.InsertToken(ctx, token)
if err != nil { if err != nil {
return return
} }
@@ -208,17 +208,17 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
service.mail.SendMail(email, "Welcome to spend-sparrow", w.String()) service.mail.SendMail(email, "Welcome to spend-sparrow", w.String())
} }
func (service AuthImpl) VerifyUserEmail(tokenStr string) error { func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
if tokenStr == "" { if tokenStr == "" {
return types.ErrInternal return types.ErrInternal
} }
token, err := service.db.GetToken(tokenStr) token, err := service.db.GetToken(ctx, tokenStr)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
user, err := service.db.GetUser(token.UserId) user, err := service.db.GetUser(ctx, token.UserId)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
@@ -236,21 +236,21 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
user.EmailVerified = true user.EmailVerified = true
user.EmailVerifiedAt = &now user.EmailVerifiedAt = &now
err = service.db.UpdateUser(user) err = service.db.UpdateUser(ctx, user)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
_ = service.db.DeleteToken(token.Token) _ = service.db.DeleteToken(ctx, token.Token)
return nil return nil
} }
func (service AuthImpl) SignOut(sessionId string) error { func (service AuthImpl) SignOut(ctx context.Context, sessionId string) error {
return service.db.DeleteSession(sessionId) return service.db.DeleteSession(ctx, sessionId)
} }
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error { func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, currPass string) error {
userDb, err := service.db.GetUser(user.Id) userDb, err := service.db.GetUser(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
@@ -260,7 +260,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
return ErrInvalidCredentials return ErrInvalidCredentials
} }
err = service.db.DeleteUser(user.Id) err = service.db.DeleteUser(ctx, user.Id)
if err != nil { if err != nil {
return err return err
} }
@@ -270,7 +270,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
return nil return nil
} }
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error { func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
} }
@@ -288,18 +288,18 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP
newHash := GetHashPassword(newPass, user.Salt) newHash := GetHashPassword(newPass, user.Salt)
user.Password = newHash user.Password = newHash
err := service.db.UpdateUser(user) err := service.db.UpdateUser(ctx, user)
if err != nil { if err != nil {
return err return err
} }
sessions, err := service.db.GetSessions(user.Id) sessions, err := service.db.GetSessions(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
for _, s := range sessions { for _, s := range sessions {
if s.Id != sessionId { if s.Id != sessionId {
err = service.db.DeleteSession(s.Id) err = service.db.DeleteSession(ctx, s.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
@@ -309,13 +309,13 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP
return nil return nil
} }
func (service AuthImpl) SendForgotPasswordMail(email string) error { func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
tokenStr, err := service.random.String(32) tokenStr, err := service.random.String(32)
if err != nil { if err != nil {
return err return err
} }
user, err := service.db.GetUserByEmail(email) user, err := service.db.GetUserByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
return nil return nil
@@ -332,7 +332,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(15*time.Minute)) service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(token) err = service.db.InsertToken(ctx, token)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
@@ -348,17 +348,17 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
return nil return nil
} }
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
} }
token, err := service.db.GetToken(tokenStr) token, err := service.db.GetToken(ctx, tokenStr)
if err != nil { if err != nil {
return ErrTokenInvalid return ErrTokenInvalid
} }
err = service.db.DeleteToken(tokenStr) err = service.db.DeleteToken(ctx, tokenStr)
if err != nil { if err != nil {
return err return err
} }
@@ -368,7 +368,7 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
return ErrTokenInvalid return ErrTokenInvalid
} }
user, err := service.db.GetUser(token.UserId) user, err := service.db.GetUser(ctx, token.UserId)
if err != nil { if err != nil {
slog.Error("Could not get user from token", "err", err) slog.Error("Could not get user from token", "err", err)
return types.ErrInternal return types.ErrInternal
@@ -377,18 +377,18 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
passHash := GetHashPassword(newPass, user.Salt) passHash := GetHashPassword(newPass, user.Salt)
user.Password = passHash user.Password = passHash
err = service.db.UpdateUser(user) err = service.db.UpdateUser(ctx, user)
if err != nil { if err != nil {
return err return err
} }
sessions, err := service.db.GetSessions(user.Id) sessions, err := service.db.GetSessions(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
for _, session := range sessions { for _, session := range sessions {
err = service.db.DeleteSession(session.Id) err = service.db.DeleteSession(ctx, session.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
@@ -397,8 +397,8 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
return nil return nil
} }
func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool { func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
token, err := service.db.GetToken(tokenStr) token, err := service.db.GetToken(ctx, tokenStr)
if err != nil { if err != nil {
return false return false
} }
@@ -412,12 +412,12 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
return true return true
} }
func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) { func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session) (string, error) {
if session == nil { if session == nil {
return "", types.ErrInternal return "", types.ErrInternal
} }
tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf) tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
if len(tokens) > 0 { if len(tokens) > 0 {
return tokens[0].Token, nil return tokens[0].Token, nil
@@ -435,7 +435,7 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
types.TokenTypeCsrf, types.TokenTypeCsrf,
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(8*time.Hour)) service.clock.Now().Add(8*time.Hour))
err = service.db.InsertToken(token) err = service.db.InsertToken(ctx, token)
if err != nil { if err != nil {
return "", types.ErrInternal return "", types.ErrInternal
} }
@@ -445,22 +445,22 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
return tokenStr, nil return tokenStr, nil
} }
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error { func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error {
if session == nil { if session == nil {
return nil return nil
} }
err := service.db.DeleteSession(session.Id) err := service.db.DeleteSession(ctx, session.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf) tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
for _, token := range tokens { for _, token := range tokens {
err = service.db.DeleteToken(token.Token) err = service.db.DeleteToken(ctx, token.Token)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
@@ -469,13 +469,13 @@ func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
return nil return nil
} }
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) { func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*types.Session, error) {
sessionId, err := service.random.String(32) sessionId, err := service.random.String(32)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, types.ErrInternal
} }
err = service.db.DeleteOldSessions(userId) err = service.db.DeleteOldSessions(ctx, userId)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, types.ErrInternal
} }
@@ -485,7 +485,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error)
session := types.NewSession(sessionId, userId, createAt, expiresAt) session := types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session) err = service.db.InsertSession(ctx, session)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, types.ErrInternal
} }

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
@@ -13,13 +14,13 @@ import (
) )
type Transaction interface { type Transaction interface {
Add(tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
Update(user *types.User, transaction types.Transaction) (*types.Transaction, error) Update(ctx context.Context, user *types.User, transaction types.Transaction) (*types.Transaction, error)
Get(user *types.User, id string) (*types.Transaction, error) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error)
GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
Delete(user *types.User, id string) error Delete(ctx context.Context, user *types.User, id string) error
RecalculateBalances(user *types.User) error RecalculateBalances(ctx context.Context, user *types.User) error
} }
type TransactionImpl struct { type TransactionImpl struct {
@@ -36,7 +37,7 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
} }
} }
func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -45,7 +46,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
ownsTransaction := false ownsTransaction := false
if tx == nil { if tx == nil {
ownsTransaction = true ownsTransaction = true
tx, err = s.db.Beginx() tx, err = s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transaction Add", nil, err) err = db.TransformAndLogDbError("transaction Add", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -55,12 +56,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
}() }()
} }
transaction, err := s.validateAndEnrichTransaction(tx, nil, user.Id, transactionInput) transaction, err := s.validateAndEnrichTransaction(ctx, tx, nil, user.Id, transactionInput)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r, err := tx.NamedExec(` r, err := tx.NamedExecContext(ctx, `
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp, INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
party, description, error, created_at, created_by) party, description, error, created_at, created_by)
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp, VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
@@ -71,8 +72,8 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
} }
if transaction.Error == nil && transaction.AccountId != nil { if transaction.Error == nil && transaction.AccountId != nil {
r, err = tx.Exec(` r, err = tx.ExecContext(ctx, `
UPDATE account UPDATE actx context.Context,ccount
SET current_balance = current_balance + ? SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction Add", r, err) err = db.TransformAndLogDbError("transaction Add", r, err)
@@ -82,7 +83,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
} }
if transaction.Error == nil && transaction.TreasureChestId != nil { if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err = tx.Exec(` r, err = tx.ExecContext(ctx, `
UPDATE treasure_chest UPDATE treasure_chest
SET current_balance = current_balance + ? SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
@@ -103,12 +104,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
return transaction, nil return transaction, nil
} }
func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*types.Transaction, error) { func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transaction Update", nil, err) err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -118,7 +119,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
}() }()
transaction := &types.Transaction{} transaction := &types.Transaction{}
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id) err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
err = db.TransformAndLogDbError("transaction Update", nil, err) err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
@@ -128,7 +129,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
} }
if transaction.Error == nil && transaction.AccountId != nil { if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(` r, err := tx.ExecContext(ctx, `
UPDATE account UPDATE account
SET current_balance = current_balance - ? SET current_balance = current_balance - ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
@@ -138,7 +139,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
} }
} }
if transaction.Error == nil && transaction.TreasureChestId != nil { if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(` r, err := tx.ExecContext(ctx, `
UPDATE treasure_chest UPDATE treasure_chest
SET current_balance = current_balance - ? SET current_balance = current_balance - ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
@@ -148,13 +149,13 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
} }
} }
transaction, err = s.validateAndEnrichTransaction(tx, transaction, user.Id, input) transaction, err = s.validateAndEnrichTransaction(ctx, tx, transaction, user.Id, input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if transaction.Error == nil && transaction.AccountId != nil { if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(` r, err := tx.ExecContext(ctx, `
UPDATE account UPDATE account
SET current_balance = current_balance + ? SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
@@ -164,7 +165,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
} }
} }
if transaction.Error == nil && transaction.TreasureChestId != nil { if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(` r, err := tx.ExecContext(ctx, `
UPDATE treasure_chest UPDATE treasure_chest
SET current_balance = current_balance + ? SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
@@ -174,7 +175,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
} }
} }
r, err := tx.NamedExec(` r, err := tx.NamedExecContext(ctx, `
UPDATE "transaction" UPDATE "transaction"
SET SET
account_id = :account_id, account_id = :account_id,
@@ -202,7 +203,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
return transaction, nil return transaction, nil
} }
func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, error) { func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -213,7 +214,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
} }
var transaction types.Transaction var transaction types.Transaction
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Get", nil, err) err = db.TransformAndLogDbError("transaction Get", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
@@ -225,13 +226,13 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
return &transaction, nil return &transaction, nil
} }
func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) { func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
transactions := make([]*types.Transaction, 0) transactions := make([]*types.Transaction, 0)
err := s.db.Select(&transactions, ` err := s.db.SelectContext(ctx, &transactions, `
SELECT * SELECT *
FROM "transaction" FROM "transaction"
WHERE user_id = ? WHERE user_id = ?
@@ -254,7 +255,7 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF
return transactions, nil return transactions, nil
} }
func (s TransactionImpl) Delete(user *types.User, id string) error { func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error {
if user == nil { if user == nil {
return ErrUnauthorized return ErrUnauthorized
} }
@@ -264,7 +265,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transaction Delete", nil, err) err = db.TransformAndLogDbError("transaction Delete", nil, err)
if err != nil { if err != nil {
return nil return nil
@@ -274,14 +275,14 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
}() }()
var transaction types.Transaction var transaction types.Transaction
err = tx.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Delete", nil, err) err = db.TransformAndLogDbError("transaction Delete", nil, err)
if err != nil { if err != nil {
return err return err
} }
if transaction.Error == nil && transaction.AccountId != nil { if transaction.Error == nil && transaction.AccountId != nil {
r, err := tx.Exec(` r, err := tx.ExecContext(ctx, `
UPDATE account UPDATE account
SET current_balance = current_balance - ? SET current_balance = current_balance - ?
WHERE id = ? WHERE id = ?
@@ -293,7 +294,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
} }
if transaction.Error == nil && transaction.TreasureChestId != nil { if transaction.Error == nil && transaction.TreasureChestId != nil {
r, err := tx.Exec(` r, err := tx.ExecContext(ctx, `
UPDATE treasure_chest UPDATE treasure_chest
SET current_balance = current_balance - ? SET current_balance = current_balance - ?
WHERE id = ? WHERE id = ?
@@ -304,7 +305,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
} }
} }
r, err := tx.Exec("DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id) r, err := tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err) err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil { if err != nil {
return err return err
@@ -319,12 +320,12 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
return nil return nil
} }
func (s TransactionImpl) RecalculateBalances(user *types.User) error { func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error {
if user == nil { if user == nil {
return ErrUnauthorized return ErrUnauthorized
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err) err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil { if err != nil {
return err return err
@@ -333,7 +334,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
_ = tx.Rollback() _ = tx.Rollback()
}() }()
r, err := tx.Exec(` r, err := tx.ExecContext(ctx, `
UPDATE account UPDATE account
SET current_balance = 0 SET current_balance = 0
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
@@ -342,7 +343,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
return err return err
} }
r, err = tx.Exec(` r, err = tx.ExecContext(ctx, `
UPDATE treasure_chest UPDATE treasure_chest
SET current_balance = 0 SET current_balance = 0
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
@@ -351,7 +352,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
return err return err
} }
rows, err := tx.Queryx(` rows, err := tx.QueryxContext(ctx, `
SELECT * SELECT *
FROM "transaction" FROM "transaction"
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
@@ -375,7 +376,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
} }
s.updateErrors(&transaction) s.updateErrors(&transaction)
r, err = tx.Exec(` r, err = tx.ExecContext(ctx, `
UPDATE "transaction" UPDATE "transaction"
SET error = ? SET error = ?
WHERE user_id = ? WHERE user_id = ?
@@ -390,7 +391,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
} }
if transaction.AccountId != nil { if transaction.AccountId != nil {
r, err = tx.Exec(` r, err = tx.ExecContext(ctx, `
UPDATE account UPDATE account
SET current_balance = current_balance + ? SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
@@ -400,7 +401,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
} }
} }
if transaction.TreasureChestId != nil { if transaction.TreasureChestId != nil {
r, err = tx.Exec(` r, err = tx.ExecContext(ctx, `
UPDATE treasure_chest UPDATE treasure_chest
SET current_balance = current_balance + ? SET current_balance = current_balance + ?
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
@@ -420,7 +421,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
return nil return nil
} }
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) { func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) {
var ( var (
id uuid.UUID id uuid.UUID
createdAt time.Time createdAt time.Time
@@ -449,7 +450,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
} }
if input.AccountId != nil { if input.AccountId != nil {
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId)
err = db.TransformAndLogDbError("transaction validate", nil, err) err = db.TransformAndLogDbError("transaction validate", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -462,7 +463,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
if input.TreasureChestId != nil { if input.TreasureChestId != nil {
var treasureChest types.TreasureChest var treasureChest types.TreasureChest
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId) err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
err = db.TransformAndLogDbError("transaction validate", nil, err) err = db.TransformAndLogDbError("transaction validate", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
@@ -15,14 +16,14 @@ import (
) )
type TransactionRecurring interface { type TransactionRecurring interface {
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
GetAll(user *types.User) ([]*types.TransactionRecurring, error) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error)
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error)
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(user *types.User, id string) error Delete(ctx context.Context, user *types.User, id string) error
GenerateTransactions(user *types.User) error GenerateTransactions(ctx context.Context, user *types.User) error
} }
type TransactionRecurringImpl struct { type TransactionRecurringImpl struct {
@@ -41,7 +42,7 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio
} }
} }
func (s TransactionRecurringImpl) Add( func (s TransactionRecurringImpl) Add(ctx context.Context,
user *types.User, user *types.User,
transactionRecurringInput types.TransactionRecurringInput, transactionRecurringInput types.TransactionRecurringInput,
) (*types.TransactionRecurring, error) { ) (*types.TransactionRecurring, error) {
@@ -49,7 +50,7 @@ func (s TransactionRecurringImpl) Add(
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring Add", nil, err) err = db.TransformAndLogDbError("transactionRecurring Add", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -58,12 +59,12 @@ func (s TransactionRecurringImpl) Add(
_ = tx.Rollback() _ = tx.Rollback()
}() }()
transactionRecurring, err := s.validateAndEnrichTransactionRecurring(tx, nil, user.Id, transactionRecurringInput) transactionRecurring, err := s.validateAndEnrichTransactionRecurring(ctx, tx, nil, user.Id, transactionRecurringInput)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r, err := tx.NamedExec(` r, err := tx.NamedExecContext(ctx, `
INSERT INTO "transaction_recurring" (id, user_id, interval_months, INSERT INTO "transaction_recurring" (id, user_id, interval_months,
next_execution, party, description, account_id, treasure_chest_id, value, created_at, created_by) next_execution, party, description, account_id, treasure_chest_id, value, created_at, created_by)
VALUES (:id, :user_id, :interval_months, VALUES (:id, :user_id, :interval_months,
@@ -83,7 +84,7 @@ func (s TransactionRecurringImpl) Add(
return transactionRecurring, nil return transactionRecurring, nil
} }
func (s TransactionRecurringImpl) Update( func (s TransactionRecurringImpl) Update(ctx context.Context,
user *types.User, user *types.User,
input types.TransactionRecurringInput, input types.TransactionRecurringInput,
) (*types.TransactionRecurring, error) { ) (*types.TransactionRecurring, error) {
@@ -96,7 +97,7 @@ func (s TransactionRecurringImpl) Update(
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err) err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -106,7 +107,7 @@ func (s TransactionRecurringImpl) Update(
}() }()
transactionRecurring := &types.TransactionRecurring{} transactionRecurring := &types.TransactionRecurring{}
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err) err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
@@ -115,12 +116,12 @@ func (s TransactionRecurringImpl) Update(
return nil, types.ErrInternal return nil, types.ErrInternal
} }
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(tx, transactionRecurring, user.Id, input) transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r, err := tx.NamedExec(` r, err := tx.NamedExecContext(ctx, `
UPDATE transaction_recurring UPDATE transaction_recurring
SET SET
interval_months = :interval_months, interval_months = :interval_months,
@@ -148,13 +149,13 @@ func (s TransactionRecurringImpl) Update(
return transactionRecurring, nil return transactionRecurring, nil
} }
func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
err := s.db.Select(&transactionRecurrings, ` err := s.db.SelectContext(ctx, &transactionRecurrings, `
SELECT * SELECT *
FROM transaction_recurring FROM transaction_recurring
WHERE user_id = ? WHERE user_id = ?
@@ -168,7 +169,7 @@ func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.Transaction
return transactionRecurrings, nil return transactionRecurrings, nil
} }
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -179,7 +180,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err) err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -189,7 +190,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
}() }()
var rowCount int var rowCount int
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err) err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
@@ -199,7 +200,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
err = tx.Select(&transactionRecurrings, ` err = tx.SelectContext(ctx, &transactionRecurrings, `
SELECT * SELECT *
FROM transaction_recurring FROM transaction_recurring
WHERE user_id = ? WHERE user_id = ?
@@ -220,7 +221,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
return transactionRecurrings, nil return transactionRecurrings, nil
} }
func (s TransactionRecurringImpl) GetAllByTreasureChest( func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
user *types.User, user *types.User,
treasureChestId string, treasureChestId string,
) ([]*types.TransactionRecurring, error) { ) ([]*types.TransactionRecurring, error) {
@@ -234,7 +235,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err) err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -244,7 +245,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
}() }()
var rowCount int var rowCount int
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err) err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
@@ -254,7 +255,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
err = tx.Select(&transactionRecurrings, ` err = tx.SelectContext(ctx, &transactionRecurrings, `
SELECT * SELECT *
FROM transaction_recurring FROM transaction_recurring
WHERE user_id = ? WHERE user_id = ?
@@ -275,7 +276,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
return transactionRecurrings, nil return transactionRecurrings, nil
} }
func (s TransactionRecurringImpl) Delete(user *types.User, id string) error { func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error {
if user == nil { if user == nil {
return ErrUnauthorized return ErrUnauthorized
} }
@@ -285,7 +286,7 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err) err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
if err != nil { if err != nil {
return nil return nil
@@ -295,13 +296,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
}() }()
var transactionRecurring types.TransactionRecurring var transactionRecurring types.TransactionRecurring
err = tx.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.GetContext(ctx, &transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err) err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
if err != nil { if err != nil {
return err return err
} }
r, err := tx.Exec("DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id) r, err := tx.ExecContext(ctx, "DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id)
err = db.TransformAndLogDbError("transactionRecurring Delete", r, err) err = db.TransformAndLogDbError("transactionRecurring Delete", r, err)
if err != nil { if err != nil {
return err return err
@@ -316,13 +317,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
return nil return nil
} }
func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error { func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error {
if user == nil { if user == nil {
return ErrUnauthorized return ErrUnauthorized
} }
now := s.clock.Now() now := s.clock.Now()
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err) err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
if err != nil { if err != nil {
return err return err
@@ -332,7 +333,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
}() }()
recurringTransactions := make([]*types.TransactionRecurring, 0) recurringTransactions := make([]*types.TransactionRecurring, 0)
err = tx.Select(&recurringTransactions, ` err = tx.SelectContext(ctx, &recurringTransactions, `
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`, SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
user.Id, now) user.Id, now)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err) err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
@@ -350,13 +351,13 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
Value: transactionRecurring.Value, Value: transactionRecurring.Value,
} }
_, err = s.transaction.Add(tx, user, transaction) _, err = s.transaction.Add(ctx, tx, user, transaction)
if err != nil { if err != nil {
return err return err
} }
nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0) nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0)
r, err := tx.Exec(`UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`, r, err := tx.ExecContext(ctx, `UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`,
nextExecution, transactionRecurring.Id, user.Id) nextExecution, transactionRecurring.Id, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err) err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err)
if err != nil { if err != nil {
@@ -373,6 +374,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
} }
func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
ctx context.Context,
tx *sqlx.Tx, tx *sqlx.Tx,
oldTransactionRecurring *types.TransactionRecurring, oldTransactionRecurring *types.TransactionRecurring,
userId uuid.UUID, userId uuid.UUID,
@@ -417,7 +419,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
} }
accountUuid = &temp accountUuid = &temp
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err) err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -438,7 +440,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
} }
treasureChestUuid = &temp treasureChestUuid = &temp
var treasureChest types.TreasureChest var treasureChest types.TreasureChest
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err) err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
@@ -13,11 +14,11 @@ import (
) )
type TreasureChest interface { type TreasureChest interface {
Add(user *types.User, parentId, name string) (*types.TreasureChest, error) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error)
Update(user *types.User, id, parentId, name string) (*types.TreasureChest, error) Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error)
Get(user *types.User, id string) (*types.TreasureChest, error) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error)
GetAll(user *types.User) ([]*types.TreasureChest, error) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error)
Delete(user *types.User, id string) error Delete(ctx context.Context, user *types.User, id string) error
} }
type TreasureChestImpl struct { type TreasureChestImpl struct {
@@ -34,7 +35,7 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
} }
} }
func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -51,7 +52,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
var parentUuid *uuid.UUID var parentUuid *uuid.UUID
if parentId != "" { if parentId != "" {
parent, err := s.Get(user, parentId) parent, err := s.Get(ctx, user, parentId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -76,7 +77,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
UpdatedBy: nil, UpdatedBy: nil,
} }
r, err := s.db.NamedExec(` r, err := s.db.NamedExecContext(ctx, `
INSERT INTO treasure_chest (id, parent_id, user_id, name, current_balance, created_at, created_by) INSERT INTO treasure_chest (id, parent_id, user_id, name, current_balance, created_at, created_by)
VALUES (:id, :parent_id, :user_id, :name, :current_balance, :created_at, :created_by)`, treasureChest) VALUES (:id, :parent_id, :user_id, :name, :current_balance, :created_at, :created_by)`, treasureChest)
err = db.TransformAndLogDbError("treasureChest Insert", r, err) err = db.TransformAndLogDbError("treasureChest Insert", r, err)
@@ -87,7 +88,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
return treasureChest, nil return treasureChest, nil
} }
func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -101,7 +102,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("treasureChest Update", nil, err) err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -111,7 +112,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
}() }()
treasureChest := &types.TreasureChest{} treasureChest := &types.TreasureChest{}
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id) err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Update", nil, err) err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
@@ -122,12 +123,12 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
var parentUuid *uuid.UUID var parentUuid *uuid.UUID
if parentId != "" { if parentId != "" {
parent, err := s.Get(user, parentId) parent, err := s.Get(ctx, user, parentId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var childCount int var childCount int
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id) err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Update", nil, err) err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -145,7 +146,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
treasureChest.UpdatedAt = &timestamp treasureChest.UpdatedAt = &timestamp
treasureChest.UpdatedBy = &user.Id treasureChest.UpdatedBy = &user.Id
r, err := tx.NamedExec(` r, err := tx.NamedExecContext(ctx, `
UPDATE treasure_chest UPDATE treasure_chest
SET SET
parent_id = :parent_id, parent_id = :parent_id,
@@ -169,7 +170,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
return treasureChest, nil return treasureChest, nil
} }
func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
@@ -180,7 +181,7 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
} }
var treasureChest types.TreasureChest var treasureChest types.TreasureChest
err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("treasureChest Get", nil, err) err = db.TransformAndLogDbError("treasureChest Get", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
@@ -192,13 +193,13 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
return &treasureChest, nil return &treasureChest, nil
} }
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) { func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
} }
treasureChests := make([]*types.TreasureChest, 0) treasureChests := make([]*types.TreasureChest, 0)
err := s.db.Select(&treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id) err := s.db.SelectContext(ctx, &treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("treasureChest GetAll", nil, err) err = db.TransformAndLogDbError("treasureChest GetAll", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -207,7 +208,7 @@ func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, err
return sortTree(treasureChests), nil return sortTree(treasureChests), nil
} }
func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error {
if user == nil { if user == nil {
return ErrUnauthorized return ErrUnauthorized
} }
@@ -217,7 +218,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
tx, err := s.db.Beginx() tx, err := s.db.BeginTxx(ctx, nil)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err) err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil { if err != nil {
return nil return nil
@@ -227,7 +228,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
}() }()
childCount := 0 childCount := 0
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id) err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err) err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil { if err != nil {
return err return err
@@ -238,7 +239,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
} }
transactionsCount := 0 transactionsCount := 0
err = tx.Get(&transactionsCount, err = tx.GetContext(ctx, &transactionsCount,
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
user.Id, id) user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err) err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
@@ -250,7 +251,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
} }
recurringCount := 0 recurringCount := 0
err = tx.Get(&recurringCount, ` err = tx.GetContext(ctx, &recurringCount, `
SELECT COUNT(*) FROM transaction_recurring WHERE user_id = ? AND treasure_chest_id = ?`, SELECT COUNT(*) FROM transaction_recurring WHERE user_id = ? AND treasure_chest_id = ?`,
user.Id, id) user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err) err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
@@ -261,7 +262,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest) return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest)
} }
r, err := tx.Exec(`DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id) r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
err = db.TransformAndLogDbError("treasureChest Delete", r, err) err = db.TransformAndLogDbError("treasureChest Delete", r, err)
if err != nil { if err != nil {
return err return err

View File

@@ -29,7 +29,7 @@ func DoRedirect(w http.ResponseWriter, r *http.Request, url string) {
} }
} }
func WaitMinimumTime[T interface{}](waitTime time.Duration, f func() (T, error)) (T, error) { func WaitMinimumTime[T any](waitTime time.Duration, f func() (T, error)) (T, error) {
start := time.Now() start := time.Now()
result, err := f() result, err := f()
time.Sleep(waitTime - time.Since(start)) time.Sleep(waitTime - time.Since(start))

View File

@@ -6,9 +6,11 @@ import (
"os" "os"
"spend-sparrow/internal" "spend-sparrow/internal"
"github.com/jmoiron/sqlx"
"github.com/joho/godotenv" "github.com/joho/godotenv"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/uptrace/opentelemetry-go-extra/otelsql"
"github.com/uptrace/opentelemetry-go-extra/otelsqlx"
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
) )
func main() { func main() {
@@ -18,7 +20,8 @@ func main() {
return return
} }
db, err := sqlx.Open("sqlite3", "./data/spend-sparrow.db") db, err := otelsqlx.Open("sqlite3", "./data/spend-sparrow.db",
otelsql.WithAttributes(semconv.DBSystemSqlite))
if err != nil { if err != nil {
slog.Error("Could not open Database data.db", "err", err) slog.Error("Could not open Database data.db", "err", err)
return return

View File

@@ -1,6 +1,7 @@
package test_test package test_test
import ( import (
"context"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"testing" "testing"
@@ -26,7 +27,7 @@ func setupDb(t *testing.T) *sqlx.DB {
} }
}) })
err = db.RunMigrations(d, "../") err = db.RunMigrations(context.Background(), d, "../")
if err != nil { if err != nil {
t.Fatalf("Error running migrations: %v", err) t.Fatalf("Error running migrations: %v", err)
} }
@@ -47,14 +48,14 @@ func TestUser(t *testing.T) {
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(expected) err := underTest.InsertUser(context.Background(), expected)
require.NoError(t, err) require.NoError(t, err)
actual, err := underTest.GetUser(expected.Id) actual, err := underTest.GetUser(context.Background(), expected.Id)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
actual, err = underTest.GetUserByEmail(expected.Email) actual, err = underTest.GetUserByEmail(context.Background(), expected.Email)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
}) })
@@ -64,7 +65,7 @@ func TestUser(t *testing.T) {
underTest := db.NewAuthSqlite(d) underTest := db.NewAuthSqlite(d)
_, err := underTest.GetUserByEmail("nonExistentEmail") _, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail")
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, db.ErrNotFound, err)
}) })
t.Run("should return ErrUserExist", func(t *testing.T) { t.Run("should return ErrUserExist", func(t *testing.T) {
@@ -77,10 +78,10 @@ func TestUser(t *testing.T) {
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(context.Background(), user)
require.NoError(t, err) require.NoError(t, err)
err = underTest.InsertUser(user) err = underTest.InsertUser(context.Background(), user)
assert.Equal(t, db.ErrAlreadyExists, err) assert.Equal(t, db.ErrAlreadyExists, err)
}) })
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
@@ -92,7 +93,7 @@ func TestUser(t *testing.T) {
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(context.Background(), user)
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, types.ErrInternal, err)
}) })
} }
@@ -110,21 +111,21 @@ func TestToken(t *testing.T) {
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt) expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(expected) err := underTest.InsertToken(context.Background(), expected)
require.NoError(t, err) require.NoError(t, err)
actual, err := underTest.GetToken(expected.Token) actual, err := underTest.GetToken(context.Background(), expected.Token)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
expected.SessionId = "" expected.SessionId = ""
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type) actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals) assert.Equal(t, []*types.Token{expected}, actuals)
expected.SessionId = "sessionId" expected.SessionId = "sessionId"
expected.UserId = uuid.Nil expected.UserId = uuid.Nil
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type) actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals) assert.Equal(t, []*types.Token{expected}, actuals)
}) })
@@ -140,14 +141,14 @@ func TestToken(t *testing.T) {
expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt) expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt)
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt) expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(expected1) err := underTest.InsertToken(context.Background(), expected1)
require.NoError(t, err) require.NoError(t, err)
err = underTest.InsertToken(expected2) err = underTest.InsertToken(context.Background(), expected2)
require.NoError(t, err) require.NoError(t, err)
expected1.UserId = uuid.Nil expected1.UserId = uuid.Nil
expected2.UserId = uuid.Nil expected2.UserId = uuid.Nil
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type) actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals) assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
@@ -155,7 +156,7 @@ func TestToken(t *testing.T) {
expected2.SessionId = "" expected2.SessionId = ""
expected1.UserId = userId expected1.UserId = userId
expected2.UserId = userId expected2.UserId = userId
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type) actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals) assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
}) })
@@ -165,13 +166,13 @@ func TestToken(t *testing.T) {
underTest := db.NewAuthSqlite(d) underTest := db.NewAuthSqlite(d)
_, err := underTest.GetToken("nonExistent") _, err := underTest.GetToken(context.Background(), "nonExistent")
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, db.ErrNotFound, err)
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify) _, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), types.TokenTypeEmailVerify)
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, db.ErrNotFound, err)
_, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify) _, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", types.TokenTypeEmailVerify)
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, db.ErrNotFound, err)
}) })
t.Run("should return ErrAlreadyExists", func(t *testing.T) { t.Run("should return ErrAlreadyExists", func(t *testing.T) {
@@ -184,10 +185,10 @@ func TestToken(t *testing.T) {
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(context.Background(), user)
require.NoError(t, err) require.NoError(t, err)
err = underTest.InsertUser(user) err = underTest.InsertUser(context.Background(), user)
assert.Equal(t, db.ErrAlreadyExists, err) assert.Equal(t, db.ErrAlreadyExists, err)
}) })
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
@@ -199,7 +200,7 @@ func TestToken(t *testing.T) {
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(context.Background(), user)
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, types.ErrInternal, err)
}) })
} }

View File

@@ -1,6 +1,7 @@
package test_test package test_test
import ( import (
"context"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/service" "spend-sparrow/internal/service"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
@@ -36,7 +37,7 @@ func TestSignUp(t *testing.T) {
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!") _, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!")
assert.Equal(t, service.ErrInvalidEmail, err) assert.Equal(t, service.ErrInvalidEmail, err)
}) })
@@ -58,7 +59,7 @@ func TestSignUp(t *testing.T) {
} }
for _, password := range weakPasswords { for _, password := range weakPasswords {
_, err := underTest.SignUp("some@valid.email", password) _, err := underTest.SignUp(context.Background(), "some@valid.email", password)
assert.Equal(t, service.ErrInvalidPassword, err) assert.Equal(t, service.ErrInvalidPassword, err)
} }
}) })
@@ -81,10 +82,10 @@ func TestSignUp(t *testing.T) {
mockRandom.EXPECT().UUID().Return(userId, nil) mockRandom.EXPECT().UUID().Return(userId, nil)
mockRandom.EXPECT().Bytes(16).Return(salt, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil)
mockClock.EXPECT().Now().Return(createTime) mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(expected).Return(nil) mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
actual, err := underTest.SignUp(email, password) actual, err := underTest.SignUp(context.Background(), email, password)
require.NoError(t, err) require.NoError(t, err)
@@ -109,11 +110,11 @@ func TestSignUp(t *testing.T) {
mockRandom.EXPECT().Bytes(16).Return(salt, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil)
mockClock.EXPECT().Now().Return(createTime) mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists) mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(db.ErrAlreadyExists)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp(user.Email, password) _, err := underTest.SignUp(context.Background(), user.Email, password)
assert.Equal(t, service.ErrAccountExists, err) assert.Equal(t, service.ErrAccountExists, err)
}) })
} }
@@ -140,7 +141,7 @@ func TestSendVerificationMail(t *testing.T) {
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify).Return(tokens, nil) mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, types.TokenTypeEmailVerify).Return(tokens, nil)
mockMail.EXPECT().SendMail(email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool { mockMail.EXPECT().SendMail(email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
return strings.Contains(message, token.Token) return strings.Contains(message, token.Token)
@@ -148,6 +149,6 @@ func TestSendVerificationMail(t *testing.T) {
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
underTest.SendVerificationMail(userId, email) underTest.SendVerificationMail(context.Background(), userId, email)
}) })
} }

View File

@@ -182,16 +182,16 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s
csrfToken := "my-verifying-token" + add csrfToken := "my-verifying-token" + add
email := add + "mail@mail.de" email := add + "mail@mail.de"
_, err := db.Exec(` _, err := db.ExecContext(context.Background(), `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt")) VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(context.Background(), `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(context.Background(), `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf) VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -112,11 +112,11 @@ func TestIntegrationAuth(t *testing.T) {
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -138,7 +138,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -165,7 +165,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -208,7 +208,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -248,7 +248,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -296,7 +296,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -415,7 +415,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -451,10 +451,10 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
err = db.QueryRow("SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
}) })
@@ -469,11 +469,11 @@ func TestIntegrationAuth(t *testing.T) {
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -548,7 +548,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -627,11 +627,11 @@ func TestIntegrationAuth(t *testing.T) {
assert.Contains(t, resp.Header.Get("Hx-Trigger"), "An activation link has been send to your email") assert.Contains(t, resp.Header.Get("Hx-Trigger"), "An activation link has been send to your email")
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE email = ? AND email_verified = FALSE", "mail@mail.de").Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE email = ? AND email_verified = FALSE", "mail@mail.de").Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
var token string var token string
err = db.QueryRow("SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token) err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, token) assert.NotEmpty(t, token)
}) })
@@ -644,7 +644,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -658,7 +658,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -670,11 +670,11 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
token := "my-outdated-verifying-token" token := "my-outdated-verifying-token"
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO token (token, user_id, type, created_at, expires_at) INSERT INTO token (token, user_id, type, created_at, expires_at)
VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify) VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify)
require.NoError(t, err) require.NoError(t, err)
@@ -688,7 +688,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -700,11 +700,11 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
token := "my-verifying-token" token := "my-verifying-token"
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify) VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify)
require.NoError(t, err) require.NoError(t, err)
@@ -718,7 +718,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = TRUE", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = TRUE", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -747,11 +747,11 @@ func TestIntegrationAuth(t *testing.T) {
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -765,7 +765,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var csrfToken string var csrfToken string
err = db.QueryRow("SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken) err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken)
require.NoError(t, err) require.NoError(t, err)
req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil) req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil)
@@ -785,7 +785,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, -1, cookie.MaxAge) assert.Equal(t, -1, cookie.MaxAge)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
}) })
@@ -825,13 +825,13 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
sessionId := "session-id" sessionId := "session-id"
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -871,13 +871,13 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
sessionId := "session-id" sessionId := "session-id"
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -964,22 +964,22 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ?", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ?", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ?", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ?", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
err = db.QueryRow("SELECT COUNT(*) FROM account WHERE user_id = ?", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM account WHERE user_id = ?", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
err = db.QueryRow("SELECT COUNT(*) FROM treasure_chest WHERE user_id = ?", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM treasure_chest WHERE user_id = ?", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
err = db.QueryRow("SELECT COUNT(*) FROM \"transaction\" WHERE user_id = ?", userId).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM \"transaction\" WHERE user_id = ?", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
}) })
@@ -1040,13 +1040,13 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
sessionId := "session-id" sessionId := "session-id"
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -1069,7 +1069,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1080,13 +1080,13 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
sessionId := "session-id" sessionId := "session-id"
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -1119,7 +1119,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1130,13 +1130,13 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
sessionId := "session-id" sessionId := "session-id"
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -1169,7 +1169,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1181,21 +1181,21 @@ func TestIntegrationAuth(t *testing.T) {
userIdOther := uuid.New() userIdOther := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
sessionId := "session-id" sessionId := "session-id"
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES ("second", ?, datetime(), datetime("now", "+1 day"))`, userId) VALUES ("second", ?, datetime(), datetime("now", "+1 day"))`, userId)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(` _, err = db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES ("other", ?, datetime(), datetime("now", "+1 day"))`, userIdOther) VALUES ("other", ?, datetime(), datetime("now", "+1 day"))`, userIdOther)
require.NoError(t, err) require.NoError(t, err)
@@ -1232,12 +1232,12 @@ func TestIntegrationAuth(t *testing.T) {
pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt")) pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt"))
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
var sessionIds []string var sessionIds []string
sessions, err := db.Query(`SELECT session_id FROM session WHERE NOT user_id = ? ORDER BY session_id`, uuid.Nil) sessions, err := db.QueryContext(ctx, `SELECT session_id FROM session WHERE NOT user_id = ? ORDER BY session_id`, uuid.Nil)
require.NoError(t, err) require.NoError(t, err)
for sessions.Next() { for sessions.Next() {
var sessionId string var sessionId string
@@ -1260,13 +1260,13 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := d.Exec(` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
sessionId := "session-id" sessionId := "session-id"
_, err = d.Exec(` _, err = d.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId) VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -1288,7 +1288,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := d.Exec(` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -1317,7 +1317,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = d.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
}) })
@@ -1363,7 +1363,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := db.Exec(` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -1399,7 +1399,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg) assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg)
var rows int var rows int
err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1413,7 +1413,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := d.Exec(` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -1445,7 +1445,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1456,7 +1456,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := d.Exec(` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -1473,7 +1473,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.NotEmpty(t, anonymousCsrfToken) assert.NotEmpty(t, anonymousCsrfToken)
token := "password-reset-token" token := "password-reset-token"
_, err = d.Exec(` _, err = d.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset) VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset)
require.NoError(t, err) require.NoError(t, err)
@@ -1494,7 +1494,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1505,7 +1505,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := d.Exec(` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
@@ -1522,7 +1522,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.NotEmpty(t, anonymousCsrfToken) assert.NotEmpty(t, anonymousCsrfToken)
token := "password-reset-token" token := "password-reset-token"
_, err = d.Exec(` _, err = d.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset) VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset)
require.NoError(t, err) require.NoError(t, err)
@@ -1543,7 +1543,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1554,12 +1554,12 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := service.GetHashPassword("password", []byte("salt"))
_, err := d.Exec(` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = d.Exec(` _, err = d.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId) VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -1590,7 +1590,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var token string var token string
err = d.QueryRow("SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token) err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token)
require.NoError(t, err) require.NoError(t, err)
formData = url.Values{ formData = url.Values{
@@ -1608,7 +1608,7 @@ func TestIntegrationAuth(t *testing.T) {
_ = resp.Body.Close() _ = resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
sessions, err := d.Query("SELECT session_id FROM session WHERE user_id = ?", userId) sessions, err := d.QueryContext(ctx, "SELECT session_id FROM session WHERE user_id = ?", userId)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, sessions.Next()) assert.False(t, sessions.Next())
}) })
@@ -1623,11 +1623,11 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" sessionId := "session-id"
_, err := d.Exec(` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = d.Exec(` _, err = d.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId) VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -1643,7 +1643,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.NotEqual(t, sessionId, newSession.Value) assert.NotEqual(t, sessionId, newSession.Value)
var rows int var rows int
err = d.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
}) })
@@ -1670,11 +1670,11 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" sessionId := "session-id"
_, err := d.Exec(` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
_, err = d.Exec(` _, err = d.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId) VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
require.NoError(t, err) require.NoError(t, err)
@@ -1769,7 +1769,7 @@ func TestIntegrationAccount(t *testing.T) {
_ = resp.Body.Close() _ = resp.Body.Close()
var id uuid.UUID var id uuid.UUID
err = db.Get(&id, "SELECT id FROM account") err = db.GetContext(ctx, &id, "SELECT id FROM account")
require.NoError(t, err) require.NoError(t, err)
// Update // Update

View File

@@ -22,7 +22,7 @@ func TestTreasureChestShouldNotDeleteIfTransactionRecurringExists(t *testing.T)
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var parentId string var parentId string
err := db.Get(&parentId, "SELECT id FROM treasure_chest") err := db.GetContext(ctx, &parentId, "SELECT id FROM treasure_chest")
require.NoError(t, err) require.NoError(t, err)
formData = url.Values{ formData = url.Values{
@@ -33,7 +33,7 @@ func TestTreasureChestShouldNotDeleteIfTransactionRecurringExists(t *testing.T)
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var childId string var childId string
err = db.Get(&childId, "SELECT id FROM treasure_chest WHERE parent_id = ?", parentId) err = db.GetContext(ctx, &childId, "SELECT id FROM treasure_chest WHERE parent_id = ?", parentId)
require.NoError(t, err) require.NoError(t, err)
formData = url.Values{ formData = url.Values{