feat: extract authentication to domain package #393

Merged
tim merged 1 commits from restructure-auth-to-domain-package into prod 2025-12-25 06:48:16 +00:00
43 changed files with 559 additions and 558 deletions

View File

@@ -3,11 +3,11 @@ dir: mocks/
outpkg: mocks outpkg: mocks
issue-845-fix: True issue-845-fix: True
packages: packages:
spend-sparrow/internal/service: spend-sparrow/internal/core:
interfaces: interfaces:
Random: Random:
Clock: Clock:
Mail: Mail:
spend-sparrow/internal/db: spend-sparrow/internal/authentication:
interfaces: interfaces:
Auth: Db:

View File

@@ -4,29 +4,31 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"log/slog" "log/slog"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/service" "spend-sparrow/internal/service"
"spend-sparrow/internal/types"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
) )
type Service interface { type Service interface {
Add(ctx context.Context, user *types.User, name string) (*Account, error) Add(ctx context.Context, user *auth_types.User, name string) (*Account, error)
UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error) UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error)
Get(ctx context.Context, user *types.User, id string) (*Account, error) Get(ctx context.Context, user *auth_types.User, id string) (*Account, error)
GetAll(ctx context.Context, user *types.User) ([]*Account, error) GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error)
Delete(ctx context.Context, user *types.User, id string) error Delete(ctx context.Context, user *auth_types.User, id string) error
} }
type ServiceImpl struct { type ServiceImpl struct {
db *sqlx.DB db *sqlx.DB
clock service.Clock clock core.Clock
random service.Random random core.Random
} }
func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Service { func NewServiceImpl(db *sqlx.DB, random core.Random, clock core.Clock) Service {
return ServiceImpl{ return ServiceImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -34,14 +36,14 @@ func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Ser
} }
} }
func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*Account, error) { func (s ServiceImpl) Add(ctx context.Context, user *auth_types.User, name string) (*Account, error) {
if user == nil { if user == nil {
return nil, types.ErrUnauthorized return nil, core.ErrUnauthorized
} }
newId, err := s.random.UUID(ctx) newId, err := s.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
err = service.ValidateString(name, "name") err = service.ValidateString(name, "name")
@@ -76,9 +78,9 @@ func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*A
return account, nil return account, nil
} }
func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error) { func (s ServiceImpl) UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error) {
if user == nil { if user == nil {
return nil, types.ErrUnauthorized return nil, core.ErrUnauthorized
} }
err := service.ValidateString(name, "name") err := service.ValidateString(name, "name")
if err != nil { if err != nil {
@@ -87,7 +89,7 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "account update", "err", err) slog.ErrorContext(ctx, "account update", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -103,10 +105,10 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError(ctx, "account Update", nil, err) err = db.TransformAndLogDbError(ctx, "account Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", id, service.ErrBadRequest) return nil, fmt.Errorf("account %v not found: %w", id, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
timestamp := s.clock.Now() timestamp := s.clock.Now()
@@ -136,14 +138,14 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
return &account, nil return &account, nil
} }
func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Account, error) { func (s ServiceImpl) Get(ctx context.Context, user *auth_types.User, id string) (*Account, error) {
if user == nil { if user == nil {
return nil, service.ErrUnauthorized return nil, core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "account get", "err", err) slog.ErrorContext(ctx, "account get", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
var account Account var account Account
@@ -158,9 +160,9 @@ func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Acc
return &account, nil return &account, nil
} }
func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account, error) { func (s ServiceImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error) {
if user == nil { if user == nil {
return nil, service.ErrUnauthorized return nil, core.ErrUnauthorized
} }
accounts := make([]*Account, 0) accounts := make([]*Account, 0)
@@ -174,14 +176,14 @@ func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account,
return accounts, nil return accounts, nil
} }
func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) error { func (s ServiceImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
if user == nil { if user == nil {
return service.ErrUnauthorized return core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "account delete", "err", err) slog.ErrorContext(ctx, "account delete", "err", err)
return fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -200,7 +202,7 @@ func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) er
return err return err
} }
if transactionsCount > 0 { if transactionsCount > 0 {
return fmt.Errorf("account has transactions, cannot delete: %w", service.ErrBadRequest) return fmt.Errorf("account has transactions, cannot delete: %w", core.ErrBadRequest)
} }
res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id) res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)

View File

@@ -1,4 +1,4 @@
package types package auth_types
import ( import (
"time" "time"

View File

@@ -1,11 +1,12 @@
package db package authentication
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"log/slog" "log/slog"
"spend-sparrow/internal/types" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"strings" "strings"
"time" "time"
@@ -13,36 +14,36 @@ import (
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
) )
type Auth interface { type Db interface {
InsertUser(ctx context.Context, user *types.User) error InsertUser(ctx context.Context, user *auth_types.User) error
UpdateUser(ctx context.Context, user *types.User) error UpdateUser(ctx context.Context, user *auth_types.User) error
GetUserByEmail(ctx context.Context, email string) (*types.User, error) GetUserByEmail(ctx context.Context, email string) (*auth_types.User, error)
GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) GetUser(ctx context.Context, userId uuid.UUID) (*auth_types.User, error)
DeleteUser(ctx context.Context, userId uuid.UUID) error DeleteUser(ctx context.Context, userId uuid.UUID) error
InsertToken(ctx context.Context, token *types.Token) error InsertToken(ctx context.Context, token *auth_types.Token) error
GetToken(ctx context.Context, token string) (*types.Token, error) GetToken(ctx context.Context, token string) (*auth_types.Token, error)
GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType auth_types.TokenType) ([]*auth_types.Token, error)
GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error)
DeleteToken(ctx context.Context, token string) error DeleteToken(ctx context.Context, token string) error
InsertSession(ctx context.Context, session *types.Session) error InsertSession(ctx context.Context, session *auth_types.Session) error
GetSession(ctx context.Context, sessionId string) (*types.Session, error) GetSession(ctx context.Context, sessionId string) (*auth_types.Session, error)
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) GetSessions(ctx context.Context, userId uuid.UUID) ([]*auth_types.Session, error)
DeleteSession(ctx context.Context, sessionId string) error DeleteSession(ctx context.Context, sessionId string) error
DeleteOldSessions(ctx context.Context) error DeleteOldSessions(ctx context.Context) error
DeleteOldTokens(ctx context.Context) error DeleteOldTokens(ctx context.Context) error
} }
type AuthSqlite struct { type DbSqlite struct {
db *sqlx.DB db *sqlx.DB
} }
func NewAuthSqlite(db *sqlx.DB) *AuthSqlite { func NewDbSqlite(db *sqlx.DB) *DbSqlite {
return &AuthSqlite{db: db} return &DbSqlite{db: db}
} }
func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error { func (db DbSqlite) InsertUser(ctx context.Context, user *auth_types.User) error {
_, err := db.db.ExecContext(ctx, ` _, err := db.db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
@@ -50,17 +51,17 @@ func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error {
if err != nil { if err != nil {
if strings.Contains(err.Error(), "email") { if strings.Contains(err.Error(), "email") {
return ErrAlreadyExists return core.ErrAlreadyExists
} }
slog.ErrorContext(ctx, "SQL error InsertUser", "err", err) slog.ErrorContext(ctx, "SQL error InsertUser", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error { func (db DbSqlite) UpdateUser(ctx context.Context, user *auth_types.User) error {
_, err := db.db.ExecContext(ctx, ` _, err := db.db.ExecContext(ctx, `
UPDATE user UPDATE user
SET email_verified = ?, email_verified_at = ?, password = ? SET email_verified = ?, email_verified_at = ?, password = ?
@@ -69,13 +70,13 @@ func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error {
if err != nil { if err != nil {
slog.ErrorContext(ctx, "SQL error UpdateUser", "err", err) slog.ErrorContext(ctx, "SQL error UpdateUser", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.User, error) { func (db DbSqlite) GetUserByEmail(ctx context.Context, email string) (*auth_types.User, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
emailVerified bool emailVerified bool
@@ -92,17 +93,17 @@ func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.U
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, core.ErrNotFound
} else { } else {
slog.ErrorContext(ctx, "SQL error GetUser", "err", err) slog.ErrorContext(ctx, "SQL error GetUser", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
} }
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil return auth_types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
} }
func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) { func (db DbSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*auth_types.User, error) {
var ( var (
email string email string
emailVerified bool emailVerified bool
@@ -119,92 +120,92 @@ func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, core.ErrNotFound
} else { } else {
slog.ErrorContext(ctx, "SQL error GetUser", "err", err) slog.ErrorContext(ctx, "SQL error GetUser", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
} }
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil return auth_types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
} }
func (db AuthSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error { func (db DbSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error {
tx, err := db.db.BeginTx(ctx, nil) tx, err := db.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not start transaction", "err", err) slog.ErrorContext(ctx, "Could not start transaction", "err", err)
return types.ErrInternal return core.ErrInternal
} }
_, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.ErrorContext(ctx, "Could not delete accounts", "err", err) slog.ErrorContext(ctx, "Could not delete accounts", "err", err)
return types.ErrInternal return core.ErrInternal
} }
_, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.ErrorContext(ctx, "Could not delete user tokens", "err", err) slog.ErrorContext(ctx, "Could not delete user tokens", "err", err)
return types.ErrInternal return core.ErrInternal
} }
_, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.ErrorContext(ctx, "Could not delete sessions", "err", err) slog.ErrorContext(ctx, "Could not delete sessions", "err", err)
return types.ErrInternal return core.ErrInternal
} }
_, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.ErrorContext(ctx, "Could not delete user", "err", err) slog.ErrorContext(ctx, "Could not delete user", "err", err)
return types.ErrInternal return core.ErrInternal
} }
_, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.ErrorContext(ctx, "Could not delete user", "err", err) slog.ErrorContext(ctx, "Could not delete user", "err", err)
return types.ErrInternal return core.ErrInternal
} }
_, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId) _, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
slog.ErrorContext(ctx, "Could not delete user", "err", err) slog.ErrorContext(ctx, "Could not delete user", "err", err)
return types.ErrInternal return core.ErrInternal
} }
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not commit transaction", "err", err) slog.ErrorContext(ctx, "Could not commit transaction", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (db AuthSqlite) InsertToken(ctx context.Context, token *types.Token) error { func (db DbSqlite) InsertToken(ctx context.Context, token *auth_types.Token) error {
_, err := db.db.ExecContext(ctx, ` _, err := db.db.ExecContext(ctx, `
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at) INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt) VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not insert token", "err", err) slog.ErrorContext(ctx, "Could not insert token", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, error) { func (db DbSqlite) GetToken(ctx context.Context, token string) (*auth_types.Token, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
sessionId string sessionId string
tokenType types.TokenType tokenType auth_types.TokenType
createdAtStr string createdAtStr string
expiresAtStr string expiresAtStr string
createdAt time.Time createdAt time.Time
@@ -219,29 +220,29 @@ func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token,
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
slog.InfoContext(ctx, "Token not found", "token", token) slog.InfoContext(ctx, "Token not found", "token", token)
return nil, ErrNotFound return nil, core.ErrNotFound
} else { } else {
slog.ErrorContext(ctx, "Could not get token", "err", err) slog.ErrorContext(ctx, "Could not get token", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
} }
createdAt, err = time.Parse(time.RFC3339, createdAtStr) createdAt, err = time.Parse(time.RFC3339, createdAtStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err) slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
expiresAt, err = time.Parse(time.RFC3339, expiresAtStr) expiresAt, err = time.Parse(time.RFC3339, expiresAtStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err) slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil return auth_types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
} }
func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) { func (db DbSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType auth_types.TokenType) ([]*auth_types.Token, error) {
query, err := db.db.QueryContext(ctx, ` query, err := db.db.QueryContext(ctx, `
SELECT token, created_at, expires_at SELECT token, created_at, expires_at
FROM token FROM token
@@ -250,13 +251,13 @@ func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.U
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not get token", "err", err) slog.ErrorContext(ctx, "Could not get token", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return getTokensFromQuery(ctx, query, userId, "", tokenType) return getTokensFromQuery(ctx, query, userId, "", tokenType)
} }
func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) { func (db DbSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error) {
query, err := db.db.QueryContext(ctx, ` query, err := db.db.QueryContext(ctx, `
SELECT token, created_at, expires_at SELECT token, created_at, expires_at
FROM token FROM token
@@ -265,14 +266,14 @@ func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not get token", "err", err) slog.ErrorContext(ctx, "Could not get token", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return getTokensFromQuery(ctx, query, uuid.Nil, sessionId, tokenType) return getTokensFromQuery(ctx, query, uuid.Nil, sessionId, tokenType)
} }
func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID, sessionId string, tokenType types.TokenType) ([]*types.Token, error) { func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error) {
var tokens []*types.Token var tokens []*auth_types.Token
hasRows := false hasRows := false
for query.Next() { for query.Next() {
@@ -289,54 +290,54 @@ func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID,
err := query.Scan(&token, &createdAtStr, &expiresAtStr) err := query.Scan(&token, &createdAtStr, &expiresAtStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not scan token", "err", err) slog.ErrorContext(ctx, "Could not scan token", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
createdAt, err = time.Parse(time.RFC3339, createdAtStr) createdAt, err = time.Parse(time.RFC3339, createdAtStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err) slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
expiresAt, err = time.Parse(time.RFC3339, expiresAtStr) expiresAt, err = time.Parse(time.RFC3339, expiresAtStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err) slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
tokens = append(tokens, types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt)) tokens = append(tokens, auth_types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt))
} }
if !hasRows { if !hasRows {
return nil, ErrNotFound return nil, core.ErrNotFound
} }
return tokens, nil return tokens, nil
} }
func (db AuthSqlite) DeleteToken(ctx context.Context, token string) error { func (db DbSqlite) DeleteToken(ctx context.Context, token string) error {
_, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token) _, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not delete token", "err", err) slog.ErrorContext(ctx, "Could not delete token", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (db AuthSqlite) InsertSession(ctx context.Context, session *types.Session) error { func (db DbSqlite) InsertSession(ctx context.Context, session *auth_types.Session) error {
_, err := db.db.ExecContext(ctx, ` _, err := db.db.ExecContext(ctx, `
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt) VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not insert new session", "err", err) slog.ErrorContext(ctx, "Could not insert new session", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.Session, error) { func (db DbSqlite) GetSession(ctx context.Context, sessionId string) (*auth_types.Session, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
createdAt time.Time createdAt time.Time
@@ -350,56 +351,56 @@ func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.S
if err != nil { if err != nil {
slog.WarnContext(ctx, "Session not found", "session-id", sessionId, "err", err) slog.WarnContext(ctx, "Session not found", "session-id", sessionId, "err", err)
return nil, ErrNotFound return nil, core.ErrNotFound
} }
return types.NewSession(sessionId, userId, createdAt, expiresAt), nil return auth_types.NewSession(sessionId, userId, createdAt, expiresAt), nil
} }
func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) { func (db DbSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*auth_types.Session, error) {
var sessions []*types.Session var sessions []*auth_types.Session
err := db.db.SelectContext(ctx, &sessions, ` err := db.db.SelectContext(ctx, &sessions, `
SELECT * SELECT *
FROM session FROM session
WHERE user_id = ?`, userId) WHERE user_id = ?`, userId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not get sessions", "err", err) slog.ErrorContext(ctx, "Could not get sessions", "err", err)
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return sessions, nil return sessions, nil
} }
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error { func (db DbSqlite) DeleteSession(ctx context.Context, sessionId string) error {
if sessionId != "" { if sessionId != "" {
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId) _, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not delete session", "err", err) slog.ErrorContext(ctx, "Could not delete session", "err", err)
return types.ErrInternal return core.ErrInternal
} }
} }
return nil return nil
} }
func (db AuthSqlite) DeleteOldSessions(ctx context.Context) error { func (db DbSqlite) DeleteOldSessions(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, ` _, err := db.db.ExecContext(ctx, `
DELETE FROM session DELETE FROM session
WHERE expires_at < datetime('now')`) WHERE expires_at < datetime('now')`)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not delete old sessions", "err", err) slog.ErrorContext(ctx, "Could not delete old sessions", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (db AuthSqlite) DeleteOldTokens(ctx context.Context) error { func (db DbSqlite) DeleteOldTokens(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, ` _, err := db.db.ExecContext(ctx, `
DELETE FROM token DELETE FROM token
WHERE expires_at < datetime('now')`) WHERE expires_at < datetime('now')`)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not delete old tokens", "err", err) slog.ErrorContext(ctx, "Could not delete old tokens", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }

View File

@@ -1,36 +1,34 @@
package handler package authentication
import ( import (
"errors" "errors"
"log/slog" "log/slog"
"net/http" "net/http"
"net/url" "net/url"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/authentication/template"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/handler/middleware"
"spend-sparrow/internal/service"
"spend-sparrow/internal/template/auth"
"spend-sparrow/internal/types"
"spend-sparrow/internal/utils" "spend-sparrow/internal/utils"
"time" "time"
) )
type Auth interface { type Handler interface {
Handle(router *http.ServeMux) Handle(router *http.ServeMux)
} }
type AuthImpl struct { type HandlerImpl struct {
service service.Auth service Service
render *core.Render render *core.Render
} }
func NewAuth(service service.Auth, render *core.Render) Auth { func NewHandler(service Service, render *core.Render) Handler {
return AuthImpl{ return HandlerImpl{
service: service, service: service,
render: render, render: render,
} }
} }
func (handler AuthImpl) Handle(router *http.ServeMux) { func (handler HandlerImpl) Handle(router *http.ServeMux) {
router.Handle("GET /auth/signin", handler.handleSignInPage()) router.Handle("GET /auth/signin", handler.handleSignInPage())
router.Handle("POST /api/auth/signin", handler.handleSignIn()) router.Handle("POST /api/auth/signin", handler.handleSignIn())
@@ -57,7 +55,7 @@ var (
securityWaitDuration = 250 * time.Millisecond securityWaitDuration = 250 * time.Millisecond
) )
func (handler AuthImpl) handleSignInPage() http.HandlerFunc { func (handler HandlerImpl) handleSignInPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -71,17 +69,17 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
return return
} }
comp := auth.SignInOrUpComp(true) comp := template.SignInOrUpComp(true)
handler.render.RenderLayout(r, w, comp, nil) handler.render.RenderLayout(r, w, comp, nil)
} }
} }
func (handler AuthImpl) handleSignIn() http.HandlerFunc { func (handler HandlerImpl) handleSignIn() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) { user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*auth_types.User, error) {
session := core.GetSession(r) session := core.GetSession(r)
email := r.FormValue("email") email := r.FormValue("email")
password := r.FormValue("password") password := r.FormValue("password")
@@ -91,14 +89,14 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
return nil, err return nil, err
} }
cookie := middleware.CreateSessionCookie(session.Id) cookie := core.CreateSessionCookie(session.Id)
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
return user, nil return user, nil
}) })
if err != nil { if err != nil {
if errors.Is(err, service.ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) {
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Invalid email or password", http.StatusUnauthorized) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Invalid email or password", http.StatusUnauthorized)
} else { } else {
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
@@ -114,7 +112,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
} }
} }
func (handler AuthImpl) handleSignUpPage() http.HandlerFunc { func (handler HandlerImpl) handleSignUpPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -129,12 +127,12 @@ func (handler AuthImpl) handleSignUpPage() http.HandlerFunc {
return return
} }
signUpComp := auth.SignInOrUpComp(false) signUpComp := template.SignInOrUpComp(false)
handler.render.RenderLayout(r, w, signUpComp, nil) handler.render.RenderLayout(r, w, signUpComp, nil)
} }
} }
func (handler AuthImpl) handleSignUpVerifyPage() http.HandlerFunc { func (handler HandlerImpl) handleSignUpVerifyPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -149,12 +147,12 @@ func (handler AuthImpl) handleSignUpVerifyPage() http.HandlerFunc {
return return
} }
signIn := auth.VerifyComp() signIn := template.VerifyComp()
handler.render.RenderLayout(r, w, signIn, user) handler.render.RenderLayout(r, w, signIn, user)
} }
} }
func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc { func (handler HandlerImpl) handleVerifyResendComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -173,7 +171,7 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
} }
} }
func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc { func (handler HandlerImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -182,7 +180,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
err := handler.service.VerifyUserEmail(r.Context(), token) err := handler.service.VerifyUserEmail(r.Context(), token)
isVerified := err == nil isVerified := err == nil
comp := auth.VerifyResponseComp(isVerified) comp := template.VerifyResponseComp(isVerified)
var status int var status int
if isVerified { if isVerified {
@@ -195,7 +193,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
} }
} }
func (handler AuthImpl) handleSignUp() http.HandlerFunc { func (handler HandlerImpl) handleSignUp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -216,14 +214,14 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, types.ErrInternal): case errors.Is(err, core.ErrInternal):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
return return
case errors.Is(err, service.ErrInvalidEmail): case errors.Is(err, ErrInvalidEmail):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "The email provided is invalid", http.StatusBadRequest) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "The email provided is invalid", http.StatusBadRequest)
return return
case errors.Is(err, service.ErrInvalidPassword): case errors.Is(err, ErrInvalidPassword):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest) utils.TriggerToastWithStatus(r.Context(), w, r, "error", ErrInvalidPassword.Error(), http.StatusBadRequest)
return return
} }
// If err is "service.ErrAccountExists", then just continue // If err is "service.ErrAccountExists", then just continue
@@ -233,7 +231,7 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
} }
} }
func (handler AuthImpl) handleSignOut() http.HandlerFunc { func (handler HandlerImpl) handleSignOut() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -262,7 +260,7 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc {
} }
} }
func (handler AuthImpl) handleDeleteAccountPage() http.HandlerFunc { func (handler HandlerImpl) handleDeleteAccountPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -272,12 +270,12 @@ func (handler AuthImpl) handleDeleteAccountPage() http.HandlerFunc {
return return
} }
comp := auth.DeleteAccountComp() comp := template.DeleteAccountComp()
handler.render.RenderLayout(r, w, comp, user) handler.render.RenderLayout(r, w, comp, user)
} }
} }
func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { func (handler HandlerImpl) handleDeleteAccountComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -291,7 +289,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
err := handler.service.DeleteAccount(r.Context(), user, password) err := handler.service.DeleteAccount(r.Context(), user, password)
if err != nil { if err != nil {
if errors.Is(err, service.ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) {
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Password not correct", http.StatusBadRequest) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Password not correct", http.StatusBadRequest)
} else { } else {
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Internal Server Error", http.StatusInternalServerError) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Internal Server Error", http.StatusInternalServerError)
@@ -303,7 +301,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
} }
} }
func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc { func (handler HandlerImpl) handleChangePasswordPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -316,12 +314,12 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
return return
} }
comp := auth.ChangePasswordComp(isPasswordReset) comp := template.ChangePasswordComp(isPasswordReset)
handler.render.RenderLayout(r, w, comp, user) handler.render.RenderLayout(r, w, comp, user)
} }
} }
func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { func (handler HandlerImpl) handleChangePasswordComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -345,7 +343,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
} }
} }
func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc { func (handler HandlerImpl) handleForgotPasswordPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -355,12 +353,12 @@ func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
return return
} }
comp := auth.ResetPasswordComp() comp := template.ResetPasswordComp()
handler.render.RenderLayout(r, w, comp, user) handler.render.RenderLayout(r, w, comp, user)
} }
} }
func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { func (handler HandlerImpl) handleForgotPasswordComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
@@ -383,7 +381,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
} }
} }
func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { func (handler HandlerImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)

View File

@@ -1,4 +1,4 @@
package service package authentication
import ( import (
"context" "context"
@@ -6,7 +6,8 @@ import (
"errors" "errors"
"log/slog" "log/slog"
"net/mail" "net/mail"
"spend-sparrow/internal/db" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
mailTemplate "spend-sparrow/internal/template/mail" mailTemplate "spend-sparrow/internal/template/mail"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"strings" "strings"
@@ -25,39 +26,39 @@ var (
ErrTokenInvalid = errors.New("token is invalid") ErrTokenInvalid = errors.New("token is invalid")
) )
type Auth interface { type Service interface {
SignUp(ctx context.Context, email string, password string) (*types.User, error) SignUp(ctx context.Context, email string, password string) (*auth_types.User, error)
SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string)
VerifyUserEmail(ctx context.Context, token string) error VerifyUserEmail(ctx context.Context, token string) error
SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_types.Session, *auth_types.User, error)
SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error)
SignInAnonymous(ctx context.Context) (*types.Session, error) SignInAnonymous(ctx context.Context) (*auth_types.Session, error)
SignOut(ctx context.Context, sessionId string) error SignOut(ctx context.Context, sessionId string) error
DeleteAccount(ctx context.Context, user *types.User, currPass string) error DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error
ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error
SendForgotPasswordMail(ctx context.Context, email string) error SendForgotPasswordMail(ctx context.Context, email string) error
ForgotPassword(ctx context.Context, token string, newPass string) error ForgotPassword(ctx context.Context, token string, newPass string) error
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
GetCsrfToken(ctx context.Context, session *types.Session) (string, error) GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error)
CleanupSessionsAndTokens(ctx context.Context) error CleanupSessionsAndTokens(ctx context.Context) error
} }
type AuthImpl struct { type ServiceImpl struct {
db db.Auth db Db
random Random random core.Random
clock Clock clock core.Clock
mail Mail mail core.Mail
serverSettings *types.Settings serverSettings *types.Settings
} }
func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *types.Settings) *AuthImpl { func NewService(db Db, random core.Random, clock core.Clock, mail core.Mail, serverSettings *types.Settings) *ServiceImpl {
return &AuthImpl{ return &ServiceImpl{
db: db, db: db,
random: random, random: random,
clock: clock, clock: clock,
@@ -66,13 +67,13 @@ func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *
} }
} }
func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) { func (service ServiceImpl) SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_types.Session, *auth_types.User, error) {
user, err := service.db.GetUserByEmail(ctx, email) user, err := service.db.GetUserByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, nil, ErrInvalidCredentials return nil, nil, ErrInvalidCredentials
} else { } else {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
} }
@@ -84,36 +85,36 @@ func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, emai
newSession, err := service.createSession(ctx, user.Id) newSession, err := service.createSession(ctx, user.Id)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
err = service.db.DeleteSession(ctx, session.Id) err = service.db.DeleteSession(ctx, session.Id)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
for _, token := range tokens { for _, token := range tokens {
err = service.db.DeleteToken(ctx, token.Token) err = service.db.DeleteToken(ctx, token.Token)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
} }
return newSession, user, nil return newSession, user, nil
} }
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) { func (service ServiceImpl) SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error) {
if sessionId == "" { if sessionId == "" {
return nil, nil, ErrSessionIdInvalid return nil, nil, ErrSessionIdInvalid
} }
session, err := service.db.GetSession(ctx, sessionId) session, err := service.db.GetSession(ctx, sessionId)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
if session.ExpiresAt.Before(service.clock.Now()) { if session.ExpiresAt.Before(service.clock.Now()) {
_ = service.db.DeleteSession(ctx, sessionId) _ = service.db.DeleteSession(ctx, sessionId)
@@ -126,16 +127,16 @@ func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*t
user, err := service.db.GetUser(ctx, session.UserId) user, err := service.db.GetUser(ctx, session.UserId)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
return session, user, nil return session, user, nil
} }
func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, error) { func (service ServiceImpl) SignInAnonymous(ctx context.Context) (*auth_types.Session, error) {
session, err := service.createSession(ctx, uuid.Nil) session, err := service.createSession(ctx, uuid.Nil)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
slog.InfoContext(ctx, "anonymous session created", "session-id", session.Id) slog.InfoContext(ctx, "anonymous session created", "session-id", session.Id)
@@ -143,7 +144,7 @@ func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, er
return session, nil return session, nil
} }
func (service AuthImpl) SignUp(ctx context.Context, email string, password string) (*types.User, error) { func (service ServiceImpl) SignUp(ctx context.Context, email string, password string) (*auth_types.User, error) {
_, err := mail.ParseAddress(email) _, err := mail.ParseAddress(email)
if err != nil { if err != nil {
return nil, ErrInvalidEmail return nil, ErrInvalidEmail
@@ -155,37 +156,37 @@ func (service AuthImpl) SignUp(ctx context.Context, email string, password strin
userId, err := service.random.UUID(ctx) userId, err := service.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
salt, err := service.random.Bytes(ctx, 16) salt, err := service.random.Bytes(ctx, 16)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
hash := GetHashPassword(password, salt) hash := GetHashPassword(password, salt)
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) user := auth_types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
err = service.db.InsertUser(ctx, user) err = service.db.InsertUser(ctx, user)
if err != nil { if err != nil {
if errors.Is(err, db.ErrAlreadyExists) { if errors.Is(err, core.ErrAlreadyExists) {
return nil, ErrAccountExists return nil, ErrAccountExists
} else { } else {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
} }
return user, nil return user, nil
} }
func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) { func (service ServiceImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify) tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, auth_types.TokenTypeEmailVerify)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return return
} }
var token *types.Token var token *auth_types.Token
if len(tokens) > 0 { if len(tokens) > 0 {
token = tokens[0] token = tokens[0]
@@ -197,11 +198,11 @@ func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UU
return return
} }
token = types.NewToken( token = auth_types.NewToken(
userId, userId,
"", "",
newTokenStr, newTokenStr,
types.TokenTypeEmailVerify, auth_types.TokenTypeEmailVerify,
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(24*time.Hour)) service.clock.Now().Add(24*time.Hour))
@@ -221,29 +222,29 @@ func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UU
service.mail.SendMail(ctx, email, "Welcome to spend-sparrow", w.String()) service.mail.SendMail(ctx, email, "Welcome to spend-sparrow", w.String())
} }
func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error { func (service ServiceImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
if tokenStr == "" { if tokenStr == "" {
return types.ErrInternal return core.ErrInternal
} }
token, err := service.db.GetToken(ctx, tokenStr) token, err := service.db.GetToken(ctx, tokenStr)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
user, err := service.db.GetUser(ctx, token.UserId) user, err := service.db.GetUser(ctx, token.UserId)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
if token.Type != types.TokenTypeEmailVerify { if token.Type != auth_types.TokenTypeEmailVerify {
return types.ErrInternal return core.ErrInternal
} }
now := service.clock.Now() now := service.clock.Now()
if token.ExpiresAt.Before(now) { if token.ExpiresAt.Before(now) {
return types.ErrInternal return core.ErrInternal
} }
user.EmailVerified = true user.EmailVerified = true
@@ -251,21 +252,21 @@ func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) er
err = service.db.UpdateUser(ctx, user) err = service.db.UpdateUser(ctx, user)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
_ = service.db.DeleteToken(ctx, token.Token) _ = service.db.DeleteToken(ctx, token.Token)
return nil return nil
} }
func (service AuthImpl) SignOut(ctx context.Context, sessionId string) error { func (service ServiceImpl) SignOut(ctx context.Context, sessionId string) error {
return service.db.DeleteSession(ctx, sessionId) return service.db.DeleteSession(ctx, sessionId)
} }
func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, currPass string) error { func (service ServiceImpl) DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error {
userDb, err := service.db.GetUser(ctx, user.Id) userDb, err := service.db.GetUser(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
currHash := GetHashPassword(currPass, userDb.Salt) currHash := GetHashPassword(currPass, userDb.Salt)
@@ -283,7 +284,7 @@ func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, cur
return nil return nil
} }
func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error { func (service ServiceImpl) ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
} }
@@ -308,13 +309,13 @@ func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, se
sessions, err := service.db.GetSessions(ctx, user.Id) sessions, err := service.db.GetSessions(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
for _, s := range sessions { for _, s := range sessions {
if s.Id != sessionId { if s.Id != sessionId {
err = service.db.DeleteSession(ctx, s.Id) err = service.db.DeleteSession(ctx, s.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
} }
} }
@@ -322,7 +323,7 @@ func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, se
return nil return nil
} }
func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string) error { func (service ServiceImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
tokenStr, err := service.random.String(ctx, 32) tokenStr, err := service.random.String(ctx, 32)
if err != nil { if err != nil {
return err return err
@@ -330,38 +331,38 @@ func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string
user, err := service.db.GetUserByEmail(ctx, email) user, err := service.db.GetUserByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil return nil
} else { } else {
return types.ErrInternal return core.ErrInternal
} }
} }
token := types.NewToken( token := auth_types.NewToken(
user.Id, user.Id,
"", "",
tokenStr, tokenStr,
types.TokenTypePasswordReset, auth_types.TokenTypePasswordReset,
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(15*time.Minute)) service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(ctx, token) err = service.db.InsertToken(ctx, token)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
var mail strings.Builder var mail strings.Builder
err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail) err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not render reset password email", "err", err) slog.ErrorContext(ctx, "Could not render reset password email", "err", err)
return types.ErrInternal return core.ErrInternal
} }
service.mail.SendMail(ctx, email, "Reset Password", mail.String()) service.mail.SendMail(ctx, email, "Reset Password", mail.String())
return nil return nil
} }
func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error { func (service ServiceImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
} }
@@ -376,7 +377,7 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new
return err return err
} }
if token.Type != types.TokenTypePasswordReset || if token.Type != auth_types.TokenTypePasswordReset ||
token.ExpiresAt.Before(service.clock.Now()) { token.ExpiresAt.Before(service.clock.Now()) {
return ErrTokenInvalid return ErrTokenInvalid
} }
@@ -384,7 +385,7 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new
user, err := service.db.GetUser(ctx, token.UserId) user, err := service.db.GetUser(ctx, token.UserId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not get user from token", "err", err) slog.ErrorContext(ctx, "Could not get user from token", "err", err)
return types.ErrInternal return core.ErrInternal
} }
passHash := GetHashPassword(newPass, user.Salt) passHash := GetHashPassword(newPass, user.Salt)
@@ -397,26 +398,26 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new
sessions, err := service.db.GetSessions(ctx, user.Id) sessions, err := service.db.GetSessions(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
for _, session := range sessions { for _, session := range sessions {
err = service.db.DeleteSession(ctx, session.Id) err = service.db.DeleteSession(ctx, session.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
} }
return nil return nil
} }
func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool { func (service ServiceImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
token, err := service.db.GetToken(ctx, tokenStr) token, err := service.db.GetToken(ctx, tokenStr)
if err != nil { if err != nil {
return false return false
} }
if token.Type != types.TokenTypeCsrf || if token.Type != auth_types.TokenTypeCsrf ||
token.SessionId != sessionId || token.SessionId != sessionId ||
token.ExpiresAt.Before(service.clock.Now()) { token.ExpiresAt.Before(service.clock.Now()) {
return false return false
@@ -425,12 +426,12 @@ func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, s
return true return true
} }
func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session) (string, error) { func (service ServiceImpl) GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error) {
if session == nil { if session == nil {
return "", types.ErrInternal return "", core.ErrInternal
} }
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf)
if len(tokens) > 0 { if len(tokens) > 0 {
return tokens[0].Token, nil return tokens[0].Token, nil
@@ -438,19 +439,19 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session
tokenStr, err := service.random.String(ctx, 32) tokenStr, err := service.random.String(ctx, 32)
if err != nil { if err != nil {
return "", types.ErrInternal return "", core.ErrInternal
} }
token := types.NewToken( token := auth_types.NewToken(
session.UserId, session.UserId,
session.Id, session.Id,
tokenStr, tokenStr,
types.TokenTypeCsrf, auth_types.TokenTypeCsrf,
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(8*time.Hour)) service.clock.Now().Add(8*time.Hour))
err = service.db.InsertToken(ctx, token) err = service.db.InsertToken(ctx, token)
if err != nil { if err != nil {
return "", types.ErrInternal return "", core.ErrInternal
} }
slog.InfoContext(ctx, "CSRF-Token created", "token", tokenStr) slog.InfoContext(ctx, "CSRF-Token created", "token", tokenStr)
@@ -458,34 +459,34 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session
return tokenStr, nil return tokenStr, nil
} }
func (service AuthImpl) CleanupSessionsAndTokens(ctx context.Context) error { func (service ServiceImpl) CleanupSessionsAndTokens(ctx context.Context) error {
err := service.db.DeleteOldSessions(ctx) err := service.db.DeleteOldSessions(ctx)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
err = service.db.DeleteOldTokens(ctx) err = service.db.DeleteOldTokens(ctx)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*types.Session, error) { func (service ServiceImpl) createSession(ctx context.Context, userId uuid.UUID) (*auth_types.Session, error) {
sessionId, err := service.random.String(ctx, 32) sessionId, err := service.random.String(ctx, 32)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
createAt := service.clock.Now() createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
session := types.NewSession(sessionId, userId, createAt, expiresAt) session := auth_types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(ctx, session) err = service.db.InsertSession(ctx, session)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return session, nil return session, nil

View File

@@ -1,4 +1,4 @@
package auth package template
templ ChangePasswordComp(isPasswordReset bool) { templ ChangePasswordComp(isPasswordReset bool) {
<form <form

View File

@@ -0,0 +1 @@
package template

View File

@@ -1,4 +1,4 @@
package auth package template
templ DeleteAccountComp() { templ DeleteAccountComp() {
<form <form

View File

@@ -1,4 +1,4 @@
package auth package template
templ ResetPasswordComp() { templ ResetPasswordComp() {
<form <form

View File

@@ -1,13 +1,13 @@
package auth package template
templ SignInOrUpComp(isSignIn bool) { templ SignInOrUpComp(isSignIn bool) {
{{ {{
var postUrl string var postUrl string
if isSignIn { if isSignIn {
postUrl = "/api/auth/signin" postUrl = "/api/auth/signin"
} else { } else {
postUrl = "/api/auth/signup" postUrl = "/api/auth/signup"
} }
}} }}
<form <form
class="max-w-xl px-2 mx-auto flex flex-col gap-4 h-full justify-center" class="max-w-xl px-2 mx-auto flex flex-col gap-4 h-full justify-center"

View File

@@ -1,4 +1,4 @@
package auth package template
templ VerifyComp() { templ VerifyComp() {
<main class="h-full"> <main class="h-full">

View File

@@ -1,4 +1,4 @@
package auth package template
templ VerifyResponseComp(isVerified bool) { templ VerifyResponseComp(isVerified bool) {
<main> <main>

View File

@@ -2,7 +2,7 @@ package core
import ( import (
"net/http" "net/http"
"spend-sparrow/internal/types" "spend-sparrow/internal/auth_types"
) )
type ContextKey string type ContextKey string
@@ -10,13 +10,13 @@ type ContextKey string
var SessionKey ContextKey = "session" var SessionKey ContextKey = "session"
var UserKey ContextKey = "user" var UserKey ContextKey = "user"
func GetUser(r *http.Request) *types.User { func GetUser(r *http.Request) *auth_types.User {
obj := r.Context().Value(UserKey) obj := r.Context().Value(UserKey)
if obj == nil { if obj == nil {
return nil return nil
} }
user, ok := obj.(*types.User) user, ok := obj.(*auth_types.User)
if !ok { if !ok {
return nil return nil
} }
@@ -24,16 +24,28 @@ func GetUser(r *http.Request) *types.User {
return user return user
} }
func GetSession(r *http.Request) *types.Session { func GetSession(r *http.Request) *auth_types.Session {
obj := r.Context().Value(SessionKey) obj := r.Context().Value(SessionKey)
if obj == nil { if obj == nil {
return nil return nil
} }
session, ok := obj.(*types.Session) session, ok := obj.(*auth_types.Session)
if !ok { if !ok {
return nil return nil
} }
return session return session
} }
func CreateSessionCookie(sessionId string) http.Cookie {
return http.Cookie{
Name: "id",
Value: sessionId,
MaxAge: 60 * 60 * 8, // 8 hours
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
}
}

View File

@@ -1,4 +1,4 @@
package service package core
import "time" import "time"

View File

@@ -3,8 +3,6 @@ package core
import ( import (
"errors" "errors"
"net/http" "net/http"
"spend-sparrow/internal/db"
"spend-sparrow/internal/service"
"spend-sparrow/internal/utils" "spend-sparrow/internal/utils"
"strings" "strings"
@@ -14,13 +12,13 @@ import (
func HandleError(w http.ResponseWriter, r *http.Request, err error) { func HandleError(w http.ResponseWriter, r *http.Request, err error) {
switch { switch {
case errors.Is(err, service.ErrUnauthorized): case errors.Is(err, ErrUnauthorized):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized)
return return
case errors.Is(err, service.ErrBadRequest): case errors.Is(err, ErrBadRequest):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusBadRequest) utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
return return
case errors.Is(err, db.ErrNotFound): case errors.Is(err, ErrNotFound):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusNotFound) utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusNotFound)
return return
} }

13
internal/core/error.go Normal file
View File

@@ -0,0 +1,13 @@
package core
import "errors"
var (
ErrNotFound = errors.New("the value does not exist")
ErrAlreadyExists = errors.New("row already exists")
ErrInternal = errors.New("internal server error")
ErrUnauthorized = errors.New("you are not authorized to perform this action")
ErrBadRequest = errors.New("bad request")
)

View File

@@ -1,4 +1,4 @@
package service package core
import ( import (
"context" "context"

View File

@@ -1,11 +1,10 @@
package service package core
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"log/slog" "log/slog"
"spend-sparrow/internal/types"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -28,7 +27,7 @@ func (r *RandomImpl) Bytes(ctx context.Context, tsize int) ([]byte, error) {
_, err := rand.Read(b) _, err := rand.Read(b)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Error generating random bytes", "err", err) slog.ErrorContext(ctx, "Error generating random bytes", "err", err)
return []byte{}, types.ErrInternal return []byte{}, ErrInternal
} }
return b, nil return b, nil
@@ -38,7 +37,7 @@ func (r *RandomImpl) String(ctx context.Context, size int) (string, error) {
bytes, err := r.Bytes(ctx, size) bytes, err := r.Bytes(ctx, size)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Error generating random string", "err", err) slog.ErrorContext(ctx, "Error generating random string", "err", err)
return "", types.ErrInternal return "", ErrInternal
} }
return base64.StdEncoding.EncodeToString(bytes), nil return base64.StdEncoding.EncodeToString(bytes), nil
@@ -48,7 +47,7 @@ func (r *RandomImpl) UUID(ctx context.Context) (uuid.UUID, error) {
id, err := uuid.NewRandom() id, err := uuid.NewRandom()
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Error generating random UUID", "err", err) slog.ErrorContext(ctx, "Error generating random UUID", "err", err)
return uuid.Nil, types.ErrInternal return uuid.Nil, ErrInternal
} }
return id, nil return id, nil

View File

@@ -1,10 +1,11 @@
package core package core
import ( import (
"github.com/a-h/templ"
"log/slog" "log/slog"
"net/http" "net/http"
"spend-sparrow/internal/types" "spend-sparrow/internal/auth_types"
"github.com/a-h/templ"
) )
type Render struct { type Render struct {
@@ -28,18 +29,18 @@ func (render *Render) Render(r *http.Request, w http.ResponseWriter, comp templ.
render.RenderWithStatus(r, w, comp, http.StatusOK) render.RenderWithStatus(r, w, comp, http.StatusOK)
} }
func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *types.User) { func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *auth_types.User) {
render.RenderLayoutWithStatus(r, w, slot, user, http.StatusOK) render.RenderLayoutWithStatus(r, w, slot, user, http.StatusOK)
} }
func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWriter, slot templ.Component, user *types.User, status int) { func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWriter, slot templ.Component, user *auth_types.User, status int) {
userComp := render.getUserComp(user) userComp := render.getUserComp(user)
layout := Layout(slot, userComp, user != nil, r.URL.Path) layout := Layout(slot, userComp, user != nil, r.URL.Path)
render.RenderWithStatus(r, w, layout, status) render.RenderWithStatus(r, w, layout, status)
} }
func (render *Render) getUserComp(user *types.User) templ.Component { func (render *Render) getUserComp(user *auth_types.User) templ.Component {
if user != nil { if user != nil {
return UserComp(user.Email) return UserComp(user.Email)
} else { } else {

View File

@@ -5,33 +5,28 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"log/slog" "log/slog"
"spend-sparrow/internal/types" "spend-sparrow/internal/core"
)
var (
ErrNotFound = errors.New("the value does not exist")
ErrAlreadyExists = errors.New("row already exists")
) )
func TransformAndLogDbError(ctx context.Context, module string, r sql.Result, err error) error { func TransformAndLogDbError(ctx context.Context, module string, r sql.Result, err error) error {
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return ErrNotFound return core.ErrNotFound
} }
slog.ErrorContext(ctx, "database sql", "module", module, "err", err) slog.ErrorContext(ctx, "database sql", "module", module, "err", err)
return types.ErrInternal return core.ErrInternal
} }
if r != nil { if r != nil {
rows, err := r.RowsAffected() rows, err := r.RowsAffected()
if err != nil { if err != nil {
slog.ErrorContext(ctx, "database rows affected", "module", module, "err", err) slog.ErrorContext(ctx, "database rows affected", "module", module, "err", err)
return types.ErrInternal return core.ErrInternal
} }
if rows == 0 { if rows == 0 {
slog.InfoContext(ctx, "row not found", "module", module) slog.InfoContext(ctx, "row not found", "module", module)
return ErrNotFound return core.ErrNotFound
} }
} }

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"log/slog" "log/slog"
"spend-sparrow/internal/types" "spend-sparrow/internal/core"
"github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3" "github.com/golang-migrate/migrate/v4/database/sqlite3"
@@ -25,7 +25,7 @@ func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error {
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not create Migration instance", "err", err) slog.ErrorContext(ctx, "Could not create Migration instance", "err", err)
return types.ErrInternal return core.ErrInternal
} }
m, err := migrate.NewWithDatabaseInstance( m, err := migrate.NewWithDatabaseInstance(
@@ -34,14 +34,14 @@ func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error {
driver) driver)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not create migrations instance", "err", err) slog.ErrorContext(ctx, "Could not create migrations instance", "err", err)
return types.ErrInternal return core.ErrInternal
} }
m.Log = migrationLogger{} m.Log = migrationLogger{}
if err = m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { if err = m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
slog.ErrorContext(ctx, "Could not run migrations", "err", err) slog.ErrorContext(ctx, "Could not run migrations", "err", err)
return types.ErrInternal return core.ErrInternal
} }
return nil return nil

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"os/signal" "os/signal"
"spend-sparrow/internal/account" "spend-sparrow/internal/account"
"spend-sparrow/internal/authentication"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/handler" "spend-sparrow/internal/handler"
@@ -107,13 +108,13 @@ func shutdownServer(ctx context.Context, s *http.Server, wg *sync.WaitGroup) {
func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler { func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
var router = http.NewServeMux() var router = http.NewServeMux()
authDb := db.NewAuthSqlite(d) authDb := authentication.NewDbSqlite(d)
randomService := service.NewRandom() randomService := core.NewRandom()
clockService := service.NewClock() clockService := core.NewClock()
mailService := service.NewMail(serverSettings) mailService := core.NewMail(serverSettings)
authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings) authService := authentication.NewService(authDb, randomService, clockService, mailService, serverSettings)
accountService := account.NewServiceImpl(d, randomService, clockService) accountService := account.NewServiceImpl(d, randomService, clockService)
treasureChestService := service.NewTreasureChest(d, randomService, clockService) treasureChestService := service.NewTreasureChest(d, randomService, clockService)
transactionService := service.NewTransaction(d, randomService, clockService) transactionService := service.NewTransaction(d, randomService, clockService)
@@ -123,7 +124,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *
render := core.NewRender() render := core.NewRender()
indexHandler := handler.NewIndex(render, clockService) indexHandler := handler.NewIndex(render, clockService)
dashboardHandler := handler.NewDashboard(render, dashboardService, treasureChestService) dashboardHandler := handler.NewDashboard(render, dashboardService, treasureChestService)
authHandler := handler.NewAuth(authService, render) authHandler := authentication.NewHandler(authService, render)
accountHandler := account.NewHandler(accountService, render) accountHandler := account.NewHandler(accountService, render)
treasureChestHandler := handler.NewTreasureChest(treasureChestService, transactionRecurringService, render) treasureChestHandler := handler.NewTreasureChest(treasureChestService, transactionRecurringService, render)
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render) transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
@@ -157,7 +158,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *
return wrapper return wrapper
} }
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) { func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) {
runDailyTasks(ctx, transactionRecurring, auth) runDailyTasks(ctx, transactionRecurring, auth)
ticker := time.NewTicker(24 * time.Hour) ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop() defer ticker.Stop()
@@ -172,7 +173,7 @@ func dailyTaskTimer(ctx context.Context, transactionRecurring service.Transactio
} }
} }
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) { func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) {
slog.InfoContext(ctx, "Running daily tasks") slog.InfoContext(ctx, "Running daily tasks")
_ = transactionRecurring.GenerateTransactions(ctx) _ = transactionRecurring.GenerateTransactions(ctx)
_ = auth.CleanupSessionsAndTokens(ctx) _ = auth.CleanupSessionsAndTokens(ctx)

View File

@@ -193,7 +193,7 @@ func (handler DashboardImpl) handleDashboardTreasureChest() http.HandlerFunc {
if treasureChestStr != "" { if treasureChestStr != "" {
id, err := uuid.Parse(treasureChestStr) id, err := uuid.Parse(treasureChestStr)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", core.ErrBadRequest))
return return
} }

View File

@@ -3,12 +3,12 @@ package middleware
import ( import (
"context" "context"
"net/http" "net/http"
"spend-sparrow/internal/authentication"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/service"
"strings" "strings"
) )
func Authenticate(service service.Auth) func(http.Handler) http.Handler { func Authenticate(service authentication.Service) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
@@ -31,7 +31,7 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
return return
} }
cookie := CreateSessionCookie(session.Id) cookie := core.CreateSessionCookie(session.Id)
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
} }

View File

@@ -3,8 +3,8 @@ package middleware
import ( import (
"log/slog" "log/slog"
"net/http" "net/http"
"spend-sparrow/internal/authentication"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/service"
"spend-sparrow/internal/utils" "spend-sparrow/internal/utils"
"strings" "strings"
) )
@@ -31,7 +31,7 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
return rr.ResponseWriter.Write([]byte(dataStr)) return rr.ResponseWriter.Write([]byte(dataStr))
} }
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { func CrossSiteRequestForgery(auth authentication.Service) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()

View File

@@ -1,15 +1 @@
package middleware package middleware
import "net/http"
func CreateSessionCookie(sessionId string) http.Cookie {
return http.Cookie{
Name: "id",
Value: sessionId,
MaxAge: 60 * 60 * 8, // 8 hours
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
}
}

View File

@@ -3,7 +3,6 @@ package handler
import ( import (
"net/http" "net/http"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/service"
"spend-sparrow/internal/template" "spend-sparrow/internal/template"
"spend-sparrow/internal/utils" "spend-sparrow/internal/utils"
@@ -16,10 +15,10 @@ type Index interface {
type IndexImpl struct { type IndexImpl struct {
r *core.Render r *core.Render
c service.Clock c core.Clock
} }
func NewIndex(r *core.Render, c service.Clock) Index { func NewIndex(r *core.Render, c core.Clock) Index {
return IndexImpl{ return IndexImpl{
r: r, r: r,
c: c, c: c,

View File

@@ -157,7 +157,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
if idStr != "new" { if idStr != "new" {
id, err = uuid.Parse(idStr) id, err = uuid.Parse(idStr)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest))
return return
} }
} }
@@ -167,7 +167,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
if accountIdStr != "" { if accountIdStr != "" {
i, err := uuid.Parse(accountIdStr) i, err := uuid.Parse(accountIdStr)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", core.ErrBadRequest))
return return
} }
accountId = &i accountId = &i
@@ -178,7 +178,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
if treasureChestIdStr != "" { if treasureChestIdStr != "" {
i, err := uuid.Parse(treasureChestIdStr) i, err := uuid.Parse(treasureChestIdStr)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", core.ErrBadRequest))
return return
} }
treasureChestId = &i treasureChestId = &i
@@ -186,14 +186,14 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
valueF, err := strconv.ParseFloat(r.FormValue("value"), 64) valueF, err := strconv.ParseFloat(r.FormValue("value"), 64)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse value: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse value: %w", core.ErrBadRequest))
return return
} }
value := int64(math.Round(valueF * service.DECIMALS_MULTIPLIER)) value := int64(math.Round(valueF * service.DECIMALS_MULTIPLIER))
timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp")) timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp"))
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest))
return return
} }

View File

@@ -2,6 +2,7 @@ package handler
import ( import (
"net/http" "net/http"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/service" "spend-sparrow/internal/service"
t "spend-sparrow/internal/template/transaction_recurring" t "spend-sparrow/internal/template/transaction_recurring"
@@ -111,7 +112,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
} }
} }
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) { func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *auth_types.User, id, accountId, treasureChestId string) {
var transactionsRecurring []*types.TransactionRecurring var transactionsRecurring []*types.TransactionRecurring
var err error var err error
if accountId == "" && treasureChestId == "" { if accountId == "" && treasureChestId == "" {

View File

@@ -2,6 +2,8 @@ package service
import ( import (
"context" "context"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"time" "time"
@@ -22,10 +24,10 @@ func NewDashboard(db *sqlx.DB) *Dashboard {
func (s Dashboard) MainChart( func (s Dashboard) MainChart(
ctx context.Context, ctx context.Context,
user *types.User, user *auth_types.User,
) ([]types.DashboardMainChartEntry, error) { ) ([]types.DashboardMainChartEntry, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
transactions := make([]types.Transaction, 0) transactions := make([]types.Transaction, 0)
@@ -82,10 +84,10 @@ func (s Dashboard) MainChart(
func (s Dashboard) TreasureChests( func (s Dashboard) TreasureChests(
ctx context.Context, ctx context.Context,
user *types.User, user *auth_types.User,
) ([]*types.DashboardTreasureChest, error) { ) ([]*types.DashboardTreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
treasureChests := make([]*types.TreasureChest, 0) treasureChests := make([]*types.TreasureChest, 0)
@@ -120,11 +122,11 @@ func (s Dashboard) TreasureChests(
func (s Dashboard) TreasureChest( func (s Dashboard) TreasureChest(
ctx context.Context, ctx context.Context,
user *types.User, user *auth_types.User,
treausureChestId *uuid.UUID, treausureChestId *uuid.UUID,
) ([]types.DashboardMainChartEntry, error) { ) ([]types.DashboardMainChartEntry, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
transactions := make([]types.Transaction, 0) transactions := make([]types.Transaction, 0)

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"fmt" "fmt"
"regexp" "regexp"
"spend-sparrow/internal/core"
) )
const ( const (
@@ -16,9 +17,9 @@ var (
func ValidateString(value string, fieldName string) error { func ValidateString(value string, fieldName string) error {
switch { switch {
case value == "": case value == "":
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest) return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, core.ErrBadRequest)
case !safeInputRegex.MatchString(value): case !safeInputRegex.MatchString(value):
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest) return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, core.ErrBadRequest)
default: default:
return nil return nil
} }

View File

@@ -1,8 +0,0 @@
package service
import "errors"
var (
ErrBadRequest = errors.New("bad request")
ErrUnauthorized = errors.New("unauthorized")
)

View File

@@ -5,6 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"strconv" "strconv"
@@ -17,22 +19,22 @@ import (
const page_size = 25 const page_size = 25
type Transaction interface { type Transaction interface {
Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error) Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error)
Update(ctx context.Context, user *types.User, transaction types.Transaction) (*types.Transaction, error) Update(ctx context.Context, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error)
Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error)
GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
Delete(ctx context.Context, user *types.User, id string) error Delete(ctx context.Context, user *auth_types.User, id string) error
RecalculateBalances(ctx context.Context, user *types.User) error RecalculateBalances(ctx context.Context, user *auth_types.User) error
} }
type TransactionImpl struct { type TransactionImpl struct {
db *sqlx.DB db *sqlx.DB
clock Clock clock core.Clock
random Random random core.Random
} }
func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction { func NewTransaction(db *sqlx.DB, random core.Random, clock core.Clock) Transaction {
return TransactionImpl{ return TransactionImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -40,9 +42,9 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
} }
} }
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transactionInput types.Transaction) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
var err error var err error
@@ -107,9 +109,9 @@ func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User,
return transaction, nil return transaction, nil
} }
func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) { func (s TransactionImpl) Update(ctx context.Context, user *auth_types.User, input types.Transaction) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -125,10 +127,10 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ
err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id) err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
err = db.TransformAndLogDbError(ctx, "transaction Update", nil, err) err = db.TransformAndLogDbError(ctx, "transaction Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest) return nil, fmt.Errorf("transaction %v not found: %w", input.Id, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
if transaction.Error == nil && transaction.AccountId != nil { if transaction.Error == nil && transaction.AccountId != nil {
@@ -206,32 +208,32 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ
return transaction, nil return transaction, nil
} }
func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) { func (s TransactionImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transaction get", "err", err) slog.ErrorContext(ctx, "transaction get", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
var transaction types.Transaction var transaction types.Transaction
err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError(ctx, "transaction Get", nil, err) err = db.TransformAndLogDbError(ctx, "transaction Get", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest) return nil, fmt.Errorf("transaction %v not found: %w", id, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return &transaction, nil return &transaction, nil
} }
func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) { func (s TransactionImpl) GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
var ( var (
@@ -277,14 +279,14 @@ func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter ty
return transactions, nil return transactions, nil
} }
func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error { func (s TransactionImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
if user == nil { if user == nil {
return ErrUnauthorized return core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transaction delete", "err", err) slog.ErrorContext(ctx, "transaction delete", "err", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -310,7 +312,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
WHERE id = ? WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err) err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
} }
@@ -322,7 +324,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
WHERE id = ? WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err) err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
} }
@@ -342,9 +344,9 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
return nil return nil
} }
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error { func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *auth_types.User) error {
if user == nil { if user == nil {
return ErrUnauthorized return core.ErrUnauthorized
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -361,7 +363,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
SET current_balance = 0 SET current_balance = 0
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
@@ -370,7 +372,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
SET current_balance = 0 SET current_balance = 0
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
@@ -379,7 +381,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
FROM "transaction" FROM "transaction"
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", nil, err) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", nil, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
defer func() { defer func() {
@@ -458,7 +460,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
if oldTransaction == nil { if oldTransaction == nil {
id, err = s.random.UUID(ctx) id, err = s.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
createdAt = s.clock.Now() createdAt = s.clock.Now()
createdBy = userId createdBy = userId
@@ -479,7 +481,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
} }
if rowCount == 0 { if rowCount == 0 {
slog.ErrorContext(ctx, "transaction validate", "err", err) slog.ErrorContext(ctx, "transaction validate", "err", err)
return nil, fmt.Errorf("account not found: %w", ErrBadRequest) return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest)
} }
} }
@@ -488,13 +490,13 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId) err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
err = db.TransformAndLogDbError(ctx, "transaction validate", nil, err) err = db.TransformAndLogDbError(ctx, "transaction validate", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest)
} }
return nil, err return nil, err
} }
if treasureChest.ParentId == nil { if treasureChest.ParentId == nil {
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest)
} }
} }

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"strconv" "strconv"
@@ -16,24 +18,24 @@ import (
) )
type TransactionRecurring interface { type TransactionRecurring interface {
Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Add(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Update(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error)
GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error)
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) GetAllByTreasureChest(ctx context.Context, user *auth_types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(ctx context.Context, user *types.User, id string) error Delete(ctx context.Context, user *auth_types.User, id string) error
GenerateTransactions(ctx context.Context) error GenerateTransactions(ctx context.Context) error
} }
type TransactionRecurringImpl struct { type TransactionRecurringImpl struct {
db *sqlx.DB db *sqlx.DB
clock Clock clock core.Clock
random Random random core.Random
transaction Transaction transaction Transaction
} }
func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transaction Transaction) TransactionRecurring { func NewTransactionRecurring(db *sqlx.DB, random core.Random, clock core.Clock, transaction Transaction) TransactionRecurring {
return TransactionRecurringImpl{ return TransactionRecurringImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -43,11 +45,11 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio
} }
func (s TransactionRecurringImpl) Add(ctx context.Context, func (s TransactionRecurringImpl) Add(ctx context.Context,
user *types.User, user *auth_types.User,
transactionRecurringInput types.TransactionRecurringInput, transactionRecurringInput types.TransactionRecurringInput,
) (*types.TransactionRecurring, error) { ) (*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -85,16 +87,16 @@ func (s TransactionRecurringImpl) Add(ctx context.Context,
} }
func (s TransactionRecurringImpl) Update(ctx context.Context, func (s TransactionRecurringImpl) Update(ctx context.Context,
user *types.User, user *auth_types.User,
input types.TransactionRecurringInput, input types.TransactionRecurringInput,
) (*types.TransactionRecurring, error) { ) (*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
uuid, err := uuid.Parse(input.Id) uuid, err := uuid.Parse(input.Id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring update", "err", err) slog.ErrorContext(ctx, "transactionRecurring update", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -110,10 +112,10 @@ func (s TransactionRecurringImpl) Update(ctx context.Context,
err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError(ctx, "transactionRecurring Update", nil, err) err = db.TransformAndLogDbError(ctx, "transactionRecurring Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest) return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input) transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
@@ -149,9 +151,9 @@ func (s TransactionRecurringImpl) Update(ctx context.Context,
return transactionRecurring, nil return transactionRecurring, nil
} }
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
@@ -169,15 +171,15 @@ func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User)
return transactionRecurrings, nil return transactionRecurrings, nil
} }
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
accountUuid, err := uuid.Parse(accountId) accountUuid, err := uuid.Parse(accountId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring GetAllByAccount", "err", err) slog.ErrorContext(ctx, "transactionRecurring GetAllByAccount", "err", err)
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -193,10 +195,10 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByAccount", nil, err) err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByAccount", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest) return nil, fmt.Errorf("account %v not found: %w", accountId, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
@@ -222,17 +224,17 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ
} }
func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context, func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
user *types.User, user *auth_types.User,
treasureChestId string, treasureChestId string,
) ([]*types.TransactionRecurring, error) { ) ([]*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
treasureChestUuid, err := uuid.Parse(treasureChestId) treasureChestUuid, err := uuid.Parse(treasureChestId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring GetAllByTreasureChest", "err", err) slog.ErrorContext(ctx, "transactionRecurring GetAllByTreasureChest", "err", err)
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -248,10 +250,10 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByTreasureChest", nil, err) err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest) return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
@@ -276,14 +278,14 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
return transactionRecurrings, nil return transactionRecurrings, nil
} }
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error { func (s TransactionRecurringImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
if user == nil { if user == nil {
return ErrUnauthorized return core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring delete", "err", err) slog.ErrorContext(ctx, "transactionRecurring delete", "err", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -339,7 +341,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) erro
} }
for _, transactionRecurring := range recurringTransactions { for _, transactionRecurring := range recurringTransactions {
user := &types.User{ user := &auth_types.User{
Id: transactionRecurring.UserId, Id: transactionRecurring.UserId,
} }
transaction := types.Transaction{ transaction := types.Transaction{
@@ -397,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
if oldTransactionRecurring == nil { if oldTransactionRecurring == nil {
id, err = s.random.UUID(ctx) id, err = s.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
createdAt = s.clock.Now() createdAt = s.clock.Now()
createdBy = userId createdBy = userId
@@ -416,7 +418,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
temp, err := uuid.Parse(input.AccountId) temp, err := uuid.Parse(input.AccountId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest)
} }
accountUuid = &temp accountUuid = &temp
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
@@ -426,7 +428,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
} }
if rowCount == 0 { if rowCount == 0 {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("account not found: %w", ErrBadRequest) return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest)
} }
hasAccount = true hasAccount = true
@@ -436,37 +438,37 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
temp, err := uuid.Parse(input.TreasureChestId) temp, err := uuid.Parse(input.TreasureChestId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest)
} }
treasureChestUuid = &temp treasureChestUuid = &temp
var treasureChest types.TreasureChest var treasureChest types.TreasureChest
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError(ctx, "transactionRecurring validate", nil, err) err = db.TransformAndLogDbError(ctx, "transactionRecurring validate", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest)
} }
return nil, err return nil, err
} }
if treasureChest.ParentId == nil { if treasureChest.ParentId == nil {
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest)
} }
hasTreasureChest = true hasTreasureChest = true
} }
if !hasAccount && !hasTreasureChest { if !hasAccount && !hasTreasureChest {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("either account or treasure chest is required: %w", ErrBadRequest) return nil, fmt.Errorf("either account or treasure chest is required: %w", core.ErrBadRequest)
} }
if hasAccount && hasTreasureChest { if hasAccount && hasTreasureChest {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", ErrBadRequest) return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", core.ErrBadRequest)
} }
valueFloat, err := strconv.ParseFloat(input.Value, 64) valueFloat, err := strconv.ParseFloat(input.Value, 64)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse value: %w", core.ErrBadRequest)
} }
value := int64(math.Round(valueFloat * DECIMALS_MULTIPLIER)) value := int64(math.Round(valueFloat * DECIMALS_MULTIPLIER))
@@ -485,18 +487,18 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0) intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("could not parse intervalMonths: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse intervalMonths: %w", core.ErrBadRequest)
} }
if intervalMonths < 1 { if intervalMonths < 1 {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", ErrBadRequest) return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", core.ErrBadRequest)
} }
var nextExecution *time.Time = nil var nextExecution *time.Time = nil
if input.NextExecution != "" { if input.NextExecution != "" {
t, err := time.Parse("2006-01-02", input.NextExecution) t, err := time.Parse("2006-01-02", input.NextExecution)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transaction validate", "err", err) slog.ErrorContext(ctx, "transaction validate", "err", err)
return nil, fmt.Errorf("could not parse timestamp: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest)
} }
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location()) t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"slices" "slices"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
@@ -14,20 +16,20 @@ import (
) )
type TreasureChest interface { type TreasureChest interface {
Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error)
Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error) Update(ctx context.Context, user *auth_types.User, id, parentId, name string) (*types.TreasureChest, error)
Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error)
GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error)
Delete(ctx context.Context, user *types.User, id string) error Delete(ctx context.Context, user *auth_types.User, id string) error
} }
type TreasureChestImpl struct { type TreasureChestImpl struct {
db *sqlx.DB db *sqlx.DB
clock Clock clock core.Clock
random Random random core.Random
} }
func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest { func NewTreasureChest(db *sqlx.DB, random core.Random, clock core.Clock) TreasureChest {
return TreasureChestImpl{ return TreasureChestImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -35,14 +37,14 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
} }
} }
func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
newId, err := s.random.UUID(ctx) newId, err := s.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
err = ValidateString(name, "name") err = ValidateString(name, "name")
@@ -57,7 +59,7 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId,
return nil, err return nil, err
} }
if parent.ParentId != nil { if parent.ParentId != nil {
return nil, fmt.Errorf("only a depth of 1 allowed: %w", ErrBadRequest) return nil, fmt.Errorf("only a depth of 1 allowed: %w", core.ErrBadRequest)
} }
parentUuid = &parent.Id parentUuid = &parent.Id
} }
@@ -88,9 +90,9 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId,
return treasureChest, nil return treasureChest, nil
} }
func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Update(ctx context.Context, user *auth_types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
err := ValidateString(name, "name") err := ValidateString(name, "name")
if err != nil { if err != nil {
@@ -99,7 +101,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
id, err := uuid.Parse(idStr) id, err := uuid.Parse(idStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "treasureChest update", "err", err) slog.ErrorContext(ctx, "treasureChest update", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -115,10 +117,10 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id) err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
err = db.TransformAndLogDbError(ctx, "treasureChest Update", nil, err) err = db.TransformAndLogDbError(ctx, "treasureChest Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err) return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
var parentUuid *uuid.UUID var parentUuid *uuid.UUID
@@ -134,7 +136,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
return nil, err return nil, err
} }
if parent.ParentId != nil || childCount > 0 { if parent.ParentId != nil || childCount > 0 {
return nil, fmt.Errorf("only one level allowed: %w", ErrBadRequest) return nil, fmt.Errorf("only one level allowed: %w", core.ErrBadRequest)
} }
parentUuid = &parent.Id parentUuid = &parent.Id
@@ -170,32 +172,32 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
return treasureChest, nil return treasureChest, nil
} }
func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "treasureChest get", "err", err) slog.ErrorContext(ctx, "treasureChest get", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
var treasureChest types.TreasureChest var treasureChest types.TreasureChest
err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError(ctx, "treasureChest Get", nil, err) err = db.TransformAndLogDbError(ctx, "treasureChest Get", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err) return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return &treasureChest, nil return &treasureChest, nil
} }
func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) { func (s TreasureChestImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
treasureChests := make([]*types.TreasureChest, 0) treasureChests := make([]*types.TreasureChest, 0)
@@ -208,14 +210,14 @@ func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*typ
return sortTreasureChests(treasureChests), nil return sortTreasureChests(treasureChests), nil
} }
func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error { func (s TreasureChestImpl) Delete(ctx context.Context, user *auth_types.User, idStr string) error {
if user == nil { if user == nil {
return ErrUnauthorized return core.ErrUnauthorized
} }
id, err := uuid.Parse(idStr) id, err := uuid.Parse(idStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "treasureChest delete", "err", err) slog.ErrorContext(ctx, "treasureChest delete", "err", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -235,7 +237,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
} }
if childCount > 0 { if childCount > 0 {
return fmt.Errorf("treasure chest has children: %w", ErrBadRequest) return fmt.Errorf("treasure chest has children: %w", core.ErrBadRequest)
} }
transactionsCount := 0 transactionsCount := 0
@@ -247,7 +249,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
return err return err
} }
if transactionsCount > 0 { if transactionsCount > 0 {
return fmt.Errorf("treasure chest has transactions: %w", ErrBadRequest) return fmt.Errorf("treasure chest has transactions: %w", core.ErrBadRequest)
} }
recurringCount := 0 recurringCount := 0
@@ -259,7 +261,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
return err return err
} }
if recurringCount > 0 { if recurringCount > 0 {
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest) return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", core.ErrBadRequest)
} }
r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id) r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)

View File

@@ -1 +0,0 @@
package auth

View File

@@ -1,10 +0,0 @@
package types
import (
"errors"
)
var (
ErrInternal = errors.New("internal server error")
ErrUnauthorized = errors.New("you are not authorized to perform this action")
)

View File

@@ -1 +0,0 @@
package mocks

View File

@@ -2,8 +2,10 @@ package test_test
import ( import (
"context" "context"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/authentication"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types"
"testing" "testing"
"time" "time"
@@ -42,11 +44,11 @@ func TestUser(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) expected := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(context.Background(), expected) err := underTest.InsertUser(context.Background(), expected)
require.NoError(t, err) require.NoError(t, err)
@@ -63,38 +65,38 @@ func TestUser(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
_, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail") _, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail")
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, core.ErrNotFound, err)
}) })
t.Run("should return ErrUserExist", func(t *testing.T) { t.Run("should return ErrUserExist", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(context.Background(), user) err := underTest.InsertUser(context.Background(), user)
require.NoError(t, err) require.NoError(t, err)
err = underTest.InsertUser(context.Background(), user) err = underTest.InsertUser(context.Background(), user)
assert.Equal(t, db.ErrAlreadyExists, err) assert.Equal(t, core.ErrAlreadyExists, err)
}) })
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := auth_types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
err := underTest.InsertUser(context.Background(), user) err := underTest.InsertUser(context.Background(), user)
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, core.ErrInternal, err)
}) })
} }
@@ -105,11 +107,11 @@ func TestToken(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt) expected := auth_types.NewToken(uuid.New(), "sessionId", "token", auth_types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(context.Background(), expected) err := underTest.InsertToken(context.Background(), expected)
require.NoError(t, err) require.NoError(t, err)
@@ -121,25 +123,25 @@ func TestToken(t *testing.T) {
expected.SessionId = "" expected.SessionId = ""
actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type) actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals) assert.Equal(t, []*auth_types.Token{expected}, actuals)
expected.SessionId = "sessionId" expected.SessionId = "sessionId"
expected.UserId = uuid.Nil expected.UserId = uuid.Nil
actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type) actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals) assert.Equal(t, []*auth_types.Token{expected}, actuals)
}) })
t.Run("should insert and return multiple tokens", func(t *testing.T) { t.Run("should insert and return multiple tokens", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
userId := uuid.New() userId := uuid.New()
expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt) expected1 := auth_types.NewToken(userId, "sessionId", "token1", auth_types.TokenTypeCsrf, createAt, expiresAt)
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt) expected2 := auth_types.NewToken(userId, "sessionId", "token2", auth_types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(context.Background(), expected1) err := underTest.InsertToken(context.Background(), expected1)
require.NoError(t, err) require.NoError(t, err)
@@ -150,7 +152,7 @@ func TestToken(t *testing.T) {
expected2.UserId = uuid.Nil expected2.UserId = uuid.Nil
actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type) actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals) assert.Equal(t, []*auth_types.Token{expected1, expected2}, actuals)
expected1.SessionId = "" expected1.SessionId = ""
expected2.SessionId = "" expected2.SessionId = ""
@@ -158,49 +160,49 @@ func TestToken(t *testing.T) {
expected2.UserId = userId expected2.UserId = userId
actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type) actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals) assert.Equal(t, []*auth_types.Token{expected1, expected2}, actuals)
}) })
t.Run("should return ErrNotFound", func(t *testing.T) { t.Run("should return ErrNotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
_, err := underTest.GetToken(context.Background(), "nonExistent") _, err := underTest.GetToken(context.Background(), "nonExistent")
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, core.ErrNotFound, err)
_, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), types.TokenTypeEmailVerify) _, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), auth_types.TokenTypeEmailVerify)
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, core.ErrNotFound, err)
_, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", types.TokenTypeEmailVerify) _, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", auth_types.TokenTypeEmailVerify)
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, core.ErrNotFound, err)
}) })
t.Run("should return ErrAlreadyExists", func(t *testing.T) { t.Run("should return ErrAlreadyExists", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(context.Background(), user) err := underTest.InsertUser(context.Background(), user)
require.NoError(t, err) require.NoError(t, err)
err = underTest.InsertUser(context.Background(), user) err = underTest.InsertUser(context.Background(), user)
assert.Equal(t, db.ErrAlreadyExists, err) assert.Equal(t, core.ErrAlreadyExists, err)
}) })
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := auth_types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
err := underTest.InsertUser(context.Background(), user) err := underTest.InsertUser(context.Background(), user)
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, core.ErrInternal, err)
}) })
} }

View File

@@ -2,8 +2,9 @@ package test_test
import ( import (
"context" "context"
"spend-sparrow/internal/db" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/service" "spend-sparrow/internal/authentication"
"spend-sparrow/internal/core"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"spend-sparrow/mocks" "spend-sparrow/mocks"
"strings" "strings"
@@ -30,26 +31,26 @@ func TestSignUp(t *testing.T) {
t.Run("should check for correct email address", func(t *testing.T) { t.Run("should check for correct email address", func(t *testing.T) {
t.Parallel() t.Parallel()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!") _, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!")
assert.Equal(t, service.ErrInvalidEmail, err) assert.Equal(t, authentication.ErrInvalidEmail, err)
}) })
t.Run("should check for password complexity", func(t *testing.T) { t.Run("should check for password complexity", func(t *testing.T) {
t.Parallel() t.Parallel()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
weakPasswords := []string{ weakPasswords := []string{
"123!ab", // too short "123!ab", // too short
@@ -60,13 +61,13 @@ func TestSignUp(t *testing.T) {
for _, password := range weakPasswords { for _, password := range weakPasswords {
_, err := underTest.SignUp(context.Background(), "some@valid.email", password) _, err := underTest.SignUp(context.Background(), "some@valid.email", password)
assert.Equal(t, service.ErrInvalidPassword, err) assert.Equal(t, authentication.ErrInvalidPassword, err)
} }
}) })
t.Run("should signup correctly", func(t *testing.T) { t.Run("should signup correctly", func(t *testing.T) {
t.Parallel() t.Parallel()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
@@ -77,7 +78,7 @@ func TestSignUp(t *testing.T) {
salt := []byte("salt") salt := []byte("salt")
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime) expected := auth_types.NewUser(userId, email, false, nil, false, authentication.GetHashPassword(password, salt), salt, createTime)
ctx := context.Background() ctx := context.Background()
@@ -86,7 +87,7 @@ func TestSignUp(t *testing.T) {
mockClock.EXPECT().Now().Return(createTime) mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil) mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
actual, err := underTest.SignUp(context.Background(), email, password) actual, err := underTest.SignUp(context.Background(), email, password)
require.NoError(t, err) require.NoError(t, err)
@@ -96,7 +97,7 @@ func TestSignUp(t *testing.T) {
t.Run("should return ErrAccountExists", func(t *testing.T) { t.Run("should return ErrAccountExists", func(t *testing.T) {
t.Parallel() t.Parallel()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
@@ -106,19 +107,19 @@ func TestSignUp(t *testing.T) {
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
password := "SomeStrongPassword123!" password := "SomeStrongPassword123!"
salt := []byte("salt") salt := []byte("salt")
user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime) user := auth_types.NewUser(userId, email, false, nil, false, authentication.GetHashPassword(password, salt), salt, createTime)
ctx := context.Background() ctx := context.Background()
mockRandom.EXPECT().UUID(ctx).Return(user.Id, nil) mockRandom.EXPECT().UUID(ctx).Return(user.Id, nil)
mockRandom.EXPECT().Bytes(ctx, 16).Return(salt, nil) mockRandom.EXPECT().Bytes(ctx, 16).Return(salt, nil)
mockClock.EXPECT().Now().Return(createTime) mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(db.ErrAlreadyExists) mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(core.ErrAlreadyExists)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp(context.Background(), user.Email, password) _, err := underTest.SignUp(context.Background(), user.Email, password)
assert.Equal(t, service.ErrAccountExists, err) assert.Equal(t, authentication.ErrAccountExists, err)
}) })
} }
@@ -127,30 +128,30 @@ func TestSendVerificationMail(t *testing.T) {
t.Run("should use stored token and send mail", func(t *testing.T) { t.Run("should use stored token and send mail", func(t *testing.T) {
t.Parallel() t.Parallel()
token := types.NewToken( token := auth_types.NewToken(
uuid.New(), uuid.New(),
"sessionId", "sessionId",
"someRandomTokenToUse", "someRandomTokenToUse",
types.TokenTypeEmailVerify, auth_types.TokenTypeEmailVerify,
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
tokens := []*types.Token{token} tokens := []*auth_types.Token{token}
email := "some@email.de" email := "some@email.de"
userId := uuid.New() userId := uuid.New()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
ctx := context.Background() ctx := context.Background()
mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, types.TokenTypeEmailVerify).Return(tokens, nil) mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, auth_types.TokenTypeEmailVerify).Return(tokens, nil)
mockMail.EXPECT().SendMail(ctx, email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool { mockMail.EXPECT().SendMail(ctx, email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
return strings.Contains(message, token.Token) return strings.Contains(message, token.Token)
})).Return() })).Return()
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
underTest.SendVerificationMail(context.Background(), userId, email) underTest.SendVerificationMail(context.Background(), userId, email)
}) })

View File

@@ -7,8 +7,9 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"spend-sparrow/internal" "spend-sparrow/internal"
"spend-sparrow/internal/service" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/types" "spend-sparrow/internal/authentication"
"spend-sparrow/internal/core"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -117,7 +118,7 @@ func waitForReady(
default: default:
if time.Since(startTime) >= timeout { if time.Since(startTime) >= timeout {
t.Fatal("timeout reached while waiting for endpoint") t.Fatal("timeout reached while waiting for endpoint")
return types.ErrInternal return core.ErrInternal
} }
// wait a little while between checks // wait a little while between checks
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
@@ -178,7 +179,7 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s
t.Helper() t.Helper()
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" + add sessionId := "session-id" + add
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
csrfToken := "my-verifying-token" + add csrfToken := "my-verifying-token" + add
email := add + "mail@mail.de" email := add + "mail@mail.de"
@@ -193,7 +194,7 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s
_, err = db.ExecContext(context.Background(), ` _, err = db.ExecContext(context.Background(), `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf) VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, auth_types.TokenTypeCsrf)
require.NoError(t, err) require.NoError(t, err)
return userId, csrfToken, sessionId return userId, csrfToken, sessionId

View File

@@ -3,8 +3,8 @@ package test_test
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"spend-sparrow/internal/service" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/types" "spend-sparrow/internal/authentication"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -110,7 +110,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -136,7 +136,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -163,7 +163,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -206,7 +206,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -247,7 +247,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
@@ -295,7 +295,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
@@ -414,7 +414,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
@@ -467,7 +467,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -550,7 +550,7 @@ func TestIntegrationAuth(t *testing.T) {
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), authentication.GetHashPassword("password", []byte("salt")), []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil)
@@ -631,7 +631,7 @@ func TestIntegrationAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
var token string var token string
err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token) err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", auth_types.TokenTypeEmailVerify).Scan(&token)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, token) assert.NotEmpty(t, token)
}) })
@@ -676,7 +676,7 @@ func TestIntegrationAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.ExecContext(ctx, ` _, err = db.ExecContext(ctx, `
INSERT INTO token (token, user_id, type, created_at, expires_at) INSERT INTO token (token, user_id, type, created_at, expires_at)
VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify) VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, auth_types.TokenTypeEmailVerify)
require.NoError(t, err) require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil)
@@ -706,7 +706,7 @@ func TestIntegrationAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.ExecContext(ctx, ` _, err = db.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify) VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, auth_types.TokenTypeEmailVerify)
require.NoError(t, err) require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil)
@@ -746,7 +746,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -765,7 +765,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var csrfToken string var csrfToken string
err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken) err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypeCsrf).Scan(&csrfToken)
require.NoError(t, err) require.NoError(t, err)
req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil) req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil)
@@ -824,7 +824,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -870,7 +870,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1039,7 +1039,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1078,7 +1078,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -1128,7 +1128,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -1180,7 +1180,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
userIdOther := uuid.New() userIdOther := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1230,7 +1230,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt")) pass = authentication.GetHashPassword("MyNewSecurePassword1!", []byte("salt"))
var rows int var rows int
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
@@ -1259,7 +1259,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1287,7 +1287,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1317,7 +1317,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypePasswordReset).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
}) })
@@ -1362,7 +1362,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1399,7 +1399,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg) assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg)
var rows int var rows int
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypePasswordReset).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1412,7 +1412,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1455,7 +1455,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1475,7 +1475,7 @@ func TestIntegrationAuth(t *testing.T) {
token := "password-reset-token" token := "password-reset-token"
_, err = d.ExecContext(ctx, ` _, err = d.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset) VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", auth_types.TokenTypePasswordReset)
require.NoError(t, err) require.NoError(t, err)
formData := url.Values{ formData := url.Values{
@@ -1504,7 +1504,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1524,7 +1524,7 @@ func TestIntegrationAuth(t *testing.T) {
token := "password-reset-token" token := "password-reset-token"
_, err = d.ExecContext(ctx, ` _, err = d.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset) VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", auth_types.TokenTypePasswordReset)
require.NoError(t, err) require.NoError(t, err)
formData := url.Values{ formData := url.Values{
@@ -1553,7 +1553,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1590,7 +1590,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var token string var token string
err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token) err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", auth_types.TokenTypePasswordReset).Scan(&token)
require.NoError(t, err) require.NoError(t, err)
formData = url.Values{ formData = url.Values{