feat(observabillity): #153 instrument sqlx
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 8m11s
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 8m11s
This commit is contained in:
@@ -24,6 +24,7 @@ linters:
|
|||||||
- cyclop
|
- cyclop
|
||||||
- contextcheck
|
- contextcheck
|
||||||
- bodyclose # i don't care in the tests, the implementation itself doesn't do http requests
|
- bodyclose # i don't care in the tests, the implementation itself doesn't do http requests
|
||||||
|
- containedctx
|
||||||
settings:
|
settings:
|
||||||
nestif:
|
nestif:
|
||||||
min-complexity: 6
|
min-complexity: 6
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -12,6 +12,7 @@ require (
|
|||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/mattn/go-sqlite3 v1.14.28
|
github.com/mattn/go-sqlite3 v1.14.28
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
|
github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2
|
||||||
go.opentelemetry.io/contrib/bridges/otelslog v0.11.0
|
go.opentelemetry.io/contrib/bridges/otelslog v0.11.0
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0
|
||||||
go.opentelemetry.io/otel v1.36.0
|
go.opentelemetry.io/otel v1.36.0
|
||||||
@@ -38,6 +39,7 @@ require (
|
|||||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/stretchr/objx v0.5.2 // indirect
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
|
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 // indirect
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -51,6 +51,10 @@ github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
|||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2 h1:ZjUj9BLYf9PEqBn8W/OapxhPjVRdC6CsXTdULHsyk5c=
|
||||||
|
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2/go.mod h1:O8bHQfyinKwTXKkiKNGmLQS7vRsqRxIQTFZpYpHK3IQ=
|
||||||
|
github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2 h1:zA9ZXfdtowo0EKt+t7uqXNlHxPeygrxuFSIroiBVgPU=
|
||||||
|
github.com/uptrace/opentelemetry-go-extra/otelsqlx v0.3.2/go.mod h1:ySXmuW9JLCm/TjsQksuMY/7MNiWqfHnhH2xeT34uOLU=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
go.opentelemetry.io/contrib/bridges/otelslog v0.11.0 h1:EMIiYTms4Z4m3bBuKp1VmMNRLZcl6j4YbvOPL1IhlWo=
|
go.opentelemetry.io/contrib/bridges/otelslog v0.11.0 h1:EMIiYTms4Z4m3bBuKp1VmMNRLZcl6j4YbvOPL1IhlWo=
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -13,23 +14,23 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Auth interface {
|
type Auth interface {
|
||||||
InsertUser(user *types.User) error
|
InsertUser(ctx context.Context, user *types.User) error
|
||||||
UpdateUser(user *types.User) error
|
UpdateUser(ctx context.Context, user *types.User) error
|
||||||
GetUserByEmail(email string) (*types.User, error)
|
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
|
||||||
GetUser(userId uuid.UUID) (*types.User, error)
|
GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error)
|
||||||
DeleteUser(userId uuid.UUID) error
|
DeleteUser(ctx context.Context, userId uuid.UUID) error
|
||||||
|
|
||||||
InsertToken(token *types.Token) error
|
InsertToken(ctx context.Context, token *types.Token) error
|
||||||
GetToken(token string) (*types.Token, error)
|
GetToken(ctx context.Context, token string) (*types.Token, error)
|
||||||
GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
|
GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
|
||||||
GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error)
|
GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error)
|
||||||
DeleteToken(token string) error
|
DeleteToken(ctx context.Context, token string) error
|
||||||
|
|
||||||
InsertSession(session *types.Session) error
|
InsertSession(ctx context.Context, session *types.Session) error
|
||||||
GetSession(sessionId string) (*types.Session, error)
|
GetSession(ctx context.Context, sessionId string) (*types.Session, error)
|
||||||
GetSessions(userId uuid.UUID) ([]*types.Session, error)
|
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
|
||||||
DeleteSession(sessionId string) error
|
DeleteSession(ctx context.Context, sessionId string) error
|
||||||
DeleteOldSessions(userId uuid.UUID) error
|
DeleteOldSessions(ctx context.Context, userId uuid.UUID) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthSqlite struct {
|
type AuthSqlite struct {
|
||||||
@@ -40,8 +41,8 @@ func NewAuthSqlite(db *sqlx.DB) *AuthSqlite {
|
|||||||
return &AuthSqlite{db: db}
|
return &AuthSqlite{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) InsertUser(user *types.User) error {
|
func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error {
|
||||||
_, err := db.db.Exec(`
|
_, err := db.db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt)
|
user.Id, user.Email, user.EmailVerified, user.EmailVerifiedAt, user.IsAdmin, user.Password, user.Salt, user.CreateAt)
|
||||||
@@ -58,8 +59,8 @@ func (db AuthSqlite) InsertUser(user *types.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) UpdateUser(user *types.User) error {
|
func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error {
|
||||||
_, err := db.db.Exec(`
|
_, err := db.db.ExecContext(ctx, `
|
||||||
UPDATE user
|
UPDATE user
|
||||||
SET email_verified = ?, email_verified_at = ?, password = ?
|
SET email_verified = ?, email_verified_at = ?, password = ?
|
||||||
WHERE user_id = ?`,
|
WHERE user_id = ?`,
|
||||||
@@ -73,7 +74,7 @@ func (db AuthSqlite) UpdateUser(user *types.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
|
func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
||||||
var (
|
var (
|
||||||
userId uuid.UUID
|
userId uuid.UUID
|
||||||
emailVerified bool
|
emailVerified bool
|
||||||
@@ -84,7 +85,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
|
|||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
err := db.db.QueryRow(`
|
err := db.db.QueryRowContext(ctx, `
|
||||||
SELECT user_id, email_verified, email_verified_at, password, salt, created_at
|
SELECT user_id, email_verified, email_verified_at, password, salt, created_at
|
||||||
FROM user
|
FROM user
|
||||||
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
||||||
@@ -100,7 +101,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
|
|||||||
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
|
func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) {
|
||||||
var (
|
var (
|
||||||
email string
|
email string
|
||||||
emailVerified bool
|
emailVerified bool
|
||||||
@@ -111,7 +112,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
|
|||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
err := db.db.QueryRow(`
|
err := db.db.QueryRowContext(ctx, `
|
||||||
SELECT email, email_verified, email_verified_at, password, salt, created_at
|
SELECT email, email_verified, email_verified_at, password, salt, created_at
|
||||||
FROM user
|
FROM user
|
||||||
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
||||||
@@ -127,49 +128,49 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
|
|||||||
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
|
func (db AuthSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error {
|
||||||
tx, err := db.db.Begin()
|
tx, err := db.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Could not start transaction", "err", err)
|
slog.Error("Could not start transaction", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("DELETE FROM account WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.Error("Could not delete accounts", "err", err)
|
slog.Error("Could not delete accounts", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("DELETE FROM token WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.Error("Could not delete user tokens", "err", err)
|
slog.Error("Could not delete user tokens", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("DELETE FROM session WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.Error("Could not delete sessions", "err", err)
|
slog.Error("Could not delete sessions", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("DELETE FROM user WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.Error("Could not delete user", "err", err)
|
slog.Error("Could not delete user", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("DELETE FROM treasure_chest WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.Error("Could not delete user", "err", err)
|
slog.Error("Could not delete user", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("DELETE FROM \"transaction\" WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.Error("Could not delete user", "err", err)
|
slog.Error("Could not delete user", "err", err)
|
||||||
@@ -185,8 +186,8 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) InsertToken(token *types.Token) error {
|
func (db AuthSqlite) InsertToken(ctx context.Context, token *types.Token) error {
|
||||||
_, err := db.db.Exec(`
|
_, err := db.db.ExecContext(ctx, `
|
||||||
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
|
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
|
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
|
||||||
|
|
||||||
@@ -198,7 +199,7 @@ func (db AuthSqlite) InsertToken(token *types.Token) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
|
func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, error) {
|
||||||
var (
|
var (
|
||||||
userId uuid.UUID
|
userId uuid.UUID
|
||||||
sessionId string
|
sessionId string
|
||||||
@@ -209,7 +210,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
|
|||||||
expiresAt time.Time
|
expiresAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
err := db.db.QueryRow(`
|
err := db.db.QueryRowContext(ctx, `
|
||||||
SELECT user_id, session_id, type, created_at, expires_at
|
SELECT user_id, session_id, type, created_at, expires_at
|
||||||
FROM token
|
FROM token
|
||||||
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
|
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
|
||||||
@@ -239,8 +240,8 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
|
|||||||
return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
|
return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
|
func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
|
||||||
query, err := db.db.Query(`
|
query, err := db.db.QueryContext(ctx, `
|
||||||
SELECT token, created_at, expires_at
|
SELECT token, created_at, expires_at
|
||||||
FROM token
|
FROM token
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
@@ -254,8 +255,8 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.
|
|||||||
return getTokensFromQuery(query, userId, "", tokenType)
|
return getTokensFromQuery(query, userId, "", tokenType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
|
func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
|
||||||
query, err := db.db.Query(`
|
query, err := db.db.QueryContext(ctx, `
|
||||||
SELECT token, created_at, expires_at
|
SELECT token, created_at, expires_at
|
||||||
FROM token
|
FROM token
|
||||||
WHERE session_id = ?
|
WHERE session_id = ?
|
||||||
@@ -312,8 +313,8 @@ func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tok
|
|||||||
return tokens, nil
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteToken(token string) error {
|
func (db AuthSqlite) DeleteToken(ctx context.Context, token string) error {
|
||||||
_, err := db.db.Exec("DELETE FROM token WHERE token = ?", token)
|
_, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Could not delete token", "err", err)
|
slog.Error("Could not delete token", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
@@ -321,8 +322,8 @@ func (db AuthSqlite) DeleteToken(token string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) InsertSession(session *types.Session) error {
|
func (db AuthSqlite) InsertSession(ctx context.Context, session *types.Session) error {
|
||||||
_, err := db.db.Exec(`
|
_, err := db.db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
|
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
|
||||||
|
|
||||||
@@ -334,14 +335,14 @@ func (db AuthSqlite) InsertSession(session *types.Session) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.Session, error) {
|
||||||
var (
|
var (
|
||||||
userId uuid.UUID
|
userId uuid.UUID
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
expiresAt time.Time
|
expiresAt time.Time
|
||||||
)
|
)
|
||||||
|
|
||||||
err := db.db.QueryRow(`
|
err := db.db.QueryRowContext(ctx, `
|
||||||
SELECT user_id, created_at, expires_at
|
SELECT user_id, created_at, expires_at
|
||||||
FROM session
|
FROM session
|
||||||
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
|
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
|
||||||
@@ -354,9 +355,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
|||||||
return types.NewSession(sessionId, userId, createdAt, expiresAt), nil
|
return types.NewSession(sessionId, userId, createdAt, expiresAt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
|
func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) {
|
||||||
var sessions []*types.Session
|
var sessions []*types.Session
|
||||||
err := db.db.Select(&sessions, `
|
err := db.db.SelectContext(ctx, &sessions, `
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM session
|
FROM session
|
||||||
WHERE user_id = ?`, userId)
|
WHERE user_id = ?`, userId)
|
||||||
@@ -368,8 +369,8 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
|
|||||||
return sessions, nil
|
return sessions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
|
func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error {
|
||||||
_, err := db.db.Exec(`
|
_, err := db.db.ExecContext(ctx, `
|
||||||
DELETE FROM session
|
DELETE FROM session
|
||||||
WHERE expires_at < datetime('now')
|
WHERE expires_at < datetime('now')
|
||||||
AND user_id = ?`, userId)
|
AND user_id = ?`, userId)
|
||||||
@@ -380,9 +381,9 @@ func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteSession(sessionId string) error {
|
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
|
||||||
if sessionId != "" {
|
if sessionId != "" {
|
||||||
_, err := db.db.Exec("DELETE FROM session WHERE session_id = ?", sessionId)
|
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Could not delete session", "err", err)
|
slog.Error("Could not delete session", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
@@ -20,7 +21,7 @@ func (l migrationLogger) Verbose() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunMigrations(db *sqlx.DB, pathPrefix string) error {
|
func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error {
|
||||||
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Could not create Migration instance", "err", err)
|
slog.Error("Could not create Migration instance", "err", err)
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// init db
|
// init db
|
||||||
err = db.RunMigrations(database, migrationsPrefix)
|
err = db.RunMigrations(ctx, database, migrationsPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not run migrations: %w", err)
|
return fmt.Errorf("could not run migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func (h AccountImpl) handleAccountPage() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := h.s.GetAll(user)
|
accounts, err := h.s.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -72,7 +72,7 @@ func (h AccountImpl) handleAccountItemComp() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := h.s.Get(user, id)
|
account, err := h.s.Get(r.Context(), user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -105,13 +105,13 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc {
|
|||||||
id := r.PathValue("id")
|
id := r.PathValue("id")
|
||||||
name := r.FormValue("name")
|
name := r.FormValue("name")
|
||||||
if id == "new" {
|
if id == "new" {
|
||||||
account, err = h.s.Add(user, name)
|
account, err = h.s.Add(r.Context(), user, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
account, err = h.s.UpdateName(user, id, name)
|
account, err = h.s.UpdateName(r.Context(), user, id, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -135,7 +135,7 @@ func (h AccountImpl) handleDeleteAccount() http.HandlerFunc {
|
|||||||
|
|
||||||
id := r.PathValue("id")
|
id := r.PathValue("id")
|
||||||
|
|
||||||
err := h.s.Delete(user, id)
|
err := h.s.Delete(r.Context(), user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
|||||||
email := r.FormValue("email")
|
email := r.FormValue("email")
|
||||||
password := r.FormValue("password")
|
password := r.FormValue("password")
|
||||||
|
|
||||||
session, user, err := handler.service.SignIn(session, email, password)
|
session, user, err := handler.service.SignIn(r.Context(), session, email, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -163,7 +163,7 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
go handler.service.SendVerificationMail(user.Id, user.Email)
|
go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email)
|
||||||
|
|
||||||
_, err := w.Write([]byte("<p class=\"mt-8\">Verification email sent</p>"))
|
_, err := w.Write([]byte("<p class=\"mt-8\">Verification email sent</p>"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -178,7 +178,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
|||||||
|
|
||||||
token := r.URL.Query().Get("token")
|
token := r.URL.Query().Get("token")
|
||||||
|
|
||||||
err := handler.service.VerifyUserEmail(token)
|
err := handler.service.VerifyUserEmail(r.Context(), token)
|
||||||
|
|
||||||
isVerified := err == nil
|
isVerified := err == nil
|
||||||
comp := auth.VerifyResponseComp(isVerified)
|
comp := auth.VerifyResponseComp(isVerified)
|
||||||
@@ -203,13 +203,13 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
|
|||||||
|
|
||||||
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
|
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
|
||||||
slog.Info("signing up", "email", email)
|
slog.Info("signing up", "email", email)
|
||||||
user, err := handler.service.SignUp(email, password)
|
user, err := handler.service.SignUp(r.Context(), email, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("Sending verification email", "to", user.Email)
|
slog.Info("Sending verification email", "to", user.Email)
|
||||||
go handler.service.SendVerificationMail(user.Id, user.Email)
|
go handler.service.SendVerificationMail(r.Context(), user.Id, user.Email)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -239,7 +239,7 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc {
|
|||||||
session := middleware.GetSession(r)
|
session := middleware.GetSession(r)
|
||||||
|
|
||||||
if session != nil {
|
if session != nil {
|
||||||
err := handler.service.SignOut(session.Id)
|
err := handler.service.SignOut(r.Context(), session.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "An error occurred", http.StatusInternalServerError)
|
http.Error(w, "An error occurred", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@@ -288,7 +288,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
|
|||||||
|
|
||||||
password := r.FormValue("password")
|
password := r.FormValue("password")
|
||||||
|
|
||||||
err := handler.service.DeleteAccount(user, password)
|
err := handler.service.DeleteAccount(r.Context(), user, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrInvalidCredentials) {
|
if errors.Is(err, service.ErrInvalidCredentials) {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
|
||||||
@@ -334,7 +334,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
|
|||||||
currPass := r.FormValue("current-password")
|
currPass := r.FormValue("current-password")
|
||||||
newPass := r.FormValue("new-password")
|
newPass := r.FormValue("new-password")
|
||||||
|
|
||||||
err := handler.service.ChangePassword(user, session.Id, currPass, newPass)
|
err := handler.service.ChangePassword(r.Context(), user, session.Id, currPass, newPass)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
@@ -370,7 +370,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
|
_, err := utils.WaitMinimumTime(securityWaitDuration, func() (any, error) {
|
||||||
err := handler.service.SendForgotPasswordMail(email)
|
err := handler.service.SendForgotPasswordMail(r.Context(), email)
|
||||||
return nil, err
|
return nil, err
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -396,7 +396,7 @@ func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
|
|||||||
token := pageUrl.Query().Get("token")
|
token := pageUrl.Query().Get("token")
|
||||||
newPass := r.FormValue("new-password")
|
newPass := r.FormValue("new-password")
|
||||||
|
|
||||||
err = handler.service.ForgotPassword(token, newPass)
|
err = handler.service.ForgotPassword(r.Context(), token, newPass)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusBadRequest)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -17,13 +17,13 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
|||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
sessionId := getSessionID(r)
|
sessionId := getSessionID(r)
|
||||||
session, user, _ := service.SignInSession(sessionId)
|
session, user, _ := service.SignInSession(r.Context(), sessionId)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
// Always sign in anonymous
|
// Always sign in anonymous
|
||||||
// This way, we can always generate csrf tokens
|
// This way, we can always generate csrf tokens
|
||||||
if session == nil {
|
if session == nil {
|
||||||
session, err = service.SignInAnonymous()
|
session, err = service.SignInAnonymous(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"spend-sparrow/internal/service"
|
"spend-sparrow/internal/service"
|
||||||
@@ -11,13 +12,15 @@ import (
|
|||||||
|
|
||||||
type csrfResponseWriter struct {
|
type csrfResponseWriter struct {
|
||||||
http.ResponseWriter
|
http.ResponseWriter
|
||||||
|
ctx context.Context
|
||||||
auth service.Auth
|
auth service.Auth
|
||||||
session *types.Session
|
session *types.Session
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter {
|
func newCsrfResponseWriter(ctx context.Context, w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter {
|
||||||
return &csrfResponseWriter{
|
return &csrfResponseWriter{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
|
ctx: ctx,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
session: session,
|
session: session,
|
||||||
}
|
}
|
||||||
@@ -25,7 +28,7 @@ func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *ty
|
|||||||
|
|
||||||
func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
|
func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
|
||||||
dataStr := string(data)
|
dataStr := string(data)
|
||||||
csrfToken, err := rr.auth.GetCsrfToken(rr.session)
|
csrfToken, err := rr.auth.GetCsrfToken(rr.ctx, rr.session)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken)
|
dataStr = strings.ReplaceAll(dataStr, "CSRF_TOKEN", csrfToken)
|
||||||
}
|
}
|
||||||
@@ -37,6 +40,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
|||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
session := GetSession(r)
|
session := GetSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
if r.Method == http.MethodPost ||
|
if r.Method == http.MethodPost ||
|
||||||
r.Method == http.MethodPut ||
|
r.Method == http.MethodPut ||
|
||||||
@@ -44,7 +48,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
|||||||
r.Method == http.MethodPatch {
|
r.Method == http.MethodPatch {
|
||||||
csrfToken := r.Header.Get("Csrf-Token")
|
csrfToken := r.Header.Get("Csrf-Token")
|
||||||
|
|
||||||
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
|
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(ctx, csrfToken, session.Id) {
|
||||||
slog.Info("CSRF-Token not correct", "token", csrfToken)
|
slog.Info("CSRF-Token not correct", "token", csrfToken)
|
||||||
if r.Header.Get("Hx-Request") == "true" {
|
if r.Header.Get("Hx-Request") == "true" {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
|
||||||
@@ -55,7 +59,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
responseWriter := newCsrfResponseWriter(w, auth, session)
|
responseWriter := newCsrfResponseWriter(ctx, w, auth, session)
|
||||||
next.ServeHTTP(responseWriter, r)
|
next.ServeHTTP(responseWriter, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ func GenerateRecurringTransactions(transactionRecurring service.TransactionRecur
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = transactionRecurring.GenerateTransactions(user)
|
_ = transactionRecurring.GenerateTransactions(r.Context(), user)
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -65,19 +65,19 @@ func (h TransactionImpl) handleTransactionPage() http.HandlerFunc {
|
|||||||
Error: r.URL.Query().Get("error"),
|
Error: r.URL.Query().Get("error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
transactions, err := h.s.GetAll(user, filter)
|
transactions, err := h.s.GetAll(r.Context(), user, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := h.account.GetAll(user)
|
accounts, err := h.account.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChests, err := h.treasureChest.GetAll(user)
|
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -105,13 +105,13 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := h.account.GetAll(user)
|
accounts, err := h.account.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChests, err := h.treasureChest.GetAll(user)
|
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -124,7 +124,7 @@ func (h TransactionImpl) handleTransactionItemComp() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
transaction, err := h.s.Get(user, id)
|
transaction, err := h.s.Get(r.Context(), user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -212,26 +212,26 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
|||||||
|
|
||||||
var transaction *types.Transaction
|
var transaction *types.Transaction
|
||||||
if idStr == "new" {
|
if idStr == "new" {
|
||||||
transaction, err = h.s.Add(nil, user, input)
|
transaction, err = h.s.Add(r.Context(), nil, user, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
transaction, err = h.s.Update(user, input)
|
transaction, err = h.s.Update(r.Context(), user, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := h.account.GetAll(user)
|
accounts, err := h.account.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChests, err := h.treasureChest.GetAll(user)
|
treasureChests, err := h.treasureChest.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -253,7 +253,7 @@ func (h TransactionImpl) handleRecalculate() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.s.RecalculateBalances(user)
|
err := h.s.RecalculateBalances(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -275,7 +275,7 @@ func (h TransactionImpl) handleDeleteTransaction() http.HandlerFunc {
|
|||||||
|
|
||||||
id := r.PathValue("id")
|
id := r.PathValue("id")
|
||||||
|
|
||||||
err := h.s.Delete(user, id)
|
err := h.s.Delete(r.Context(), user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -70,13 +70,13 @@ func (h TransactionRecurringImpl) handleUpdateTransactionRecurring() http.Handle
|
|||||||
}
|
}
|
||||||
|
|
||||||
if input.Id == "new" {
|
if input.Id == "new" {
|
||||||
_, err := h.s.Add(user, input)
|
_, err := h.s.Add(r.Context(), user, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, err := h.s.Update(user, input)
|
_, err := h.s.Update(r.Context(), user, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -101,7 +101,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
|
|||||||
accountId := r.URL.Query().Get("account-id")
|
accountId := r.URL.Query().Get("account-id")
|
||||||
treasureChestId := r.URL.Query().Get("treasure-chest-id")
|
treasureChestId := r.URL.Query().Get("treasure-chest-id")
|
||||||
|
|
||||||
err := h.s.Delete(user, id)
|
err := h.s.Delete(r.Context(), user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -118,13 +118,13 @@ func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Req
|
|||||||
utils.TriggerToastWithStatus(w, r, "error", "Please select an account or treasure chest", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", "Please select an account or treasure chest", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
if accountId != "" {
|
if accountId != "" {
|
||||||
transactionsRecurring, err = h.s.GetAllByAccount(user, accountId)
|
transactionsRecurring, err = h.s.GetAllByAccount(r.Context(), user, accountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
transactionsRecurring, err = h.s.GetAllByTreasureChest(user, treasureChestId)
|
transactionsRecurring, err = h.s.GetAllByTreasureChest(r.Context(), user, treasureChestId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -48,13 +48,13 @@ func (h TreasureChestImpl) handleTreasureChestPage() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChests, err := h.s.GetAll(user)
|
treasureChests, err := h.s.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionsRecurring, err := h.transactionRecurring.GetAll(user)
|
transactionsRecurring, err := h.transactionRecurring.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -77,7 +77,7 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChests, err := h.s.GetAll(user)
|
treasureChests, err := h.s.GetAll(r.Context(), user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -90,13 +90,13 @@ func (h TreasureChestImpl) handleTreasureChestItemComp() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChest, err := h.s.Get(user, id)
|
treasureChest, err := h.s.Get(r.Context(), user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String())
|
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -132,20 +132,20 @@ func (h TreasureChestImpl) handleUpdateTreasureChest() http.HandlerFunc {
|
|||||||
parentId := r.FormValue("parent-id")
|
parentId := r.FormValue("parent-id")
|
||||||
name := r.FormValue("name")
|
name := r.FormValue("name")
|
||||||
if id == "new" {
|
if id == "new" {
|
||||||
treasureChest, err = h.s.Add(user, parentId, name)
|
treasureChest, err = h.s.Add(r.Context(), user, parentId, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
treasureChest, err = h.s.Update(user, id, parentId, name)
|
treasureChest, err = h.s.Update(r.Context(), user, id, parentId, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(user, treasureChest.Id.String())
|
transactionsRecurring, err := h.transactionRecurring.GetAllByTreasureChest(r.Context(), user, treasureChest.Id.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -171,7 +171,7 @@ func (h TreasureChestImpl) handleDeleteTreasureChest() http.HandlerFunc {
|
|||||||
|
|
||||||
id := r.PathValue("id")
|
id := r.PathValue("id")
|
||||||
|
|
||||||
err := h.s.Delete(user, id)
|
err := h.s.Delete(r.Context(), user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(w, r, err)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -12,11 +13,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Account interface {
|
type Account interface {
|
||||||
Add(user *types.User, name string) (*types.Account, error)
|
Add(ctx context.Context, user *types.User, name string) (*types.Account, error)
|
||||||
UpdateName(user *types.User, id string, name string) (*types.Account, error)
|
UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error)
|
||||||
Get(user *types.User, id string) (*types.Account, error)
|
Get(ctx context.Context, user *types.User, id string) (*types.Account, error)
|
||||||
GetAll(user *types.User) ([]*types.Account, error)
|
GetAll(ctx context.Context, user *types.User) ([]*types.Account, error)
|
||||||
Delete(user *types.User, id string) error
|
Delete(ctx context.Context, user *types.User, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountImpl struct {
|
type AccountImpl struct {
|
||||||
@@ -33,7 +34,7 @@ func NewAccount(db *sqlx.DB, random Random, clock Clock) Account {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) {
|
func (s AccountImpl) Add(ctx context.Context, user *types.User, name string) (*types.Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -64,7 +65,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
|
|||||||
UpdatedBy: nil,
|
UpdatedBy: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := s.db.NamedExec(`
|
r, err := s.db.NamedExecContext(ctx, `
|
||||||
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
|
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
|
||||||
VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account)
|
VALUES (:id, :user_id, :name, :current_balance, :oink_balance, :created_at, :created_by)`, account)
|
||||||
err = db.TransformAndLogDbError("account Insert", r, err)
|
err = db.TransformAndLogDbError("account Insert", r, err)
|
||||||
@@ -75,7 +76,7 @@ func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error)
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*types.Account, error) {
|
func (s AccountImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*types.Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -89,7 +90,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
|||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("account Update", nil, err)
|
err = db.TransformAndLogDbError("account Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -99,7 +100,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var account types.Account
|
var account types.Account
|
||||||
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("account Update", nil, err)
|
err = db.TransformAndLogDbError("account Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -113,7 +114,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
|||||||
account.UpdatedAt = ×tamp
|
account.UpdatedAt = ×tamp
|
||||||
account.UpdatedBy = &user.Id
|
account.UpdatedBy = &user.Id
|
||||||
|
|
||||||
r, err := tx.NamedExec(`
|
r, err := tx.NamedExecContext(ctx, `
|
||||||
UPDATE account
|
UPDATE account
|
||||||
SET
|
SET
|
||||||
name = :name,
|
name = :name,
|
||||||
@@ -135,7 +136,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
|||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
func (s AccountImpl) Get(ctx context.Context, user *types.User, id string) (*types.Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -146,7 +147,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var account types.Account
|
var account types.Account
|
||||||
err = s.db.Get(&account, `
|
err = s.db.GetContext(ctx, &account, `
|
||||||
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("account Get", nil, err)
|
err = db.TransformAndLogDbError("account Get", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -157,13 +158,13 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
|||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
func (s AccountImpl) GetAll(ctx context.Context, user *types.User) ([]*types.Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts := make([]*types.Account, 0)
|
accounts := make([]*types.Account, 0)
|
||||||
err := s.db.Select(&accounts, `
|
err := s.db.SelectContext(ctx, &accounts, `
|
||||||
SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id)
|
SELECT * FROM account WHERE user_id = ? ORDER BY name`, user.Id)
|
||||||
err = db.TransformAndLogDbError("account GetAll", nil, err)
|
err = db.TransformAndLogDbError("account GetAll", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -173,7 +174,7 @@ func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
|||||||
return accounts, nil
|
return accounts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AccountImpl) Delete(user *types.User, id string) error {
|
func (s AccountImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -183,7 +184,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
|
|||||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("account Delete", nil, err)
|
err = db.TransformAndLogDbError("account Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -193,7 +194,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
transactionsCount := 0
|
transactionsCount := 0
|
||||||
err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid)
|
err = tx.GetContext(ctx, &transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND account_id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("account Delete", nil, err)
|
err = db.TransformAndLogDbError("account Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -202,7 +203,7 @@ func (s AccountImpl) Delete(user *types.User, id string) error {
|
|||||||
return fmt.Errorf("account has transactions, cannot delete: %w", ErrBadRequest)
|
return fmt.Errorf("account has transactions, cannot delete: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := tx.Exec("DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
|
res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
|
||||||
err = db.TransformAndLogDbError("account Delete", res, err)
|
err = db.TransformAndLogDbError("account Delete", res, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -26,24 +26,24 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Auth interface {
|
type Auth interface {
|
||||||
SignUp(email string, password string) (*types.User, error)
|
SignUp(ctx context.Context, email string, password string) (*types.User, error)
|
||||||
SendVerificationMail(userId uuid.UUID, email string)
|
SendVerificationMail(ctx context.Context, userId uuid.UUID, email string)
|
||||||
VerifyUserEmail(token string) error
|
VerifyUserEmail(ctx context.Context, token string) error
|
||||||
|
|
||||||
SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error)
|
SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error)
|
||||||
SignInSession(sessionId string) (*types.Session, *types.User, error)
|
SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error)
|
||||||
SignInAnonymous() (*types.Session, error)
|
SignInAnonymous(ctx context.Context) (*types.Session, error)
|
||||||
SignOut(sessionId string) error
|
SignOut(ctx context.Context, sessionId string) error
|
||||||
|
|
||||||
DeleteAccount(user *types.User, currPass string) error
|
DeleteAccount(ctx context.Context, user *types.User, currPass string) error
|
||||||
|
|
||||||
ChangePassword(user *types.User, sessionId string, currPass, newPass string) error
|
ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error
|
||||||
|
|
||||||
SendForgotPasswordMail(email string) error
|
SendForgotPasswordMail(ctx context.Context, email string) error
|
||||||
ForgotPassword(token string, newPass string) error
|
ForgotPassword(ctx context.Context, token string, newPass string) error
|
||||||
|
|
||||||
IsCsrfTokenValid(tokenStr string, sessionId string) bool
|
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
|
||||||
GetCsrfToken(session *types.Session) (string, error)
|
GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthImpl struct {
|
type AuthImpl struct {
|
||||||
@@ -64,8 +64,8 @@ func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) {
|
func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) {
|
||||||
user, err := service.db.GetUserByEmail(email)
|
user, err := service.db.GetUserByEmail(ctx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, nil, ErrInvalidCredentials
|
return nil, nil, ErrInvalidCredentials
|
||||||
@@ -80,12 +80,12 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
|
|||||||
return nil, nil, ErrInvalidCredentials
|
return nil, nil, ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.cleanUpSessionWithTokens(session)
|
err = service.cleanUpSessionWithTokens(ctx, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err = service.createSession(user.Id)
|
session, err = service.createSession(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -93,17 +93,17 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
|
|||||||
return session, user, nil
|
return session, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
|
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
|
||||||
if sessionId == "" {
|
if sessionId == "" {
|
||||||
return nil, nil, ErrSessionIdInvalid
|
return nil, nil, ErrSessionIdInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := service.db.GetSession(sessionId)
|
session, err := service.db.GetSession(ctx, sessionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
if session.ExpiresAt.Before(service.clock.Now()) {
|
if session.ExpiresAt.Before(service.clock.Now()) {
|
||||||
_ = service.db.DeleteSession(sessionId)
|
_ = service.db.DeleteSession(ctx, sessionId)
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,7 +111,7 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
|
|||||||
return session, nil, nil
|
return session, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := service.db.GetUser(session.UserId)
|
user, err := service.db.GetUser(ctx, session.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -119,8 +119,8 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
|
|||||||
return session, user, nil
|
return session, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
|
func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, error) {
|
||||||
session, err := service.createSession(uuid.Nil)
|
session, err := service.createSession(ctx, uuid.Nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -130,7 +130,7 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
|
|||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
|
func (service AuthImpl) SignUp(ctx context.Context, email string, password string) (*types.User, error) {
|
||||||
_, err := mail.ParseAddress(email)
|
_, err := mail.ParseAddress(email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidEmail
|
return nil, ErrInvalidEmail
|
||||||
@@ -154,7 +154,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
|
|||||||
|
|
||||||
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
|
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
|
||||||
|
|
||||||
err = service.db.InsertUser(user)
|
err = service.db.InsertUser(ctx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrAlreadyExists) {
|
if errors.Is(err, db.ErrAlreadyExists) {
|
||||||
return nil, ErrAccountExists
|
return nil, ErrAccountExists
|
||||||
@@ -166,8 +166,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
|
||||||
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
|
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify)
|
||||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -192,7 +192,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
|||||||
service.clock.Now(),
|
service.clock.Now(),
|
||||||
service.clock.Now().Add(24*time.Hour))
|
service.clock.Now().Add(24*time.Hour))
|
||||||
|
|
||||||
err = service.db.InsertToken(token)
|
err = service.db.InsertToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -208,17 +208,17 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
|||||||
service.mail.SendMail(email, "Welcome to spend-sparrow", w.String())
|
service.mail.SendMail(email, "Welcome to spend-sparrow", w.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
|
func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
|
||||||
if tokenStr == "" {
|
if tokenStr == "" {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := service.db.GetToken(tokenStr)
|
token, err := service.db.GetToken(ctx, tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := service.db.GetUser(token.UserId)
|
user, err := service.db.GetUser(ctx, token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -236,21 +236,21 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
|
|||||||
user.EmailVerified = true
|
user.EmailVerified = true
|
||||||
user.EmailVerifiedAt = &now
|
user.EmailVerifiedAt = &now
|
||||||
|
|
||||||
err = service.db.UpdateUser(user)
|
err = service.db.UpdateUser(ctx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = service.db.DeleteToken(token.Token)
|
_ = service.db.DeleteToken(ctx, token.Token)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignOut(sessionId string) error {
|
func (service AuthImpl) SignOut(ctx context.Context, sessionId string) error {
|
||||||
return service.db.DeleteSession(sessionId)
|
return service.db.DeleteSession(ctx, sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, currPass string) error {
|
||||||
userDb, err := service.db.GetUser(user.Id)
|
userDb, err := service.db.GetUser(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -260,7 +260,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
|||||||
return ErrInvalidCredentials
|
return ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.DeleteUser(user.Id)
|
err = service.db.DeleteUser(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -270,7 +270,7 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
|
func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error {
|
||||||
if !isPasswordValid(newPass) {
|
if !isPasswordValid(newPass) {
|
||||||
return ErrInvalidPassword
|
return ErrInvalidPassword
|
||||||
}
|
}
|
||||||
@@ -288,18 +288,18 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP
|
|||||||
newHash := GetHashPassword(newPass, user.Salt)
|
newHash := GetHashPassword(newPass, user.Salt)
|
||||||
user.Password = newHash
|
user.Password = newHash
|
||||||
|
|
||||||
err := service.db.UpdateUser(user)
|
err := service.db.UpdateUser(ctx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sessions, err := service.db.GetSessions(user.Id)
|
sessions, err := service.db.GetSessions(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
for _, s := range sessions {
|
for _, s := range sessions {
|
||||||
if s.Id != sessionId {
|
if s.Id != sessionId {
|
||||||
err = service.db.DeleteSession(s.Id)
|
err = service.db.DeleteSession(ctx, s.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -309,13 +309,13 @@ func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currP
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SendForgotPasswordMail(email string) error {
|
func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
|
||||||
tokenStr, err := service.random.String(32)
|
tokenStr, err := service.random.String(32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := service.db.GetUserByEmail(email)
|
user, err := service.db.GetUserByEmail(ctx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil
|
return nil
|
||||||
@@ -332,7 +332,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
|
|||||||
service.clock.Now(),
|
service.clock.Now(),
|
||||||
service.clock.Now().Add(15*time.Minute))
|
service.clock.Now().Add(15*time.Minute))
|
||||||
|
|
||||||
err = service.db.InsertToken(token)
|
err = service.db.InsertToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -348,17 +348,17 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
|
||||||
if !isPasswordValid(newPass) {
|
if !isPasswordValid(newPass) {
|
||||||
return ErrInvalidPassword
|
return ErrInvalidPassword
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := service.db.GetToken(tokenStr)
|
token, err := service.db.GetToken(ctx, tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrTokenInvalid
|
return ErrTokenInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.DeleteToken(tokenStr)
|
err = service.db.DeleteToken(ctx, tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -368,7 +368,7 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
|||||||
return ErrTokenInvalid
|
return ErrTokenInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := service.db.GetUser(token.UserId)
|
user, err := service.db.GetUser(ctx, token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Could not get user from token", "err", err)
|
slog.Error("Could not get user from token", "err", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
@@ -377,18 +377,18 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
|||||||
passHash := GetHashPassword(newPass, user.Salt)
|
passHash := GetHashPassword(newPass, user.Salt)
|
||||||
|
|
||||||
user.Password = passHash
|
user.Password = passHash
|
||||||
err = service.db.UpdateUser(user)
|
err = service.db.UpdateUser(ctx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sessions, err := service.db.GetSessions(user.Id)
|
sessions, err := service.db.GetSessions(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, session := range sessions {
|
for _, session := range sessions {
|
||||||
err = service.db.DeleteSession(session.Id)
|
err = service.db.DeleteSession(ctx, session.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -397,8 +397,8 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool {
|
func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
|
||||||
token, err := service.db.GetToken(tokenStr)
|
token, err := service.db.GetToken(ctx, tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -412,12 +412,12 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
|
func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session) (string, error) {
|
||||||
if session == nil {
|
if session == nil {
|
||||||
return "", types.ErrInternal
|
return "", types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
|
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||||
|
|
||||||
if len(tokens) > 0 {
|
if len(tokens) > 0 {
|
||||||
return tokens[0].Token, nil
|
return tokens[0].Token, nil
|
||||||
@@ -435,7 +435,7 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
|
|||||||
types.TokenTypeCsrf,
|
types.TokenTypeCsrf,
|
||||||
service.clock.Now(),
|
service.clock.Now(),
|
||||||
service.clock.Now().Add(8*time.Hour))
|
service.clock.Now().Add(8*time.Hour))
|
||||||
err = service.db.InsertToken(token)
|
err = service.db.InsertToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", types.ErrInternal
|
return "", types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -445,22 +445,22 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
|
|||||||
return tokenStr, nil
|
return tokenStr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
|
func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error {
|
||||||
if session == nil {
|
if session == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := service.db.DeleteSession(session.Id)
|
err := service.db.DeleteSession(ctx, session.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
|
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
err = service.db.DeleteToken(token.Token)
|
err = service.db.DeleteToken(ctx, token.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -469,13 +469,13 @@ func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
|
func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*types.Session, error) {
|
||||||
sessionId, err := service.random.String(32)
|
sessionId, err := service.random.String(32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.DeleteOldSessions(userId)
|
err = service.db.DeleteOldSessions(ctx, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -485,7 +485,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error)
|
|||||||
|
|
||||||
session := types.NewSession(sessionId, userId, createAt, expiresAt)
|
session := types.NewSession(sessionId, userId, createAt, expiresAt)
|
||||||
|
|
||||||
err = service.db.InsertSession(session)
|
err = service.db.InsertSession(ctx, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -13,13 +14,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Transaction interface {
|
type Transaction interface {
|
||||||
Add(tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||||
Update(user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
Update(ctx context.Context, user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||||
Get(user *types.User, id string) (*types.Transaction, error)
|
Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error)
|
||||||
GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
|
GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
|
||||||
Delete(user *types.User, id string) error
|
Delete(ctx context.Context, user *types.User, id string) error
|
||||||
|
|
||||||
RecalculateBalances(user *types.User) error
|
RecalculateBalances(ctx context.Context, user *types.User) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransactionImpl struct {
|
type TransactionImpl struct {
|
||||||
@@ -36,7 +37,7 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
|
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -45,7 +46,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
|||||||
ownsTransaction := false
|
ownsTransaction := false
|
||||||
if tx == nil {
|
if tx == nil {
|
||||||
ownsTransaction = true
|
ownsTransaction = true
|
||||||
tx, err = s.db.Beginx()
|
tx, err = s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transaction Add", nil, err)
|
err = db.TransformAndLogDbError("transaction Add", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -55,12 +56,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
transaction, err := s.validateAndEnrichTransaction(tx, nil, user.Id, transactionInput)
|
transaction, err := s.validateAndEnrichTransaction(ctx, tx, nil, user.Id, transactionInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.NamedExec(`
|
r, err := tx.NamedExecContext(ctx, `
|
||||||
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
|
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
|
||||||
party, description, error, created_at, created_by)
|
party, description, error, created_at, created_by)
|
||||||
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
|
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
|
||||||
@@ -71,8 +72,8 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
|||||||
}
|
}
|
||||||
|
|
||||||
if transaction.Error == nil && transaction.AccountId != nil {
|
if transaction.Error == nil && transaction.AccountId != nil {
|
||||||
r, err = tx.Exec(`
|
r, err = tx.ExecContext(ctx, `
|
||||||
UPDATE account
|
UPDATE actx context.Context,ccount
|
||||||
SET current_balance = current_balance + ?
|
SET current_balance = current_balance + ?
|
||||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||||
err = db.TransformAndLogDbError("transaction Add", r, err)
|
err = db.TransformAndLogDbError("transaction Add", r, err)
|
||||||
@@ -82,7 +83,7 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
|||||||
}
|
}
|
||||||
|
|
||||||
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
||||||
r, err = tx.Exec(`
|
r, err = tx.ExecContext(ctx, `
|
||||||
UPDATE treasure_chest
|
UPDATE treasure_chest
|
||||||
SET current_balance = current_balance + ?
|
SET current_balance = current_balance + ?
|
||||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||||
@@ -103,12 +104,12 @@ func (s TransactionImpl) Add(tx *sqlx.Tx, user *types.User, transactionInput typ
|
|||||||
return transaction, nil
|
return transaction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*types.Transaction, error) {
|
func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transaction Update", nil, err)
|
err = db.TransformAndLogDbError("transaction Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -118,7 +119,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
transaction := &types.Transaction{}
|
transaction := &types.Transaction{}
|
||||||
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
|
err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
|
||||||
err = db.TransformAndLogDbError("transaction Update", nil, err)
|
err = db.TransformAndLogDbError("transaction Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -128,7 +129,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
|||||||
}
|
}
|
||||||
|
|
||||||
if transaction.Error == nil && transaction.AccountId != nil {
|
if transaction.Error == nil && transaction.AccountId != nil {
|
||||||
r, err := tx.Exec(`
|
r, err := tx.ExecContext(ctx, `
|
||||||
UPDATE account
|
UPDATE account
|
||||||
SET current_balance = current_balance - ?
|
SET current_balance = current_balance - ?
|
||||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||||
@@ -138,7 +139,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
||||||
r, err := tx.Exec(`
|
r, err := tx.ExecContext(ctx, `
|
||||||
UPDATE treasure_chest
|
UPDATE treasure_chest
|
||||||
SET current_balance = current_balance - ?
|
SET current_balance = current_balance - ?
|
||||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||||
@@ -148,13 +149,13 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
transaction, err = s.validateAndEnrichTransaction(tx, transaction, user.Id, input)
|
transaction, err = s.validateAndEnrichTransaction(ctx, tx, transaction, user.Id, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if transaction.Error == nil && transaction.AccountId != nil {
|
if transaction.Error == nil && transaction.AccountId != nil {
|
||||||
r, err := tx.Exec(`
|
r, err := tx.ExecContext(ctx, `
|
||||||
UPDATE account
|
UPDATE account
|
||||||
SET current_balance = current_balance + ?
|
SET current_balance = current_balance + ?
|
||||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||||
@@ -164,7 +165,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
||||||
r, err := tx.Exec(`
|
r, err := tx.ExecContext(ctx, `
|
||||||
UPDATE treasure_chest
|
UPDATE treasure_chest
|
||||||
SET current_balance = current_balance + ?
|
SET current_balance = current_balance + ?
|
||||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||||
@@ -174,7 +175,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.NamedExec(`
|
r, err := tx.NamedExecContext(ctx, `
|
||||||
UPDATE "transaction"
|
UPDATE "transaction"
|
||||||
SET
|
SET
|
||||||
account_id = :account_id,
|
account_id = :account_id,
|
||||||
@@ -202,7 +203,7 @@ func (s TransactionImpl) Update(user *types.User, input types.Transaction) (*typ
|
|||||||
return transaction, nil
|
return transaction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, error) {
|
func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -213,7 +214,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
var transaction types.Transaction
|
var transaction types.Transaction
|
||||||
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("transaction Get", nil, err)
|
err = db.TransformAndLogDbError("transaction Get", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -225,13 +226,13 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
|
|||||||
return &transaction, nil
|
return &transaction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
|
func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
transactions := make([]*types.Transaction, 0)
|
transactions := make([]*types.Transaction, 0)
|
||||||
err := s.db.Select(&transactions, `
|
err := s.db.SelectContext(ctx, &transactions, `
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM "transaction"
|
FROM "transaction"
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
@@ -254,7 +255,7 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF
|
|||||||
return transactions, nil
|
return transactions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) Delete(user *types.User, id string) error {
|
func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -264,7 +265,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
|||||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transaction Delete", nil, err)
|
err = db.TransformAndLogDbError("transaction Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -274,14 +275,14 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var transaction types.Transaction
|
var transaction types.Transaction
|
||||||
err = tx.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("transaction Delete", nil, err)
|
err = db.TransformAndLogDbError("transaction Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if transaction.Error == nil && transaction.AccountId != nil {
|
if transaction.Error == nil && transaction.AccountId != nil {
|
||||||
r, err := tx.Exec(`
|
r, err := tx.ExecContext(ctx, `
|
||||||
UPDATE account
|
UPDATE account
|
||||||
SET current_balance = current_balance - ?
|
SET current_balance = current_balance - ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
@@ -293,7 +294,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
if transaction.Error == nil && transaction.TreasureChestId != nil {
|
||||||
r, err := tx.Exec(`
|
r, err := tx.ExecContext(ctx, `
|
||||||
UPDATE treasure_chest
|
UPDATE treasure_chest
|
||||||
SET current_balance = current_balance - ?
|
SET current_balance = current_balance - ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
@@ -304,7 +305,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.Exec("DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id)
|
r, err := tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE id = ? AND user_id = ?", uuid, user.Id)
|
||||||
err = db.TransformAndLogDbError("transaction Delete", r, err)
|
err = db.TransformAndLogDbError("transaction Delete", r, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -319,12 +320,12 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
|
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -333,7 +334,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
r, err := tx.Exec(`
|
r, err := tx.ExecContext(ctx, `
|
||||||
UPDATE account
|
UPDATE account
|
||||||
SET current_balance = 0
|
SET current_balance = 0
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
@@ -342,7 +343,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err = tx.Exec(`
|
r, err = tx.ExecContext(ctx, `
|
||||||
UPDATE treasure_chest
|
UPDATE treasure_chest
|
||||||
SET current_balance = 0
|
SET current_balance = 0
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
@@ -351,7 +352,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := tx.Queryx(`
|
rows, err := tx.QueryxContext(ctx, `
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM "transaction"
|
FROM "transaction"
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
@@ -375,7 +376,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.updateErrors(&transaction)
|
s.updateErrors(&transaction)
|
||||||
r, err = tx.Exec(`
|
r, err = tx.ExecContext(ctx, `
|
||||||
UPDATE "transaction"
|
UPDATE "transaction"
|
||||||
SET error = ?
|
SET error = ?
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
@@ -390,7 +391,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if transaction.AccountId != nil {
|
if transaction.AccountId != nil {
|
||||||
r, err = tx.Exec(`
|
r, err = tx.ExecContext(ctx, `
|
||||||
UPDATE account
|
UPDATE account
|
||||||
SET current_balance = current_balance + ?
|
SET current_balance = current_balance + ?
|
||||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||||
@@ -400,7 +401,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if transaction.TreasureChestId != nil {
|
if transaction.TreasureChestId != nil {
|
||||||
r, err = tx.Exec(`
|
r, err = tx.ExecContext(ctx, `
|
||||||
UPDATE treasure_chest
|
UPDATE treasure_chest
|
||||||
SET current_balance = current_balance + ?
|
SET current_balance = current_balance + ?
|
||||||
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||||
@@ -420,7 +421,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) {
|
func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.Transaction) (*types.Transaction, error) {
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
@@ -449,7 +450,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
|||||||
}
|
}
|
||||||
|
|
||||||
if input.AccountId != nil {
|
if input.AccountId != nil {
|
||||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId)
|
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, input.AccountId, userId)
|
||||||
err = db.TransformAndLogDbError("transaction validate", nil, err)
|
err = db.TransformAndLogDbError("transaction validate", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -462,7 +463,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
|||||||
|
|
||||||
if input.TreasureChestId != nil {
|
if input.TreasureChestId != nil {
|
||||||
var treasureChest types.TreasureChest
|
var treasureChest types.TreasureChest
|
||||||
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
|
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
|
||||||
err = db.TransformAndLogDbError("transaction validate", nil, err)
|
err = db.TransformAndLogDbError("transaction validate", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -15,14 +16,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TransactionRecurring interface {
|
type TransactionRecurring interface {
|
||||||
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||||
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||||
GetAll(user *types.User) ([]*types.TransactionRecurring, error)
|
GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error)
|
||||||
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error)
|
GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error)
|
||||||
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||||
Delete(user *types.User, id string) error
|
Delete(ctx context.Context, user *types.User, id string) error
|
||||||
|
|
||||||
GenerateTransactions(user *types.User) error
|
GenerateTransactions(ctx context.Context, user *types.User) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransactionRecurringImpl struct {
|
type TransactionRecurringImpl struct {
|
||||||
@@ -41,7 +42,7 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Add(
|
func (s TransactionRecurringImpl) Add(ctx context.Context,
|
||||||
user *types.User,
|
user *types.User,
|
||||||
transactionRecurringInput types.TransactionRecurringInput,
|
transactionRecurringInput types.TransactionRecurringInput,
|
||||||
) (*types.TransactionRecurring, error) {
|
) (*types.TransactionRecurring, error) {
|
||||||
@@ -49,7 +50,7 @@ func (s TransactionRecurringImpl) Add(
|
|||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Add", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring Add", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -58,12 +59,12 @@ func (s TransactionRecurringImpl) Add(
|
|||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
transactionRecurring, err := s.validateAndEnrichTransactionRecurring(tx, nil, user.Id, transactionRecurringInput)
|
transactionRecurring, err := s.validateAndEnrichTransactionRecurring(ctx, tx, nil, user.Id, transactionRecurringInput)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.NamedExec(`
|
r, err := tx.NamedExecContext(ctx, `
|
||||||
INSERT INTO "transaction_recurring" (id, user_id, interval_months,
|
INSERT INTO "transaction_recurring" (id, user_id, interval_months,
|
||||||
next_execution, party, description, account_id, treasure_chest_id, value, created_at, created_by)
|
next_execution, party, description, account_id, treasure_chest_id, value, created_at, created_by)
|
||||||
VALUES (:id, :user_id, :interval_months,
|
VALUES (:id, :user_id, :interval_months,
|
||||||
@@ -83,7 +84,7 @@ func (s TransactionRecurringImpl) Add(
|
|||||||
return transactionRecurring, nil
|
return transactionRecurring, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Update(
|
func (s TransactionRecurringImpl) Update(ctx context.Context,
|
||||||
user *types.User,
|
user *types.User,
|
||||||
input types.TransactionRecurringInput,
|
input types.TransactionRecurringInput,
|
||||||
) (*types.TransactionRecurring, error) {
|
) (*types.TransactionRecurring, error) {
|
||||||
@@ -96,7 +97,7 @@ func (s TransactionRecurringImpl) Update(
|
|||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -106,7 +107,7 @@ func (s TransactionRecurringImpl) Update(
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
transactionRecurring := &types.TransactionRecurring{}
|
transactionRecurring := &types.TransactionRecurring{}
|
||||||
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -115,12 +116,12 @@ func (s TransactionRecurringImpl) Update(
|
|||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(tx, transactionRecurring, user.Id, input)
|
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.NamedExec(`
|
r, err := tx.NamedExecContext(ctx, `
|
||||||
UPDATE transaction_recurring
|
UPDATE transaction_recurring
|
||||||
SET
|
SET
|
||||||
interval_months = :interval_months,
|
interval_months = :interval_months,
|
||||||
@@ -148,13 +149,13 @@ func (s TransactionRecurringImpl) Update(
|
|||||||
return transactionRecurring, nil
|
return transactionRecurring, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.TransactionRecurring, error) {
|
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||||
err := s.db.Select(&transactionRecurrings, `
|
err := s.db.SelectContext(ctx, &transactionRecurrings, `
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM transaction_recurring
|
FROM transaction_recurring
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
@@ -168,7 +169,7 @@ func (s TransactionRecurringImpl) GetAll(user *types.User) ([]*types.Transaction
|
|||||||
return transactionRecurrings, nil
|
return transactionRecurrings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -179,7 +180,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
|||||||
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -189,7 +190,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var rowCount int
|
var rowCount int
|
||||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -199,7 +200,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
|||||||
}
|
}
|
||||||
|
|
||||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||||
err = tx.Select(&transactionRecurrings, `
|
err = tx.SelectContext(ctx, &transactionRecurrings, `
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM transaction_recurring
|
FROM transaction_recurring
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
@@ -220,7 +221,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
|||||||
return transactionRecurrings, nil
|
return transactionRecurrings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
||||||
user *types.User,
|
user *types.User,
|
||||||
treasureChestId string,
|
treasureChestId string,
|
||||||
) ([]*types.TransactionRecurring, error) {
|
) ([]*types.TransactionRecurring, error) {
|
||||||
@@ -234,7 +235,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
|||||||
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -244,7 +245,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var rowCount int
|
var rowCount int
|
||||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -254,7 +255,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||||
err = tx.Select(&transactionRecurrings, `
|
err = tx.SelectContext(ctx, &transactionRecurrings, `
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM transaction_recurring
|
FROM transaction_recurring
|
||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
@@ -275,7 +276,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(
|
|||||||
return transactionRecurrings, nil
|
return transactionRecurrings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
|
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -285,7 +286,7 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
|
|||||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -295,13 +296,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var transactionRecurring types.TransactionRecurring
|
var transactionRecurring types.TransactionRecurring
|
||||||
err = tx.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.GetContext(ctx, &transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.Exec("DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id)
|
r, err := tx.ExecContext(ctx, "DELETE FROM transaction_recurring WHERE id = ? AND user_id = ?", uuid, user.Id)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Delete", r, err)
|
err = db.TransformAndLogDbError("transactionRecurring Delete", r, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -316,13 +317,13 @@ func (s TransactionRecurringImpl) Delete(user *types.User, id string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
|
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return ErrUnauthorized
|
||||||
}
|
}
|
||||||
now := s.clock.Now()
|
now := s.clock.Now()
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -332,7 +333,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
recurringTransactions := make([]*types.TransactionRecurring, 0)
|
recurringTransactions := make([]*types.TransactionRecurring, 0)
|
||||||
err = tx.Select(&recurringTransactions, `
|
err = tx.SelectContext(ctx, &recurringTransactions, `
|
||||||
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
|
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
|
||||||
user.Id, now)
|
user.Id, now)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
||||||
@@ -350,13 +351,13 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
|
|||||||
Value: transactionRecurring.Value,
|
Value: transactionRecurring.Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.transaction.Add(tx, user, transaction)
|
_, err = s.transaction.Add(ctx, tx, user, transaction)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0)
|
nextExecution := transactionRecurring.NextExecution.AddDate(0, int(transactionRecurring.IntervalMonths), 0)
|
||||||
r, err := tx.Exec(`UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`,
|
r, err := tx.ExecContext(ctx, `UPDATE transaction_recurring SET next_execution = ? WHERE id = ? AND user_id = ?`,
|
||||||
nextExecution, transactionRecurring.Id, user.Id)
|
nextExecution, transactionRecurring.Id, user.Id)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err)
|
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", r, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -373,6 +374,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(user *types.User) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||||
|
ctx context.Context,
|
||||||
tx *sqlx.Tx,
|
tx *sqlx.Tx,
|
||||||
oldTransactionRecurring *types.TransactionRecurring,
|
oldTransactionRecurring *types.TransactionRecurring,
|
||||||
userId uuid.UUID,
|
userId uuid.UUID,
|
||||||
@@ -417,7 +419,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
accountUuid = &temp
|
accountUuid = &temp
|
||||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
|
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -438,7 +440,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
}
|
}
|
||||||
treasureChestUuid = &temp
|
treasureChestUuid = &temp
|
||||||
var treasureChest types.TreasureChest
|
var treasureChest types.TreasureChest
|
||||||
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -13,11 +14,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TreasureChest interface {
|
type TreasureChest interface {
|
||||||
Add(user *types.User, parentId, name string) (*types.TreasureChest, error)
|
Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error)
|
||||||
Update(user *types.User, id, parentId, name string) (*types.TreasureChest, error)
|
Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error)
|
||||||
Get(user *types.User, id string) (*types.TreasureChest, error)
|
Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error)
|
||||||
GetAll(user *types.User) ([]*types.TreasureChest, error)
|
GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error)
|
||||||
Delete(user *types.User, id string) error
|
Delete(ctx context.Context, user *types.User, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TreasureChestImpl struct {
|
type TreasureChestImpl struct {
|
||||||
@@ -34,7 +35,7 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.TreasureChest, error) {
|
func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -51,7 +52,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
|
|||||||
|
|
||||||
var parentUuid *uuid.UUID
|
var parentUuid *uuid.UUID
|
||||||
if parentId != "" {
|
if parentId != "" {
|
||||||
parent, err := s.Get(user, parentId)
|
parent, err := s.Get(ctx, user, parentId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -76,7 +77,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
|
|||||||
UpdatedBy: nil,
|
UpdatedBy: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := s.db.NamedExec(`
|
r, err := s.db.NamedExecContext(ctx, `
|
||||||
INSERT INTO treasure_chest (id, parent_id, user_id, name, current_balance, created_at, created_by)
|
INSERT INTO treasure_chest (id, parent_id, user_id, name, current_balance, created_at, created_by)
|
||||||
VALUES (:id, :parent_id, :user_id, :name, :current_balance, :created_at, :created_by)`, treasureChest)
|
VALUES (:id, :parent_id, :user_id, :name, :current_balance, :created_at, :created_by)`, treasureChest)
|
||||||
err = db.TransformAndLogDbError("treasureChest Insert", r, err)
|
err = db.TransformAndLogDbError("treasureChest Insert", r, err)
|
||||||
@@ -87,7 +88,7 @@ func (s TreasureChestImpl) Add(user *types.User, parentId, name string) (*types.
|
|||||||
return treasureChest, nil
|
return treasureChest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
|
func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -101,7 +102,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
|||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -111,7 +112,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
treasureChest := &types.TreasureChest{}
|
treasureChest := &types.TreasureChest{}
|
||||||
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -122,12 +123,12 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
|||||||
|
|
||||||
var parentUuid *uuid.UUID
|
var parentUuid *uuid.UUID
|
||||||
if parentId != "" {
|
if parentId != "" {
|
||||||
parent, err := s.Get(user, parentId)
|
parent, err := s.Get(ctx, user, parentId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var childCount int
|
var childCount int
|
||||||
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
|
err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -145,7 +146,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
|||||||
treasureChest.UpdatedAt = ×tamp
|
treasureChest.UpdatedAt = ×tamp
|
||||||
treasureChest.UpdatedBy = &user.Id
|
treasureChest.UpdatedBy = &user.Id
|
||||||
|
|
||||||
r, err := tx.NamedExec(`
|
r, err := tx.NamedExecContext(ctx, `
|
||||||
UPDATE treasure_chest
|
UPDATE treasure_chest
|
||||||
SET
|
SET
|
||||||
parent_id = :parent_id,
|
parent_id = :parent_id,
|
||||||
@@ -169,7 +170,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
|||||||
return treasureChest, nil
|
return treasureChest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChest, error) {
|
func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -180,7 +181,7 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
|
|||||||
}
|
}
|
||||||
|
|
||||||
var treasureChest types.TreasureChest
|
var treasureChest types.TreasureChest
|
||||||
err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("treasureChest Get", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Get", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -192,13 +193,13 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
|
|||||||
return &treasureChest, nil
|
return &treasureChest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) {
|
func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChests := make([]*types.TreasureChest, 0)
|
treasureChests := make([]*types.TreasureChest, 0)
|
||||||
err := s.db.Select(&treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id)
|
err := s.db.SelectContext(ctx, &treasureChests, `SELECT * FROM treasure_chest WHERE user_id = ?`, user.Id)
|
||||||
err = db.TransformAndLogDbError("treasureChest GetAll", nil, err)
|
err = db.TransformAndLogDbError("treasureChest GetAll", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -207,7 +208,7 @@ func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, err
|
|||||||
return sortTree(treasureChests), nil
|
return sortTree(treasureChests), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return ErrUnauthorized
|
||||||
}
|
}
|
||||||
@@ -217,7 +218,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
|||||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.Beginx()
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -227,7 +228,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
childCount := 0
|
childCount := 0
|
||||||
err = tx.Get(&childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
|
err = tx.GetContext(ctx, &childCount, `SELECT COUNT(*) FROM treasure_chest WHERE user_id = ? AND parent_id = ?`, user.Id, id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -238,7 +239,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
transactionsCount := 0
|
transactionsCount := 0
|
||||||
err = tx.Get(&transactionsCount,
|
err = tx.GetContext(ctx, &transactionsCount,
|
||||||
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
|
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
|
||||||
user.Id, id)
|
user.Id, id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||||
@@ -250,7 +251,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
recurringCount := 0
|
recurringCount := 0
|
||||||
err = tx.Get(&recurringCount, `
|
err = tx.GetContext(ctx, &recurringCount, `
|
||||||
SELECT COUNT(*) FROM transaction_recurring WHERE user_id = ? AND treasure_chest_id = ?`,
|
SELECT COUNT(*) FROM transaction_recurring WHERE user_id = ? AND treasure_chest_id = ?`,
|
||||||
user.Id, id)
|
user.Id, id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||||
@@ -261,7 +262,7 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
|||||||
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest)
|
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.Exec(`DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Delete", r, err)
|
err = db.TransformAndLogDbError("treasureChest Delete", r, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func DoRedirect(w http.ResponseWriter, r *http.Request, url string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WaitMinimumTime[T interface{}](waitTime time.Duration, f func() (T, error)) (T, error) {
|
func WaitMinimumTime[T any](waitTime time.Duration, f func() (T, error)) (T, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
result, err := f()
|
result, err := f()
|
||||||
time.Sleep(waitTime - time.Since(start))
|
time.Sleep(waitTime - time.Since(start))
|
||||||
|
|||||||
7
main.go
7
main.go
@@ -6,9 +6,11 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"spend-sparrow/internal"
|
"spend-sparrow/internal"
|
||||||
|
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"github.com/uptrace/opentelemetry-go-extra/otelsql"
|
||||||
|
"github.com/uptrace/opentelemetry-go-extra/otelsqlx"
|
||||||
|
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -18,7 +20,8 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := sqlx.Open("sqlite3", "./data/spend-sparrow.db")
|
db, err := otelsqlx.Open("sqlite3", "./data/spend-sparrow.db",
|
||||||
|
otelsql.WithAttributes(semconv.DBSystemSqlite))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Could not open Database data.db", "err", err)
|
slog.Error("Could not open Database data.db", "err", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package test_test
|
package test_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -26,7 +27,7 @@ func setupDb(t *testing.T) *sqlx.DB {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
err = db.RunMigrations(d, "../")
|
err = db.RunMigrations(context.Background(), d, "../")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error running migrations: %v", err)
|
t.Fatalf("Error running migrations: %v", err)
|
||||||
}
|
}
|
||||||
@@ -47,14 +48,14 @@ func TestUser(t *testing.T) {
|
|||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(expected)
|
err := underTest.InsertUser(context.Background(), expected)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := underTest.GetUser(expected.Id)
|
actual, err := underTest.GetUser(context.Background(), expected.Id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
|
|
||||||
actual, err = underTest.GetUserByEmail(expected.Email)
|
actual, err = underTest.GetUserByEmail(context.Background(), expected.Email)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
})
|
})
|
||||||
@@ -64,7 +65,7 @@ func TestUser(t *testing.T) {
|
|||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
_, err := underTest.GetUserByEmail("nonExistentEmail")
|
_, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail")
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrUserExist", func(t *testing.T) {
|
t.Run("should return ErrUserExist", func(t *testing.T) {
|
||||||
@@ -77,10 +78,10 @@ func TestUser(t *testing.T) {
|
|||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
err := underTest.InsertUser(context.Background(), user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = underTest.InsertUser(user)
|
err = underTest.InsertUser(context.Background(), user)
|
||||||
assert.Equal(t, db.ErrAlreadyExists, err)
|
assert.Equal(t, db.ErrAlreadyExists, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
||||||
@@ -92,7 +93,7 @@ func TestUser(t *testing.T) {
|
|||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
err := underTest.InsertUser(context.Background(), user)
|
||||||
assert.Equal(t, types.ErrInternal, err)
|
assert.Equal(t, types.ErrInternal, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -110,21 +111,21 @@ func TestToken(t *testing.T) {
|
|||||||
expiresAt := createAt.Add(24 * time.Hour)
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
|
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
|
||||||
err := underTest.InsertToken(expected)
|
err := underTest.InsertToken(context.Background(), expected)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := underTest.GetToken(expected.Token)
|
actual, err := underTest.GetToken(context.Background(), expected.Token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
|
|
||||||
expected.SessionId = ""
|
expected.SessionId = ""
|
||||||
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type)
|
actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected}, actuals)
|
assert.Equal(t, []*types.Token{expected}, actuals)
|
||||||
|
|
||||||
expected.SessionId = "sessionId"
|
expected.SessionId = "sessionId"
|
||||||
expected.UserId = uuid.Nil
|
expected.UserId = uuid.Nil
|
||||||
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type)
|
actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected}, actuals)
|
assert.Equal(t, []*types.Token{expected}, actuals)
|
||||||
})
|
})
|
||||||
@@ -140,14 +141,14 @@ func TestToken(t *testing.T) {
|
|||||||
expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt)
|
expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
|
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
|
||||||
err := underTest.InsertToken(expected1)
|
err := underTest.InsertToken(context.Background(), expected1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = underTest.InsertToken(expected2)
|
err = underTest.InsertToken(context.Background(), expected2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected1.UserId = uuid.Nil
|
expected1.UserId = uuid.Nil
|
||||||
expected2.UserId = uuid.Nil
|
expected2.UserId = uuid.Nil
|
||||||
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type)
|
actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
||||||
|
|
||||||
@@ -155,7 +156,7 @@ func TestToken(t *testing.T) {
|
|||||||
expected2.SessionId = ""
|
expected2.SessionId = ""
|
||||||
expected1.UserId = userId
|
expected1.UserId = userId
|
||||||
expected2.UserId = userId
|
expected2.UserId = userId
|
||||||
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type)
|
actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
||||||
})
|
})
|
||||||
@@ -165,13 +166,13 @@ func TestToken(t *testing.T) {
|
|||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
_, err := underTest.GetToken("nonExistent")
|
_, err := underTest.GetToken(context.Background(), "nonExistent")
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
|
|
||||||
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify)
|
_, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), types.TokenTypeEmailVerify)
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
|
|
||||||
_, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify)
|
_, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", types.TokenTypeEmailVerify)
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
|
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
|
||||||
@@ -184,10 +185,10 @@ func TestToken(t *testing.T) {
|
|||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
err := underTest.InsertUser(context.Background(), user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = underTest.InsertUser(user)
|
err = underTest.InsertUser(context.Background(), user)
|
||||||
assert.Equal(t, db.ErrAlreadyExists, err)
|
assert.Equal(t, db.ErrAlreadyExists, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
||||||
@@ -199,7 +200,7 @@ func TestToken(t *testing.T) {
|
|||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
err := underTest.InsertUser(context.Background(), user)
|
||||||
assert.Equal(t, types.ErrInternal, err)
|
assert.Equal(t, types.ErrInternal, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package test_test
|
package test_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/service"
|
"spend-sparrow/internal/service"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
@@ -36,7 +37,7 @@ func TestSignUp(t *testing.T) {
|
|||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
_, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!")
|
_, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!")
|
||||||
|
|
||||||
assert.Equal(t, service.ErrInvalidEmail, err)
|
assert.Equal(t, service.ErrInvalidEmail, err)
|
||||||
})
|
})
|
||||||
@@ -58,7 +59,7 @@ func TestSignUp(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, password := range weakPasswords {
|
for _, password := range weakPasswords {
|
||||||
_, err := underTest.SignUp("some@valid.email", password)
|
_, err := underTest.SignUp(context.Background(), "some@valid.email", password)
|
||||||
assert.Equal(t, service.ErrInvalidPassword, err)
|
assert.Equal(t, service.ErrInvalidPassword, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -81,10 +82,10 @@ func TestSignUp(t *testing.T) {
|
|||||||
mockRandom.EXPECT().UUID().Return(userId, nil)
|
mockRandom.EXPECT().UUID().Return(userId, nil)
|
||||||
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
||||||
mockClock.EXPECT().Now().Return(createTime)
|
mockClock.EXPECT().Now().Return(createTime)
|
||||||
mockAuthDb.EXPECT().InsertUser(expected).Return(nil)
|
mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil)
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
actual, err := underTest.SignUp(email, password)
|
actual, err := underTest.SignUp(context.Background(), email, password)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -109,11 +110,11 @@ func TestSignUp(t *testing.T) {
|
|||||||
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
||||||
mockClock.EXPECT().Now().Return(createTime)
|
mockClock.EXPECT().Now().Return(createTime)
|
||||||
|
|
||||||
mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists)
|
mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(db.ErrAlreadyExists)
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
_, err := underTest.SignUp(user.Email, password)
|
_, err := underTest.SignUp(context.Background(), user.Email, password)
|
||||||
assert.Equal(t, service.ErrAccountExists, err)
|
assert.Equal(t, service.ErrAccountExists, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -140,7 +141,7 @@ func TestSendVerificationMail(t *testing.T) {
|
|||||||
mockClock := mocks.NewMockClock(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
mockMail := mocks.NewMockMail(t)
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify).Return(tokens, nil)
|
mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, types.TokenTypeEmailVerify).Return(tokens, nil)
|
||||||
|
|
||||||
mockMail.EXPECT().SendMail(email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
|
mockMail.EXPECT().SendMail(email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
|
||||||
return strings.Contains(message, token.Token)
|
return strings.Contains(message, token.Token)
|
||||||
@@ -148,6 +149,6 @@ func TestSendVerificationMail(t *testing.T) {
|
|||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
underTest.SendVerificationMail(userId, email)
|
underTest.SendVerificationMail(context.Background(), userId, email)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -182,16 +182,16 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s
|
|||||||
csrfToken := "my-verifying-token" + add
|
csrfToken := "my-verifying-token" + add
|
||||||
email := add + "mail@mail.de"
|
email := add + "mail@mail.de"
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(context.Background(), `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt"))
|
VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(context.Background(), `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(context.Background(), `
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf)
|
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -112,11 +112,11 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -138,7 +138,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -165,7 +165,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -208,7 +208,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -248,7 +248,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -296,7 +296,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -415,7 +415,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -451,10 +451,10 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
})
|
})
|
||||||
@@ -469,11 +469,11 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -548,7 +548,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -627,11 +627,11 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Contains(t, resp.Header.Get("Hx-Trigger"), "An activation link has been send to your email")
|
assert.Contains(t, resp.Header.Get("Hx-Trigger"), "An activation link has been send to your email")
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE email = ? AND email_verified = FALSE", "mail@mail.de").Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE email = ? AND email_verified = FALSE", "mail@mail.de").Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
var token string
|
var token string
|
||||||
err = db.QueryRow("SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token)
|
err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, token)
|
assert.NotEmpty(t, token)
|
||||||
})
|
})
|
||||||
@@ -644,7 +644,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -658,7 +658,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -670,11 +670,11 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
token := "my-outdated-verifying-token"
|
token := "my-outdated-verifying-token"
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO token (token, user_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify)
|
VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -688,7 +688,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -700,11 +700,11 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
token := "my-verifying-token"
|
token := "my-verifying-token"
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify)
|
VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -718,7 +718,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = TRUE", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = TRUE", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -747,11 +747,11 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -765,7 +765,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var csrfToken string
|
var csrfToken string
|
||||||
err = db.QueryRow("SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken)
|
err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil)
|
req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil)
|
||||||
@@ -785,7 +785,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, -1, cookie.MaxAge)
|
assert.Equal(t, -1, cookie.MaxAge)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
})
|
})
|
||||||
@@ -825,13 +825,13 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
|
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -871,13 +871,13 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
|
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -964,22 +964,22 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ?", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ?", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ?", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ?", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM account WHERE user_id = ?", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM account WHERE user_id = ?", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM treasure_chest WHERE user_id = ?", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM treasure_chest WHERE user_id = ?", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM \"transaction\" WHERE user_id = ?", userId).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM \"transaction\" WHERE user_id = ?", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
})
|
})
|
||||||
@@ -1040,13 +1040,13 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
|
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1069,7 +1069,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -1080,13 +1080,13 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
|
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1119,7 +1119,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -1130,13 +1130,13 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
|
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1169,7 +1169,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -1181,21 +1181,21 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userIdOther := uuid.New()
|
userIdOther := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
|
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES ("second", ?, datetime(), datetime("now", "+1 day"))`, userId)
|
VALUES ("second", ?, datetime(), datetime("now", "+1 day"))`, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec(`
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES ("other", ?, datetime(), datetime("now", "+1 day"))`, userIdOther)
|
VALUES ("other", ?, datetime(), datetime("now", "+1 day"))`, userIdOther)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1232,12 +1232,12 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt"))
|
pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt"))
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
|
|
||||||
var sessionIds []string
|
var sessionIds []string
|
||||||
sessions, err := db.Query(`SELECT session_id FROM session WHERE NOT user_id = ? ORDER BY session_id`, uuid.Nil)
|
sessions, err := db.QueryContext(ctx, `SELECT session_id FROM session WHERE NOT user_id = ? ORDER BY session_id`, uuid.Nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
for sessions.Next() {
|
for sessions.Next() {
|
||||||
var sessionId string
|
var sessionId string
|
||||||
@@ -1260,13 +1260,13 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.Exec(`
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
_, err = d.Exec(`
|
_, err = d.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId)
|
VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1288,7 +1288,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.Exec(`
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1317,7 +1317,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = d.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows)
|
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
})
|
})
|
||||||
@@ -1363,7 +1363,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.Exec(`
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1399,7 +1399,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg)
|
assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -1413,7 +1413,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.Exec(`
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1445,7 +1445,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -1456,7 +1456,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.Exec(`
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1473,7 +1473,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.NotEmpty(t, anonymousCsrfToken)
|
assert.NotEmpty(t, anonymousCsrfToken)
|
||||||
|
|
||||||
token := "password-reset-token"
|
token := "password-reset-token"
|
||||||
_, err = d.Exec(`
|
_, err = d.ExecContext(ctx, `
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset)
|
VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1494,7 +1494,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -1505,7 +1505,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.Exec(`
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1522,7 +1522,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.NotEmpty(t, anonymousCsrfToken)
|
assert.NotEmpty(t, anonymousCsrfToken)
|
||||||
|
|
||||||
token := "password-reset-token"
|
token := "password-reset-token"
|
||||||
_, err = d.Exec(`
|
_, err = d.ExecContext(ctx, `
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset)
|
VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1543,7 +1543,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -1554,12 +1554,12 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.Exec(`
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = d.Exec(`
|
_, err = d.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId)
|
VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1590,7 +1590,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var token string
|
var token string
|
||||||
err = d.QueryRow("SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token)
|
err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
formData = url.Values{
|
formData = url.Values{
|
||||||
@@ -1608,7 +1608,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
sessions, err := d.Query("SELECT session_id FROM session WHERE user_id = ?", userId)
|
sessions, err := d.QueryContext(ctx, "SELECT session_id FROM session WHERE user_id = ?", userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, sessions.Next())
|
assert.False(t, sessions.Next())
|
||||||
})
|
})
|
||||||
@@ -1623,11 +1623,11 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
|
|
||||||
_, err := d.Exec(`
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = d.Exec(`
|
_, err = d.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
|
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1643,7 +1643,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.NotEqual(t, sessionId, newSession.Value)
|
assert.NotEqual(t, sessionId, newSession.Value)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = d.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
|
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
})
|
})
|
||||||
@@ -1670,11 +1670,11 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
|
|
||||||
_, err := d.Exec(`
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = d.Exec(`
|
_, err = d.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
|
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1769,7 +1769,7 @@ func TestIntegrationAccount(t *testing.T) {
|
|||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
var id uuid.UUID
|
var id uuid.UUID
|
||||||
err = db.Get(&id, "SELECT id FROM account")
|
err = db.GetContext(ctx, &id, "SELECT id FROM account")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Update
|
// Update
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func TestTreasureChestShouldNotDeleteIfTransactionRecurringExists(t *testing.T)
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var parentId string
|
var parentId string
|
||||||
err := db.Get(&parentId, "SELECT id FROM treasure_chest")
|
err := db.GetContext(ctx, &parentId, "SELECT id FROM treasure_chest")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
formData = url.Values{
|
formData = url.Values{
|
||||||
@@ -33,7 +33,7 @@ func TestTreasureChestShouldNotDeleteIfTransactionRecurringExists(t *testing.T)
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var childId string
|
var childId string
|
||||||
err = db.Get(&childId, "SELECT id FROM treasure_chest WHERE parent_id = ?", parentId)
|
err = db.GetContext(ctx, &childId, "SELECT id FROM treasure_chest WHERE parent_id = ?", parentId)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
formData = url.Values{
|
formData = url.Values{
|
||||||
|
|||||||
Reference in New Issue
Block a user