From 75433834ed8c65c634e7406c7be78634b831cca9 Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Thu, 25 Dec 2025 07:39:48 +0100 Subject: [PATCH] feat: extract authentication to domain package --- .mockery.yaml | 6 +- internal/account/service.go | 60 ++++--- .../{types/auth.go => auth_types/types.go} | 2 +- internal/{db/auth.go => authentication/db.go} | 151 ++++++++-------- .../auth.go => authentication/handler.go} | 80 ++++----- .../auth.go => authentication/service.go} | 169 +++++++++--------- .../template}/change_password.templ | 2 +- internal/authentication/template/default.go | 1 + .../template}/delete_account.templ | 2 +- .../template}/reset_password.templ | 2 +- .../template}/sign_in_or_up.templ | 14 +- .../template}/verify.templ | 2 +- .../template}/verify_response.templ | 2 +- internal/core/auth.go | 22 ++- internal/{service => core}/clock.go | 2 +- internal/core/default.go | 8 +- internal/core/error.go | 13 ++ internal/{service => core}/mail.go | 2 +- .../{service => core}/random_generator.go | 9 +- internal/core/render.go | 11 +- internal/db/error.go | 15 +- internal/db/migration.go | 8 +- internal/default.go | 17 +- internal/handler/dashboard.go | 2 +- internal/handler/middleware/authenticate.go | 6 +- .../middleware/cross_site_request_forgery.go | 4 +- internal/handler/middleware/default.go | 14 -- internal/handler/root_and_404.go | 5 +- internal/handler/transaction.go | 10 +- internal/handler/transaction_recurring.go | 3 +- internal/service/dashboard.go | 14 +- internal/service/default.go | 5 +- internal/service/error.go | 8 - internal/service/transaction.go | 80 +++++---- internal/service/transaction_recurring.go | 98 +++++----- internal/service/treasure_chest.go | 64 +++---- internal/template/auth/default.go | 1 - internal/types/types.go | 10 -- mocks/default.go | 1 - test/auth_it_test.go | 66 +++---- test/auth_test.go | 45 ++--- test/it_test.go | 11 +- test/main_it_test.go | 70 ++++---- 43 files changed, 559 insertions(+), 558 deletions(-) rename internal/{types/auth.go => auth_types/types.go} (98%) rename internal/{db/auth.go => authentication/db.go} (66%) rename internal/{handler/auth.go => authentication/handler.go} (81%) rename internal/{service/auth.go => authentication/service.go} (65%) rename internal/{template/auth => authentication/template}/change_password.templ (98%) create mode 100644 internal/authentication/template/default.go rename internal/{template/auth => authentication/template}/delete_account.templ (97%) rename internal/{template/auth => authentication/template}/reset_password.templ (97%) rename internal/{template/auth => authentication/template}/sign_in_or_up.templ (94%) rename internal/{template/auth => authentication/template}/verify.templ (97%) rename internal/{template/auth => authentication/template}/verify_response.templ (97%) rename internal/{service => core}/clock.go (92%) create mode 100644 internal/core/error.go rename internal/{service => core}/mail.go (98%) rename internal/{service => core}/random_generator.go (87%) delete mode 100644 internal/service/error.go delete mode 100644 internal/template/auth/default.go delete mode 100644 internal/types/types.go delete mode 100644 mocks/default.go diff --git a/.mockery.yaml b/.mockery.yaml index 1fd18b8..cb5c70c 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -3,11 +3,11 @@ dir: mocks/ outpkg: mocks issue-845-fix: True packages: - spend-sparrow/internal/service: + spend-sparrow/internal/core: interfaces: Random: Clock: Mail: - spend-sparrow/internal/db: + spend-sparrow/internal/authentication: interfaces: - Auth: + Db: diff --git a/internal/account/service.go b/internal/account/service.go index 30feb26..f0cc2e6 100644 --- a/internal/account/service.go +++ b/internal/account/service.go @@ -4,29 +4,31 @@ import ( "context" "errors" "fmt" - "github.com/google/uuid" - "github.com/jmoiron/sqlx" "log/slog" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/core" "spend-sparrow/internal/db" "spend-sparrow/internal/service" - "spend-sparrow/internal/types" + + "github.com/google/uuid" + "github.com/jmoiron/sqlx" ) type Service interface { - Add(ctx context.Context, user *types.User, name string) (*Account, error) - UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error) - Get(ctx context.Context, user *types.User, id string) (*Account, error) - GetAll(ctx context.Context, user *types.User) ([]*Account, error) - Delete(ctx context.Context, user *types.User, id string) error + Add(ctx context.Context, user *auth_types.User, name string) (*Account, error) + UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error) + Get(ctx context.Context, user *auth_types.User, id string) (*Account, error) + GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error) + Delete(ctx context.Context, user *auth_types.User, id string) error } type ServiceImpl struct { db *sqlx.DB - clock service.Clock - random service.Random + clock core.Clock + random core.Random } -func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Service { +func NewServiceImpl(db *sqlx.DB, random core.Random, clock core.Clock) Service { return ServiceImpl{ db: db, clock: clock, @@ -34,14 +36,14 @@ func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Ser } } -func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*Account, error) { +func (s ServiceImpl) Add(ctx context.Context, user *auth_types.User, name string) (*Account, error) { if user == nil { - return nil, types.ErrUnauthorized + return nil, core.ErrUnauthorized } newId, err := s.random.UUID(ctx) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } err = service.ValidateString(name, "name") @@ -76,9 +78,9 @@ func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*A return account, nil } -func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error) { +func (s ServiceImpl) UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error) { if user == nil { - return nil, types.ErrUnauthorized + return nil, core.ErrUnauthorized } err := service.ValidateString(name, "name") if err != nil { @@ -87,7 +89,7 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string uuid, err := uuid.Parse(id) if err != nil { slog.ErrorContext(ctx, "account update", "err", err) - return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) + return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -103,10 +105,10 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError(ctx, "account Update", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { - return nil, fmt.Errorf("account %v not found: %w", id, service.ErrBadRequest) + if errors.Is(err, core.ErrNotFound) { + return nil, fmt.Errorf("account %v not found: %w", id, core.ErrBadRequest) } - return nil, types.ErrInternal + return nil, core.ErrInternal } timestamp := s.clock.Now() @@ -136,14 +138,14 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string return &account, nil } -func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Account, error) { +func (s ServiceImpl) Get(ctx context.Context, user *auth_types.User, id string) (*Account, error) { if user == nil { - return nil, service.ErrUnauthorized + return nil, core.ErrUnauthorized } uuid, err := uuid.Parse(id) if err != nil { slog.ErrorContext(ctx, "account get", "err", err) - return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) + return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } var account Account @@ -158,9 +160,9 @@ func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Acc return &account, nil } -func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account, error) { +func (s ServiceImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error) { if user == nil { - return nil, service.ErrUnauthorized + return nil, core.ErrUnauthorized } accounts := make([]*Account, 0) @@ -174,14 +176,14 @@ func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account, return accounts, nil } -func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) error { +func (s ServiceImpl) Delete(ctx context.Context, user *auth_types.User, id string) error { if user == nil { - return service.ErrUnauthorized + return core.ErrUnauthorized } uuid, err := uuid.Parse(id) if err != nil { slog.ErrorContext(ctx, "account delete", "err", err) - return fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) + return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -200,7 +202,7 @@ func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) er return err } if transactionsCount > 0 { - return fmt.Errorf("account has transactions, cannot delete: %w", service.ErrBadRequest) + return fmt.Errorf("account has transactions, cannot delete: %w", core.ErrBadRequest) } res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id) diff --git a/internal/types/auth.go b/internal/auth_types/types.go similarity index 98% rename from internal/types/auth.go rename to internal/auth_types/types.go index b91e27f..fef0477 100644 --- a/internal/types/auth.go +++ b/internal/auth_types/types.go @@ -1,4 +1,4 @@ -package types +package auth_types import ( "time" diff --git a/internal/db/auth.go b/internal/authentication/db.go similarity index 66% rename from internal/db/auth.go rename to internal/authentication/db.go index f4eb361..7dec9ee 100644 --- a/internal/db/auth.go +++ b/internal/authentication/db.go @@ -1,11 +1,12 @@ -package db +package authentication import ( "context" "database/sql" "errors" "log/slog" - "spend-sparrow/internal/types" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/core" "strings" "time" @@ -13,36 +14,36 @@ import ( "github.com/jmoiron/sqlx" ) -type Auth interface { - InsertUser(ctx context.Context, user *types.User) error - UpdateUser(ctx context.Context, user *types.User) error - GetUserByEmail(ctx context.Context, email string) (*types.User, error) - GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) +type Db interface { + InsertUser(ctx context.Context, user *auth_types.User) error + UpdateUser(ctx context.Context, user *auth_types.User) error + GetUserByEmail(ctx context.Context, email string) (*auth_types.User, error) + GetUser(ctx context.Context, userId uuid.UUID) (*auth_types.User, error) DeleteUser(ctx context.Context, userId uuid.UUID) error - InsertToken(ctx context.Context, token *types.Token) error - GetToken(ctx context.Context, token string) (*types.Token, error) - GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) - GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) + InsertToken(ctx context.Context, token *auth_types.Token) error + GetToken(ctx context.Context, token string) (*auth_types.Token, error) + GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType auth_types.TokenType) ([]*auth_types.Token, error) + GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error) DeleteToken(ctx context.Context, token string) error - InsertSession(ctx context.Context, session *types.Session) error - GetSession(ctx context.Context, sessionId string) (*types.Session, error) - GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) + InsertSession(ctx context.Context, session *auth_types.Session) error + GetSession(ctx context.Context, sessionId string) (*auth_types.Session, error) + GetSessions(ctx context.Context, userId uuid.UUID) ([]*auth_types.Session, error) DeleteSession(ctx context.Context, sessionId string) error DeleteOldSessions(ctx context.Context) error DeleteOldTokens(ctx context.Context) error } -type AuthSqlite struct { +type DbSqlite struct { db *sqlx.DB } -func NewAuthSqlite(db *sqlx.DB) *AuthSqlite { - return &AuthSqlite{db: db} +func NewDbSqlite(db *sqlx.DB) *DbSqlite { + return &DbSqlite{db: db} } -func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error { +func (db DbSqlite) InsertUser(ctx context.Context, user *auth_types.User) error { _, err := db.db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, @@ -50,17 +51,17 @@ func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error { if err != nil { if strings.Contains(err.Error(), "email") { - return ErrAlreadyExists + return core.ErrAlreadyExists } slog.ErrorContext(ctx, "SQL error InsertUser", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil } -func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error { +func (db DbSqlite) UpdateUser(ctx context.Context, user *auth_types.User) error { _, err := db.db.ExecContext(ctx, ` UPDATE user SET email_verified = ?, email_verified_at = ?, password = ? @@ -69,13 +70,13 @@ func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error { if err != nil { slog.ErrorContext(ctx, "SQL error UpdateUser", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil } -func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.User, error) { +func (db DbSqlite) GetUserByEmail(ctx context.Context, email string) (*auth_types.User, error) { var ( userId uuid.UUID emailVerified bool @@ -92,17 +93,17 @@ func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.U WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNotFound + return nil, core.ErrNotFound } else { slog.ErrorContext(ctx, "SQL error GetUser", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } } - return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil + return auth_types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil } -func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) { +func (db DbSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*auth_types.User, error) { var ( email string emailVerified bool @@ -119,92 +120,92 @@ func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { - return nil, ErrNotFound + return nil, core.ErrNotFound } else { slog.ErrorContext(ctx, "SQL error GetUser", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } } - return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil + return auth_types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil } -func (db AuthSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error { +func (db DbSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error { tx, err := db.db.BeginTx(ctx, nil) if err != nil { slog.ErrorContext(ctx, "Could not start transaction", "err", err) - return types.ErrInternal + return core.ErrInternal } _, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.ErrorContext(ctx, "Could not delete accounts", "err", err) - return types.ErrInternal + return core.ErrInternal } _, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.ErrorContext(ctx, "Could not delete user tokens", "err", err) - return types.ErrInternal + return core.ErrInternal } _, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.ErrorContext(ctx, "Could not delete sessions", "err", err) - return types.ErrInternal + return core.ErrInternal } _, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.ErrorContext(ctx, "Could not delete user", "err", err) - return types.ErrInternal + return core.ErrInternal } _, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.ErrorContext(ctx, "Could not delete user", "err", err) - return types.ErrInternal + return core.ErrInternal } _, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() slog.ErrorContext(ctx, "Could not delete user", "err", err) - return types.ErrInternal + return core.ErrInternal } err = tx.Commit() if err != nil { slog.ErrorContext(ctx, "Could not commit transaction", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil } -func (db AuthSqlite) InsertToken(ctx context.Context, token *types.Token) error { +func (db DbSqlite) InsertToken(ctx context.Context, token *auth_types.Token) error { _, err := db.db.ExecContext(ctx, ` INSERT INTO token (user_id, session_id, type, token, created_at, expires_at) VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt) if err != nil { slog.ErrorContext(ctx, "Could not insert token", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil } -func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, error) { +func (db DbSqlite) GetToken(ctx context.Context, token string) (*auth_types.Token, error) { var ( userId uuid.UUID sessionId string - tokenType types.TokenType + tokenType auth_types.TokenType createdAtStr string expiresAtStr string createdAt time.Time @@ -219,29 +220,29 @@ func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, if err != nil { if errors.Is(err, sql.ErrNoRows) { slog.InfoContext(ctx, "Token not found", "token", token) - return nil, ErrNotFound + return nil, core.ErrNotFound } else { slog.ErrorContext(ctx, "Could not get token", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } } createdAt, err = time.Parse(time.RFC3339, createdAtStr) if err != nil { slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } expiresAt, err = time.Parse(time.RFC3339, expiresAtStr) if err != nil { slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } - return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil + return auth_types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil } -func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) { +func (db DbSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType auth_types.TokenType) ([]*auth_types.Token, error) { query, err := db.db.QueryContext(ctx, ` SELECT token, created_at, expires_at FROM token @@ -250,13 +251,13 @@ func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.U if err != nil { slog.ErrorContext(ctx, "Could not get token", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } return getTokensFromQuery(ctx, query, userId, "", tokenType) } -func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) { +func (db DbSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error) { query, err := db.db.QueryContext(ctx, ` SELECT token, created_at, expires_at FROM token @@ -265,14 +266,14 @@ func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId if err != nil { slog.ErrorContext(ctx, "Could not get token", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } return getTokensFromQuery(ctx, query, uuid.Nil, sessionId, tokenType) } -func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID, sessionId string, tokenType types.TokenType) ([]*types.Token, error) { - var tokens []*types.Token +func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error) { + var tokens []*auth_types.Token hasRows := false for query.Next() { @@ -289,54 +290,54 @@ func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID, err := query.Scan(&token, &createdAtStr, &expiresAtStr) if err != nil { slog.ErrorContext(ctx, "Could not scan token", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } createdAt, err = time.Parse(time.RFC3339, createdAtStr) if err != nil { slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } expiresAt, err = time.Parse(time.RFC3339, expiresAtStr) if err != nil { slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } - tokens = append(tokens, types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt)) + tokens = append(tokens, auth_types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt)) } if !hasRows { - return nil, ErrNotFound + return nil, core.ErrNotFound } return tokens, nil } -func (db AuthSqlite) DeleteToken(ctx context.Context, token string) error { +func (db DbSqlite) DeleteToken(ctx context.Context, token string) error { _, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token) if err != nil { slog.ErrorContext(ctx, "Could not delete token", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil } -func (db AuthSqlite) InsertSession(ctx context.Context, session *types.Session) error { +func (db DbSqlite) InsertSession(ctx context.Context, session *auth_types.Session) error { _, err := db.db.ExecContext(ctx, ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt) if err != nil { slog.ErrorContext(ctx, "Could not insert new session", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil } -func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.Session, error) { +func (db DbSqlite) GetSession(ctx context.Context, sessionId string) (*auth_types.Session, error) { var ( userId uuid.UUID createdAt time.Time @@ -350,56 +351,56 @@ func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.S if err != nil { slog.WarnContext(ctx, "Session not found", "session-id", sessionId, "err", err) - return nil, ErrNotFound + return nil, core.ErrNotFound } - return types.NewSession(sessionId, userId, createdAt, expiresAt), nil + return auth_types.NewSession(sessionId, userId, createdAt, expiresAt), nil } -func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) { - var sessions []*types.Session +func (db DbSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*auth_types.Session, error) { + var sessions []*auth_types.Session err := db.db.SelectContext(ctx, &sessions, ` SELECT * FROM session WHERE user_id = ?`, userId) if err != nil { slog.ErrorContext(ctx, "Could not get sessions", "err", err) - return nil, types.ErrInternal + return nil, core.ErrInternal } return sessions, nil } -func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error { +func (db DbSqlite) DeleteSession(ctx context.Context, sessionId string) error { if sessionId != "" { _, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId) if err != nil { slog.ErrorContext(ctx, "Could not delete session", "err", err) - return types.ErrInternal + return core.ErrInternal } } return nil } -func (db AuthSqlite) DeleteOldSessions(ctx context.Context) error { +func (db DbSqlite) DeleteOldSessions(ctx context.Context) error { _, err := db.db.ExecContext(ctx, ` DELETE FROM session WHERE expires_at < datetime('now')`) if err != nil { slog.ErrorContext(ctx, "Could not delete old sessions", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil } -func (db AuthSqlite) DeleteOldTokens(ctx context.Context) error { +func (db DbSqlite) DeleteOldTokens(ctx context.Context) error { _, err := db.db.ExecContext(ctx, ` DELETE FROM token WHERE expires_at < datetime('now')`) if err != nil { slog.ErrorContext(ctx, "Could not delete old tokens", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil } diff --git a/internal/handler/auth.go b/internal/authentication/handler.go similarity index 81% rename from internal/handler/auth.go rename to internal/authentication/handler.go index 7370b98..6cf32f6 100644 --- a/internal/handler/auth.go +++ b/internal/authentication/handler.go @@ -1,36 +1,34 @@ -package handler +package authentication import ( "errors" "log/slog" "net/http" "net/url" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/authentication/template" "spend-sparrow/internal/core" - "spend-sparrow/internal/handler/middleware" - "spend-sparrow/internal/service" - "spend-sparrow/internal/template/auth" - "spend-sparrow/internal/types" "spend-sparrow/internal/utils" "time" ) -type Auth interface { +type Handler interface { Handle(router *http.ServeMux) } -type AuthImpl struct { - service service.Auth +type HandlerImpl struct { + service Service render *core.Render } -func NewAuth(service service.Auth, render *core.Render) Auth { - return AuthImpl{ +func NewHandler(service Service, render *core.Render) Handler { + return HandlerImpl{ service: service, render: render, } } -func (handler AuthImpl) Handle(router *http.ServeMux) { +func (handler HandlerImpl) Handle(router *http.ServeMux) { router.Handle("GET /auth/signin", handler.handleSignInPage()) router.Handle("POST /api/auth/signin", handler.handleSignIn()) @@ -57,7 +55,7 @@ var ( securityWaitDuration = 250 * time.Millisecond ) -func (handler AuthImpl) handleSignInPage() http.HandlerFunc { +func (handler HandlerImpl) handleSignInPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -71,17 +69,17 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc { return } - comp := auth.SignInOrUpComp(true) + comp := template.SignInOrUpComp(true) handler.render.RenderLayout(r, w, comp, nil) } } -func (handler AuthImpl) handleSignIn() http.HandlerFunc { +func (handler HandlerImpl) handleSignIn() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) - user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) { + user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*auth_types.User, error) { session := core.GetSession(r) email := r.FormValue("email") password := r.FormValue("password") @@ -91,14 +89,14 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc { return nil, err } - cookie := middleware.CreateSessionCookie(session.Id) + cookie := core.CreateSessionCookie(session.Id) http.SetCookie(w, &cookie) return user, nil }) if err != nil { - if errors.Is(err, service.ErrInvalidCredentials) { + if errors.Is(err, ErrInvalidCredentials) { utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Invalid email or password", http.StatusUnauthorized) } else { utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError) @@ -114,7 +112,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc { } } -func (handler AuthImpl) handleSignUpPage() http.HandlerFunc { +func (handler HandlerImpl) handleSignUpPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -129,12 +127,12 @@ func (handler AuthImpl) handleSignUpPage() http.HandlerFunc { return } - signUpComp := auth.SignInOrUpComp(false) + signUpComp := template.SignInOrUpComp(false) handler.render.RenderLayout(r, w, signUpComp, nil) } } -func (handler AuthImpl) handleSignUpVerifyPage() http.HandlerFunc { +func (handler HandlerImpl) handleSignUpVerifyPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -149,12 +147,12 @@ func (handler AuthImpl) handleSignUpVerifyPage() http.HandlerFunc { return } - signIn := auth.VerifyComp() + signIn := template.VerifyComp() handler.render.RenderLayout(r, w, signIn, user) } } -func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc { +func (handler HandlerImpl) handleVerifyResendComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -173,7 +171,7 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc { } } -func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc { +func (handler HandlerImpl) handleSignUpVerifyResponsePage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -182,7 +180,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc { err := handler.service.VerifyUserEmail(r.Context(), token) isVerified := err == nil - comp := auth.VerifyResponseComp(isVerified) + comp := template.VerifyResponseComp(isVerified) var status int if isVerified { @@ -195,7 +193,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc { } } -func (handler AuthImpl) handleSignUp() http.HandlerFunc { +func (handler HandlerImpl) handleSignUp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -216,14 +214,14 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc { if err != nil { switch { - case errors.Is(err, types.ErrInternal): + case errors.Is(err, core.ErrInternal): utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError) return - case errors.Is(err, service.ErrInvalidEmail): + case errors.Is(err, ErrInvalidEmail): utils.TriggerToastWithStatus(r.Context(), w, r, "error", "The email provided is invalid", http.StatusBadRequest) return - case errors.Is(err, service.ErrInvalidPassword): - utils.TriggerToastWithStatus(r.Context(), w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest) + case errors.Is(err, ErrInvalidPassword): + utils.TriggerToastWithStatus(r.Context(), w, r, "error", ErrInvalidPassword.Error(), http.StatusBadRequest) return } // If err is "service.ErrAccountExists", then just continue @@ -233,7 +231,7 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc { } } -func (handler AuthImpl) handleSignOut() http.HandlerFunc { +func (handler HandlerImpl) handleSignOut() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -262,7 +260,7 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc { } } -func (handler AuthImpl) handleDeleteAccountPage() http.HandlerFunc { +func (handler HandlerImpl) handleDeleteAccountPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -272,12 +270,12 @@ func (handler AuthImpl) handleDeleteAccountPage() http.HandlerFunc { return } - comp := auth.DeleteAccountComp() + comp := template.DeleteAccountComp() handler.render.RenderLayout(r, w, comp, user) } } -func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { +func (handler HandlerImpl) handleDeleteAccountComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -291,7 +289,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { err := handler.service.DeleteAccount(r.Context(), user, password) if err != nil { - if errors.Is(err, service.ErrInvalidCredentials) { + if errors.Is(err, ErrInvalidCredentials) { utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Password not correct", http.StatusBadRequest) } else { utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Internal Server Error", http.StatusInternalServerError) @@ -303,7 +301,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { } } -func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc { +func (handler HandlerImpl) handleChangePasswordPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -316,12 +314,12 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc { return } - comp := auth.ChangePasswordComp(isPasswordReset) + comp := template.ChangePasswordComp(isPasswordReset) handler.render.RenderLayout(r, w, comp, user) } } -func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { +func (handler HandlerImpl) handleChangePasswordComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -345,7 +343,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { } } -func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc { +func (handler HandlerImpl) handleForgotPasswordPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -355,12 +353,12 @@ func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc { return } - comp := auth.ResetPasswordComp() + comp := template.ResetPasswordComp() handler.render.RenderLayout(r, w, comp, user) } } -func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { +func (handler HandlerImpl) handleForgotPasswordComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) @@ -383,7 +381,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { } } -func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { +func (handler HandlerImpl) handleForgotPasswordResponseComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { core.UpdateSpan(r) diff --git a/internal/service/auth.go b/internal/authentication/service.go similarity index 65% rename from internal/service/auth.go rename to internal/authentication/service.go index 06e782a..64424fd 100644 --- a/internal/service/auth.go +++ b/internal/authentication/service.go @@ -1,4 +1,4 @@ -package service +package authentication import ( "context" @@ -6,7 +6,8 @@ import ( "errors" "log/slog" "net/mail" - "spend-sparrow/internal/db" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/core" mailTemplate "spend-sparrow/internal/template/mail" "spend-sparrow/internal/types" "strings" @@ -25,39 +26,39 @@ var ( ErrTokenInvalid = errors.New("token is invalid") ) -type Auth interface { - SignUp(ctx context.Context, email string, password string) (*types.User, error) +type Service interface { + SignUp(ctx context.Context, email string, password string) (*auth_types.User, error) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) VerifyUserEmail(ctx context.Context, token string) error - SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) - SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) - SignInAnonymous(ctx context.Context) (*types.Session, error) + SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_types.Session, *auth_types.User, error) + SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error) + SignInAnonymous(ctx context.Context) (*auth_types.Session, error) SignOut(ctx context.Context, sessionId string) error - DeleteAccount(ctx context.Context, user *types.User, currPass string) error + DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error - ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error + ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error SendForgotPasswordMail(ctx context.Context, email string) error ForgotPassword(ctx context.Context, token string, newPass string) error IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool - GetCsrfToken(ctx context.Context, session *types.Session) (string, error) + GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error) CleanupSessionsAndTokens(ctx context.Context) error } -type AuthImpl struct { - db db.Auth - random Random - clock Clock - mail Mail +type ServiceImpl struct { + db Db + random core.Random + clock core.Clock + mail core.Mail serverSettings *types.Settings } -func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *types.Settings) *AuthImpl { - return &AuthImpl{ +func NewService(db Db, random core.Random, clock core.Clock, mail core.Mail, serverSettings *types.Settings) *ServiceImpl { + return &ServiceImpl{ db: db, random: random, clock: clock, @@ -66,13 +67,13 @@ func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings * } } -func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) { +func (service ServiceImpl) SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_types.Session, *auth_types.User, error) { user, err := service.db.GetUserByEmail(ctx, email) if err != nil { - if errors.Is(err, db.ErrNotFound) { + if errors.Is(err, core.ErrNotFound) { return nil, nil, ErrInvalidCredentials } else { - return nil, nil, types.ErrInternal + return nil, nil, core.ErrInternal } } @@ -84,36 +85,36 @@ func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, emai newSession, err := service.createSession(ctx, user.Id) if err != nil { - return nil, nil, types.ErrInternal + return nil, nil, core.ErrInternal } err = service.db.DeleteSession(ctx, session.Id) if err != nil { - return nil, nil, types.ErrInternal + return nil, nil, core.ErrInternal } - tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) + tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf) if err != nil { - return nil, nil, types.ErrInternal + return nil, nil, core.ErrInternal } for _, token := range tokens { err = service.db.DeleteToken(ctx, token.Token) if err != nil { - return nil, nil, types.ErrInternal + return nil, nil, core.ErrInternal } } return newSession, user, nil } -func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) { +func (service ServiceImpl) SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error) { if sessionId == "" { return nil, nil, ErrSessionIdInvalid } session, err := service.db.GetSession(ctx, sessionId) if err != nil { - return nil, nil, types.ErrInternal + return nil, nil, core.ErrInternal } if session.ExpiresAt.Before(service.clock.Now()) { _ = service.db.DeleteSession(ctx, sessionId) @@ -126,16 +127,16 @@ func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*t user, err := service.db.GetUser(ctx, session.UserId) if err != nil { - return nil, nil, types.ErrInternal + return nil, nil, core.ErrInternal } return session, user, nil } -func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, error) { +func (service ServiceImpl) SignInAnonymous(ctx context.Context) (*auth_types.Session, error) { session, err := service.createSession(ctx, uuid.Nil) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } slog.InfoContext(ctx, "anonymous session created", "session-id", session.Id) @@ -143,7 +144,7 @@ func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, er return session, nil } -func (service AuthImpl) SignUp(ctx context.Context, email string, password string) (*types.User, error) { +func (service ServiceImpl) SignUp(ctx context.Context, email string, password string) (*auth_types.User, error) { _, err := mail.ParseAddress(email) if err != nil { return nil, ErrInvalidEmail @@ -155,37 +156,37 @@ func (service AuthImpl) SignUp(ctx context.Context, email string, password strin userId, err := service.random.UUID(ctx) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } salt, err := service.random.Bytes(ctx, 16) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } hash := GetHashPassword(password, salt) - user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) + user := auth_types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) err = service.db.InsertUser(ctx, user) if err != nil { - if errors.Is(err, db.ErrAlreadyExists) { + if errors.Is(err, core.ErrAlreadyExists) { return nil, ErrAccountExists } else { - return nil, types.ErrInternal + return nil, core.ErrInternal } } return user, nil } -func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) { - tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify) - if err != nil && !errors.Is(err, db.ErrNotFound) { +func (service ServiceImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) { + tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, auth_types.TokenTypeEmailVerify) + if err != nil && !errors.Is(err, core.ErrNotFound) { return } - var token *types.Token + var token *auth_types.Token if len(tokens) > 0 { token = tokens[0] @@ -197,11 +198,11 @@ func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UU return } - token = types.NewToken( + token = auth_types.NewToken( userId, "", newTokenStr, - types.TokenTypeEmailVerify, + auth_types.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) @@ -221,29 +222,29 @@ func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UU service.mail.SendMail(ctx, email, "Welcome to spend-sparrow", w.String()) } -func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error { +func (service ServiceImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error { if tokenStr == "" { - return types.ErrInternal + return core.ErrInternal } token, err := service.db.GetToken(ctx, tokenStr) if err != nil { - return types.ErrInternal + return core.ErrInternal } user, err := service.db.GetUser(ctx, token.UserId) if err != nil { - return types.ErrInternal + return core.ErrInternal } - if token.Type != types.TokenTypeEmailVerify { - return types.ErrInternal + if token.Type != auth_types.TokenTypeEmailVerify { + return core.ErrInternal } now := service.clock.Now() if token.ExpiresAt.Before(now) { - return types.ErrInternal + return core.ErrInternal } user.EmailVerified = true @@ -251,21 +252,21 @@ func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) er err = service.db.UpdateUser(ctx, user) if err != nil { - return types.ErrInternal + return core.ErrInternal } _ = service.db.DeleteToken(ctx, token.Token) return nil } -func (service AuthImpl) SignOut(ctx context.Context, sessionId string) error { +func (service ServiceImpl) SignOut(ctx context.Context, sessionId string) error { return service.db.DeleteSession(ctx, sessionId) } -func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, currPass string) error { +func (service ServiceImpl) DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error { userDb, err := service.db.GetUser(ctx, user.Id) if err != nil { - return types.ErrInternal + return core.ErrInternal } currHash := GetHashPassword(currPass, userDb.Salt) @@ -283,7 +284,7 @@ func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, cur return nil } -func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error { +func (service ServiceImpl) ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error { if !isPasswordValid(newPass) { return ErrInvalidPassword } @@ -308,13 +309,13 @@ func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, se sessions, err := service.db.GetSessions(ctx, user.Id) if err != nil { - return types.ErrInternal + return core.ErrInternal } for _, s := range sessions { if s.Id != sessionId { err = service.db.DeleteSession(ctx, s.Id) if err != nil { - return types.ErrInternal + return core.ErrInternal } } } @@ -322,7 +323,7 @@ func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, se return nil } -func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string) error { +func (service ServiceImpl) SendForgotPasswordMail(ctx context.Context, email string) error { tokenStr, err := service.random.String(ctx, 32) if err != nil { return err @@ -330,38 +331,38 @@ func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string user, err := service.db.GetUserByEmail(ctx, email) if err != nil { - if errors.Is(err, db.ErrNotFound) { + if errors.Is(err, core.ErrNotFound) { return nil } else { - return types.ErrInternal + return core.ErrInternal } } - token := types.NewToken( + token := auth_types.NewToken( user.Id, "", tokenStr, - types.TokenTypePasswordReset, + auth_types.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute)) err = service.db.InsertToken(ctx, token) if err != nil { - return types.ErrInternal + return core.ErrInternal } var mail strings.Builder err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail) if err != nil { slog.ErrorContext(ctx, "Could not render reset password email", "err", err) - return types.ErrInternal + return core.ErrInternal } service.mail.SendMail(ctx, email, "Reset Password", mail.String()) return nil } -func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error { +func (service ServiceImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error { if !isPasswordValid(newPass) { return ErrInvalidPassword } @@ -376,7 +377,7 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new return err } - if token.Type != types.TokenTypePasswordReset || + if token.Type != auth_types.TokenTypePasswordReset || token.ExpiresAt.Before(service.clock.Now()) { return ErrTokenInvalid } @@ -384,7 +385,7 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new user, err := service.db.GetUser(ctx, token.UserId) if err != nil { slog.ErrorContext(ctx, "Could not get user from token", "err", err) - return types.ErrInternal + return core.ErrInternal } passHash := GetHashPassword(newPass, user.Salt) @@ -397,26 +398,26 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new sessions, err := service.db.GetSessions(ctx, user.Id) if err != nil { - return types.ErrInternal + return core.ErrInternal } for _, session := range sessions { err = service.db.DeleteSession(ctx, session.Id) if err != nil { - return types.ErrInternal + return core.ErrInternal } } return nil } -func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool { +func (service ServiceImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool { token, err := service.db.GetToken(ctx, tokenStr) if err != nil { return false } - if token.Type != types.TokenTypeCsrf || + if token.Type != auth_types.TokenTypeCsrf || token.SessionId != sessionId || token.ExpiresAt.Before(service.clock.Now()) { return false @@ -425,12 +426,12 @@ func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, s return true } -func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session) (string, error) { +func (service ServiceImpl) GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error) { if session == nil { - return "", types.ErrInternal + return "", core.ErrInternal } - tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) + tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf) if len(tokens) > 0 { return tokens[0].Token, nil @@ -438,19 +439,19 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session tokenStr, err := service.random.String(ctx, 32) if err != nil { - return "", types.ErrInternal + return "", core.ErrInternal } - token := types.NewToken( + token := auth_types.NewToken( session.UserId, session.Id, tokenStr, - types.TokenTypeCsrf, + auth_types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(8*time.Hour)) err = service.db.InsertToken(ctx, token) if err != nil { - return "", types.ErrInternal + return "", core.ErrInternal } slog.InfoContext(ctx, "CSRF-Token created", "token", tokenStr) @@ -458,34 +459,34 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session return tokenStr, nil } -func (service AuthImpl) CleanupSessionsAndTokens(ctx context.Context) error { +func (service ServiceImpl) CleanupSessionsAndTokens(ctx context.Context) error { err := service.db.DeleteOldSessions(ctx) if err != nil { - return types.ErrInternal + return core.ErrInternal } err = service.db.DeleteOldTokens(ctx) if err != nil { - return types.ErrInternal + return core.ErrInternal } return nil } -func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*types.Session, error) { +func (service ServiceImpl) createSession(ctx context.Context, userId uuid.UUID) (*auth_types.Session, error) { sessionId, err := service.random.String(ctx, 32) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } createAt := service.clock.Now() expiresAt := createAt.Add(24 * time.Hour) - session := types.NewSession(sessionId, userId, createAt, expiresAt) + session := auth_types.NewSession(sessionId, userId, createAt, expiresAt) err = service.db.InsertSession(ctx, session) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } return session, nil diff --git a/internal/template/auth/change_password.templ b/internal/authentication/template/change_password.templ similarity index 98% rename from internal/template/auth/change_password.templ rename to internal/authentication/template/change_password.templ index 9e761fb..db3351d 100644 --- a/internal/template/auth/change_password.templ +++ b/internal/authentication/template/change_password.templ @@ -1,4 +1,4 @@ -package auth +package template templ ChangePasswordComp(isPasswordReset bool) {
diff --git a/internal/template/auth/verify_response.templ b/internal/authentication/template/verify_response.templ similarity index 97% rename from internal/template/auth/verify_response.templ rename to internal/authentication/template/verify_response.templ index 1b96e55..0a112f4 100644 --- a/internal/template/auth/verify_response.templ +++ b/internal/authentication/template/verify_response.templ @@ -1,4 +1,4 @@ -package auth +package template templ VerifyResponseComp(isVerified bool) {
diff --git a/internal/core/auth.go b/internal/core/auth.go index f51574d..151e9aa 100644 --- a/internal/core/auth.go +++ b/internal/core/auth.go @@ -2,7 +2,7 @@ package core import ( "net/http" - "spend-sparrow/internal/types" + "spend-sparrow/internal/auth_types" ) type ContextKey string @@ -10,13 +10,13 @@ type ContextKey string var SessionKey ContextKey = "session" var UserKey ContextKey = "user" -func GetUser(r *http.Request) *types.User { +func GetUser(r *http.Request) *auth_types.User { obj := r.Context().Value(UserKey) if obj == nil { return nil } - user, ok := obj.(*types.User) + user, ok := obj.(*auth_types.User) if !ok { return nil } @@ -24,16 +24,28 @@ func GetUser(r *http.Request) *types.User { return user } -func GetSession(r *http.Request) *types.Session { +func GetSession(r *http.Request) *auth_types.Session { obj := r.Context().Value(SessionKey) if obj == nil { return nil } - session, ok := obj.(*types.Session) + session, ok := obj.(*auth_types.Session) if !ok { return nil } return session } + +func CreateSessionCookie(sessionId string) http.Cookie { + return http.Cookie{ + Name: "id", + Value: sessionId, + MaxAge: 60 * 60 * 8, // 8 hours + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Path: "/", + } +} diff --git a/internal/service/clock.go b/internal/core/clock.go similarity index 92% rename from internal/service/clock.go rename to internal/core/clock.go index 8944dc7..4e4ea67 100644 --- a/internal/service/clock.go +++ b/internal/core/clock.go @@ -1,4 +1,4 @@ -package service +package core import "time" diff --git a/internal/core/default.go b/internal/core/default.go index 5b8a1ab..2830102 100644 --- a/internal/core/default.go +++ b/internal/core/default.go @@ -3,8 +3,6 @@ package core import ( "errors" "net/http" - "spend-sparrow/internal/db" - "spend-sparrow/internal/service" "spend-sparrow/internal/utils" "strings" @@ -14,13 +12,13 @@ import ( func HandleError(w http.ResponseWriter, r *http.Request, err error) { switch { - case errors.Is(err, service.ErrUnauthorized): + case errors.Is(err, ErrUnauthorized): utils.TriggerToastWithStatus(r.Context(), w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized) return - case errors.Is(err, service.ErrBadRequest): + case errors.Is(err, ErrBadRequest): utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusBadRequest) return - case errors.Is(err, db.ErrNotFound): + case errors.Is(err, ErrNotFound): utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusNotFound) return } diff --git a/internal/core/error.go b/internal/core/error.go new file mode 100644 index 0000000..854e3ae --- /dev/null +++ b/internal/core/error.go @@ -0,0 +1,13 @@ +package core + +import "errors" + +var ( + ErrNotFound = errors.New("the value does not exist") + ErrAlreadyExists = errors.New("row already exists") + + ErrInternal = errors.New("internal server error") + ErrUnauthorized = errors.New("you are not authorized to perform this action") + + ErrBadRequest = errors.New("bad request") +) diff --git a/internal/service/mail.go b/internal/core/mail.go similarity index 98% rename from internal/service/mail.go rename to internal/core/mail.go index 9a0edd7..9bed681 100644 --- a/internal/service/mail.go +++ b/internal/core/mail.go @@ -1,4 +1,4 @@ -package service +package core import ( "context" diff --git a/internal/service/random_generator.go b/internal/core/random_generator.go similarity index 87% rename from internal/service/random_generator.go rename to internal/core/random_generator.go index 331afdc..143912f 100644 --- a/internal/service/random_generator.go +++ b/internal/core/random_generator.go @@ -1,11 +1,10 @@ -package service +package core import ( "context" "crypto/rand" "encoding/base64" "log/slog" - "spend-sparrow/internal/types" "github.com/google/uuid" ) @@ -28,7 +27,7 @@ func (r *RandomImpl) Bytes(ctx context.Context, tsize int) ([]byte, error) { _, err := rand.Read(b) if err != nil { slog.ErrorContext(ctx, "Error generating random bytes", "err", err) - return []byte{}, types.ErrInternal + return []byte{}, ErrInternal } return b, nil @@ -38,7 +37,7 @@ func (r *RandomImpl) String(ctx context.Context, size int) (string, error) { bytes, err := r.Bytes(ctx, size) if err != nil { slog.ErrorContext(ctx, "Error generating random string", "err", err) - return "", types.ErrInternal + return "", ErrInternal } return base64.StdEncoding.EncodeToString(bytes), nil @@ -48,7 +47,7 @@ func (r *RandomImpl) UUID(ctx context.Context) (uuid.UUID, error) { id, err := uuid.NewRandom() if err != nil { slog.ErrorContext(ctx, "Error generating random UUID", "err", err) - return uuid.Nil, types.ErrInternal + return uuid.Nil, ErrInternal } return id, nil diff --git a/internal/core/render.go b/internal/core/render.go index a380a49..5fab9f2 100644 --- a/internal/core/render.go +++ b/internal/core/render.go @@ -1,10 +1,11 @@ package core import ( - "github.com/a-h/templ" "log/slog" "net/http" - "spend-sparrow/internal/types" + "spend-sparrow/internal/auth_types" + + "github.com/a-h/templ" ) type Render struct { @@ -28,18 +29,18 @@ func (render *Render) Render(r *http.Request, w http.ResponseWriter, comp templ. render.RenderWithStatus(r, w, comp, http.StatusOK) } -func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *types.User) { +func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *auth_types.User) { render.RenderLayoutWithStatus(r, w, slot, user, http.StatusOK) } -func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWriter, slot templ.Component, user *types.User, status int) { +func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWriter, slot templ.Component, user *auth_types.User, status int) { userComp := render.getUserComp(user) layout := Layout(slot, userComp, user != nil, r.URL.Path) render.RenderWithStatus(r, w, layout, status) } -func (render *Render) getUserComp(user *types.User) templ.Component { +func (render *Render) getUserComp(user *auth_types.User) templ.Component { if user != nil { return UserComp(user.Email) } else { diff --git a/internal/db/error.go b/internal/db/error.go index f81fe7e..3bc2a0d 100644 --- a/internal/db/error.go +++ b/internal/db/error.go @@ -5,33 +5,28 @@ import ( "database/sql" "errors" "log/slog" - "spend-sparrow/internal/types" -) - -var ( - ErrNotFound = errors.New("the value does not exist") - ErrAlreadyExists = errors.New("row already exists") + "spend-sparrow/internal/core" ) func TransformAndLogDbError(ctx context.Context, module string, r sql.Result, err error) error { if err != nil { if errors.Is(err, sql.ErrNoRows) { - return ErrNotFound + return core.ErrNotFound } slog.ErrorContext(ctx, "database sql", "module", module, "err", err) - return types.ErrInternal + return core.ErrInternal } if r != nil { rows, err := r.RowsAffected() if err != nil { slog.ErrorContext(ctx, "database rows affected", "module", module, "err", err) - return types.ErrInternal + return core.ErrInternal } if rows == 0 { slog.InfoContext(ctx, "row not found", "module", module) - return ErrNotFound + return core.ErrNotFound } } diff --git a/internal/db/migration.go b/internal/db/migration.go index 76ed2d9..3040b93 100644 --- a/internal/db/migration.go +++ b/internal/db/migration.go @@ -4,7 +4,7 @@ import ( "context" "errors" "log/slog" - "spend-sparrow/internal/types" + "spend-sparrow/internal/core" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" @@ -25,7 +25,7 @@ func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error { driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) if err != nil { slog.ErrorContext(ctx, "Could not create Migration instance", "err", err) - return types.ErrInternal + return core.ErrInternal } m, err := migrate.NewWithDatabaseInstance( @@ -34,14 +34,14 @@ func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error { driver) if err != nil { slog.ErrorContext(ctx, "Could not create migrations instance", "err", err) - return types.ErrInternal + return core.ErrInternal } m.Log = migrationLogger{} if err = m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { slog.ErrorContext(ctx, "Could not run migrations", "err", err) - return types.ErrInternal + return core.ErrInternal } return nil diff --git a/internal/default.go b/internal/default.go index 277c75d..4cd63d1 100644 --- a/internal/default.go +++ b/internal/default.go @@ -8,6 +8,7 @@ import ( "net/http" "os/signal" "spend-sparrow/internal/account" + "spend-sparrow/internal/authentication" "spend-sparrow/internal/core" "spend-sparrow/internal/db" "spend-sparrow/internal/handler" @@ -107,13 +108,13 @@ func shutdownServer(ctx context.Context, s *http.Server, wg *sync.WaitGroup) { func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler { var router = http.NewServeMux() - authDb := db.NewAuthSqlite(d) + authDb := authentication.NewDbSqlite(d) - randomService := service.NewRandom() - clockService := service.NewClock() - mailService := service.NewMail(serverSettings) + randomService := core.NewRandom() + clockService := core.NewClock() + mailService := core.NewMail(serverSettings) - authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings) + authService := authentication.NewService(authDb, randomService, clockService, mailService, serverSettings) accountService := account.NewServiceImpl(d, randomService, clockService) treasureChestService := service.NewTreasureChest(d, randomService, clockService) transactionService := service.NewTransaction(d, randomService, clockService) @@ -123,7 +124,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings * render := core.NewRender() indexHandler := handler.NewIndex(render, clockService) dashboardHandler := handler.NewDashboard(render, dashboardService, treasureChestService) - authHandler := handler.NewAuth(authService, render) + authHandler := authentication.NewHandler(authService, render) accountHandler := account.NewHandler(accountService, render) treasureChestHandler := handler.NewTreasureChest(treasureChestService, transactionRecurringService, render) transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render) @@ -157,7 +158,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings * return wrapper } -func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) { +func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) { runDailyTasks(ctx, transactionRecurring, auth) ticker := time.NewTicker(24 * time.Hour) defer ticker.Stop() @@ -172,7 +173,7 @@ func dailyTaskTimer(ctx context.Context, transactionRecurring service.Transactio } } -func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) { +func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) { slog.InfoContext(ctx, "Running daily tasks") _ = transactionRecurring.GenerateTransactions(ctx) _ = auth.CleanupSessionsAndTokens(ctx) diff --git a/internal/handler/dashboard.go b/internal/handler/dashboard.go index 8bc4525..1e14b7b 100644 --- a/internal/handler/dashboard.go +++ b/internal/handler/dashboard.go @@ -193,7 +193,7 @@ func (handler DashboardImpl) handleDashboardTreasureChest() http.HandlerFunc { if treasureChestStr != "" { id, err := uuid.Parse(treasureChestStr) if err != nil { - core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", service.ErrBadRequest)) + core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", core.ErrBadRequest)) return } diff --git a/internal/handler/middleware/authenticate.go b/internal/handler/middleware/authenticate.go index d20a0b7..4ba51d9 100644 --- a/internal/handler/middleware/authenticate.go +++ b/internal/handler/middleware/authenticate.go @@ -3,12 +3,12 @@ package middleware import ( "context" "net/http" + "spend-sparrow/internal/authentication" "spend-sparrow/internal/core" - "spend-sparrow/internal/service" "strings" ) -func Authenticate(service service.Auth) func(http.Handler) http.Handler { +func Authenticate(service authentication.Service) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -31,7 +31,7 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler { return } - cookie := CreateSessionCookie(session.Id) + cookie := core.CreateSessionCookie(session.Id) http.SetCookie(w, &cookie) } diff --git a/internal/handler/middleware/cross_site_request_forgery.go b/internal/handler/middleware/cross_site_request_forgery.go index 4601b56..407ab2c 100644 --- a/internal/handler/middleware/cross_site_request_forgery.go +++ b/internal/handler/middleware/cross_site_request_forgery.go @@ -3,8 +3,8 @@ package middleware import ( "log/slog" "net/http" + "spend-sparrow/internal/authentication" "spend-sparrow/internal/core" - "spend-sparrow/internal/service" "spend-sparrow/internal/utils" "strings" ) @@ -31,7 +31,7 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) { return rr.ResponseWriter.Write([]byte(dataStr)) } -func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { +func CrossSiteRequestForgery(auth authentication.Service) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/internal/handler/middleware/default.go b/internal/handler/middleware/default.go index 0146fb5..c870d7c 100644 --- a/internal/handler/middleware/default.go +++ b/internal/handler/middleware/default.go @@ -1,15 +1 @@ package middleware - -import "net/http" - -func CreateSessionCookie(sessionId string) http.Cookie { - return http.Cookie{ - Name: "id", - Value: sessionId, - MaxAge: 60 * 60 * 8, // 8 hours - Secure: true, - HttpOnly: true, - SameSite: http.SameSiteStrictMode, - Path: "/", - } -} diff --git a/internal/handler/root_and_404.go b/internal/handler/root_and_404.go index 19ad051..e09b6e4 100644 --- a/internal/handler/root_and_404.go +++ b/internal/handler/root_and_404.go @@ -3,7 +3,6 @@ package handler import ( "net/http" "spend-sparrow/internal/core" - "spend-sparrow/internal/service" "spend-sparrow/internal/template" "spend-sparrow/internal/utils" @@ -16,10 +15,10 @@ type Index interface { type IndexImpl struct { r *core.Render - c service.Clock + c core.Clock } -func NewIndex(r *core.Render, c service.Clock) Index { +func NewIndex(r *core.Render, c core.Clock) Index { return IndexImpl{ r: r, c: c, diff --git a/internal/handler/transaction.go b/internal/handler/transaction.go index bae53ca..5e330fa 100644 --- a/internal/handler/transaction.go +++ b/internal/handler/transaction.go @@ -157,7 +157,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc { if idStr != "new" { id, err = uuid.Parse(idStr) if err != nil { - core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)) + core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)) return } } @@ -167,7 +167,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc { if accountIdStr != "" { i, err := uuid.Parse(accountIdStr) if err != nil { - core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", service.ErrBadRequest)) + core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", core.ErrBadRequest)) return } accountId = &i @@ -178,7 +178,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc { if treasureChestIdStr != "" { i, err := uuid.Parse(treasureChestIdStr) if err != nil { - core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", service.ErrBadRequest)) + core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", core.ErrBadRequest)) return } treasureChestId = &i @@ -186,14 +186,14 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc { valueF, err := strconv.ParseFloat(r.FormValue("value"), 64) if err != nil { - core.HandleError(w, r, fmt.Errorf("could not parse value: %w", service.ErrBadRequest)) + core.HandleError(w, r, fmt.Errorf("could not parse value: %w", core.ErrBadRequest)) return } value := int64(math.Round(valueF * service.DECIMALS_MULTIPLIER)) timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp")) if err != nil { - core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", service.ErrBadRequest)) + core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest)) return } diff --git a/internal/handler/transaction_recurring.go b/internal/handler/transaction_recurring.go index 33f8a39..615d443 100644 --- a/internal/handler/transaction_recurring.go +++ b/internal/handler/transaction_recurring.go @@ -2,6 +2,7 @@ package handler import ( "net/http" + "spend-sparrow/internal/auth_types" "spend-sparrow/internal/core" "spend-sparrow/internal/service" t "spend-sparrow/internal/template/transaction_recurring" @@ -111,7 +112,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle } } -func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) { +func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *auth_types.User, id, accountId, treasureChestId string) { var transactionsRecurring []*types.TransactionRecurring var err error if accountId == "" && treasureChestId == "" { diff --git a/internal/service/dashboard.go b/internal/service/dashboard.go index 13235d4..187efa9 100644 --- a/internal/service/dashboard.go +++ b/internal/service/dashboard.go @@ -2,6 +2,8 @@ package service import ( "context" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/core" "spend-sparrow/internal/db" "spend-sparrow/internal/types" "time" @@ -22,10 +24,10 @@ func NewDashboard(db *sqlx.DB) *Dashboard { func (s Dashboard) MainChart( ctx context.Context, - user *types.User, + user *auth_types.User, ) ([]types.DashboardMainChartEntry, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } transactions := make([]types.Transaction, 0) @@ -82,10 +84,10 @@ func (s Dashboard) MainChart( func (s Dashboard) TreasureChests( ctx context.Context, - user *types.User, + user *auth_types.User, ) ([]*types.DashboardTreasureChest, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } treasureChests := make([]*types.TreasureChest, 0) @@ -120,11 +122,11 @@ func (s Dashboard) TreasureChests( func (s Dashboard) TreasureChest( ctx context.Context, - user *types.User, + user *auth_types.User, treausureChestId *uuid.UUID, ) ([]types.DashboardMainChartEntry, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } transactions := make([]types.Transaction, 0) diff --git a/internal/service/default.go b/internal/service/default.go index 72d8261..b92dd39 100644 --- a/internal/service/default.go +++ b/internal/service/default.go @@ -3,6 +3,7 @@ package service import ( "fmt" "regexp" + "spend-sparrow/internal/core" ) const ( @@ -16,9 +17,9 @@ var ( func ValidateString(value string, fieldName string) error { switch { case value == "": - return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest) + return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, core.ErrBadRequest) case !safeInputRegex.MatchString(value): - return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest) + return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, core.ErrBadRequest) default: return nil } diff --git a/internal/service/error.go b/internal/service/error.go deleted file mode 100644 index 4f5da50..0000000 --- a/internal/service/error.go +++ /dev/null @@ -1,8 +0,0 @@ -package service - -import "errors" - -var ( - ErrBadRequest = errors.New("bad request") - ErrUnauthorized = errors.New("unauthorized") -) diff --git a/internal/service/transaction.go b/internal/service/transaction.go index f950447..fa0e3c4 100644 --- a/internal/service/transaction.go +++ b/internal/service/transaction.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "log/slog" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/core" "spend-sparrow/internal/db" "spend-sparrow/internal/types" "strconv" @@ -17,22 +19,22 @@ import ( const page_size = 25 type Transaction interface { - Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error) - Update(ctx context.Context, user *types.User, transaction types.Transaction) (*types.Transaction, error) - Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) - GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) - Delete(ctx context.Context, user *types.User, id string) error + Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error) + Update(ctx context.Context, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error) + Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error) + GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) + Delete(ctx context.Context, user *auth_types.User, id string) error - RecalculateBalances(ctx context.Context, user *types.User) error + RecalculateBalances(ctx context.Context, user *auth_types.User) error } type TransactionImpl struct { db *sqlx.DB - clock Clock - random Random + clock core.Clock + random core.Random } -func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction { +func NewTransaction(db *sqlx.DB, random core.Random, clock core.Clock) Transaction { return TransactionImpl{ db: db, clock: clock, @@ -40,9 +42,9 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction { } } -func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { +func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transactionInput types.Transaction) (*types.Transaction, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } var err error @@ -107,9 +109,9 @@ func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, return transaction, nil } -func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) { +func (s TransactionImpl) Update(ctx context.Context, user *auth_types.User, input types.Transaction) (*types.Transaction, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } tx, err := s.db.BeginTxx(ctx, nil) @@ -125,10 +127,10 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id) err = db.TransformAndLogDbError(ctx, "transaction Update", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { - return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest) + if errors.Is(err, core.ErrNotFound) { + return nil, fmt.Errorf("transaction %v not found: %w", input.Id, core.ErrBadRequest) } - return nil, types.ErrInternal + return nil, core.ErrInternal } if transaction.Error == nil && transaction.AccountId != nil { @@ -206,32 +208,32 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ return transaction, nil } -func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) { +func (s TransactionImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } uuid, err := uuid.Parse(id) if err != nil { slog.ErrorContext(ctx, "transaction get", "err", err) - return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } var transaction types.Transaction err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError(ctx, "transaction Get", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { - return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest) + if errors.Is(err, core.ErrNotFound) { + return nil, fmt.Errorf("transaction %v not found: %w", id, core.ErrBadRequest) } - return nil, types.ErrInternal + return nil, core.ErrInternal } return &transaction, nil } -func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) { +func (s TransactionImpl) GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } var ( @@ -277,14 +279,14 @@ func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter ty return transactions, nil } -func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error { +func (s TransactionImpl) Delete(ctx context.Context, user *auth_types.User, id string) error { if user == nil { - return ErrUnauthorized + return core.ErrUnauthorized } uuid, err := uuid.Parse(id) if err != nil { slog.ErrorContext(ctx, "transaction delete", "err", err) - return fmt.Errorf("could not parse Id: %w", ErrBadRequest) + return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -310,7 +312,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err) - if err != nil && !errors.Is(err, db.ErrNotFound) { + if err != nil && !errors.Is(err, core.ErrNotFound) { return err } } @@ -322,7 +324,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err) - if err != nil && !errors.Is(err, db.ErrNotFound) { + if err != nil && !errors.Is(err, core.ErrNotFound) { return err } } @@ -342,9 +344,9 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string return nil } -func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error { +func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *auth_types.User) error { if user == nil { - return ErrUnauthorized + return core.ErrUnauthorized } tx, err := s.db.BeginTxx(ctx, nil) @@ -361,7 +363,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us SET current_balance = 0 WHERE user_id = ?`, user.Id) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err) - if err != nil && !errors.Is(err, db.ErrNotFound) { + if err != nil && !errors.Is(err, core.ErrNotFound) { return err } @@ -370,7 +372,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us SET current_balance = 0 WHERE user_id = ?`, user.Id) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err) - if err != nil && !errors.Is(err, db.ErrNotFound) { + if err != nil && !errors.Is(err, core.ErrNotFound) { return err } @@ -379,7 +381,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us FROM "transaction" WHERE user_id = ?`, user.Id) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", nil, err) - if err != nil && !errors.Is(err, db.ErrNotFound) { + if err != nil && !errors.Is(err, core.ErrNotFound) { return err } defer func() { @@ -458,7 +460,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s if oldTransaction == nil { id, err = s.random.UUID(ctx) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } createdAt = s.clock.Now() createdBy = userId @@ -479,7 +481,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s } if rowCount == 0 { slog.ErrorContext(ctx, "transaction validate", "err", err) - return nil, fmt.Errorf("account not found: %w", ErrBadRequest) + return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest) } } @@ -488,13 +490,13 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId) err = db.TransformAndLogDbError(ctx, "transaction validate", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { - return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) + if errors.Is(err, core.ErrNotFound) { + return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest) } return nil, err } if treasureChest.ParentId == nil { - return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest) + return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest) } } diff --git a/internal/service/transaction_recurring.go b/internal/service/transaction_recurring.go index 6982ebc..9a6b31f 100644 --- a/internal/service/transaction_recurring.go +++ b/internal/service/transaction_recurring.go @@ -6,6 +6,8 @@ import ( "fmt" "log/slog" "math" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/core" "spend-sparrow/internal/db" "spend-sparrow/internal/types" "strconv" @@ -16,24 +18,24 @@ import ( ) type TransactionRecurring interface { - Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) - Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) - GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) - GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) - GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) - Delete(ctx context.Context, user *types.User, id string) error + Add(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) + Update(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) + GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error) + GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error) + GetAllByTreasureChest(ctx context.Context, user *auth_types.User, treasureChestId string) ([]*types.TransactionRecurring, error) + Delete(ctx context.Context, user *auth_types.User, id string) error GenerateTransactions(ctx context.Context) error } type TransactionRecurringImpl struct { db *sqlx.DB - clock Clock - random Random + clock core.Clock + random core.Random transaction Transaction } -func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transaction Transaction) TransactionRecurring { +func NewTransactionRecurring(db *sqlx.DB, random core.Random, clock core.Clock, transaction Transaction) TransactionRecurring { return TransactionRecurringImpl{ db: db, clock: clock, @@ -43,11 +45,11 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio } func (s TransactionRecurringImpl) Add(ctx context.Context, - user *types.User, + user *auth_types.User, transactionRecurringInput types.TransactionRecurringInput, ) (*types.TransactionRecurring, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } tx, err := s.db.BeginTxx(ctx, nil) @@ -85,16 +87,16 @@ func (s TransactionRecurringImpl) Add(ctx context.Context, } func (s TransactionRecurringImpl) Update(ctx context.Context, - user *types.User, + user *auth_types.User, input types.TransactionRecurringInput, ) (*types.TransactionRecurring, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } uuid, err := uuid.Parse(input.Id) if err != nil { slog.ErrorContext(ctx, "transactionRecurring update", "err", err) - return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -110,10 +112,10 @@ func (s TransactionRecurringImpl) Update(ctx context.Context, err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError(ctx, "transactionRecurring Update", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { - return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest) + if errors.Is(err, core.ErrNotFound) { + return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, core.ErrBadRequest) } - return nil, types.ErrInternal + return nil, core.ErrInternal } transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input) @@ -149,9 +151,9 @@ func (s TransactionRecurringImpl) Update(ctx context.Context, return transactionRecurring, nil } -func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) { +func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } transactionRecurrings := make([]*types.TransactionRecurring, 0) @@ -169,15 +171,15 @@ func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) return transactionRecurrings, nil } -func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) { +func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } accountUuid, err := uuid.Parse(accountId) if err != nil { slog.ErrorContext(ctx, "transactionRecurring GetAllByAccount", "err", err) - return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -193,10 +195,10 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id) err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByAccount", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { - return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest) + if errors.Is(err, core.ErrNotFound) { + return nil, fmt.Errorf("account %v not found: %w", accountId, core.ErrBadRequest) } - return nil, types.ErrInternal + return nil, core.ErrInternal } transactionRecurrings := make([]*types.TransactionRecurring, 0) @@ -222,17 +224,17 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ } func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context, - user *types.User, + user *auth_types.User, treasureChestId string, ) ([]*types.TransactionRecurring, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } treasureChestUuid, err := uuid.Parse(treasureChestId) if err != nil { slog.ErrorContext(ctx, "transactionRecurring GetAllByTreasureChest", "err", err) - return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -248,10 +250,10 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context, err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id) err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByTreasureChest", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { - return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest) + if errors.Is(err, core.ErrNotFound) { + return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, core.ErrBadRequest) } - return nil, types.ErrInternal + return nil, core.ErrInternal } transactionRecurrings := make([]*types.TransactionRecurring, 0) @@ -276,14 +278,14 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context, return transactionRecurrings, nil } -func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error { +func (s TransactionRecurringImpl) Delete(ctx context.Context, user *auth_types.User, id string) error { if user == nil { - return ErrUnauthorized + return core.ErrUnauthorized } uuid, err := uuid.Parse(id) if err != nil { slog.ErrorContext(ctx, "transactionRecurring delete", "err", err) - return fmt.Errorf("could not parse Id: %w", ErrBadRequest) + return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -339,7 +341,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) erro } for _, transactionRecurring := range recurringTransactions { - user := &types.User{ + user := &auth_types.User{ Id: transactionRecurring.UserId, } transaction := types.Transaction{ @@ -397,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( if oldTransactionRecurring == nil { id, err = s.random.UUID(ctx) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } createdAt = s.clock.Now() createdBy = userId @@ -416,7 +418,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( temp, err := uuid.Parse(input.AccountId) if err != nil { slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) - return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest) } accountUuid = &temp err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId) @@ -426,7 +428,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( } if rowCount == 0 { slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) - return nil, fmt.Errorf("account not found: %w", ErrBadRequest) + return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest) } hasAccount = true @@ -436,37 +438,37 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( temp, err := uuid.Parse(input.TreasureChestId) if err != nil { slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) - return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest) } treasureChestUuid = &temp var treasureChest types.TreasureChest err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = db.TransformAndLogDbError(ctx, "transactionRecurring validate", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { - return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) + if errors.Is(err, core.ErrNotFound) { + return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest) } return nil, err } if treasureChest.ParentId == nil { - return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest) + return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest) } hasTreasureChest = true } if !hasAccount && !hasTreasureChest { slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) - return nil, fmt.Errorf("either account or treasure chest is required: %w", ErrBadRequest) + return nil, fmt.Errorf("either account or treasure chest is required: %w", core.ErrBadRequest) } if hasAccount && hasTreasureChest { slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) - return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", ErrBadRequest) + return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", core.ErrBadRequest) } valueFloat, err := strconv.ParseFloat(input.Value, 64) if err != nil { slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) - return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse value: %w", core.ErrBadRequest) } value := int64(math.Round(valueFloat * DECIMALS_MULTIPLIER)) @@ -485,18 +487,18 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0) if err != nil { slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) - return nil, fmt.Errorf("could not parse intervalMonths: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse intervalMonths: %w", core.ErrBadRequest) } if intervalMonths < 1 { slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) - return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", ErrBadRequest) + return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", core.ErrBadRequest) } var nextExecution *time.Time = nil if input.NextExecution != "" { t, err := time.Parse("2006-01-02", input.NextExecution) if err != nil { slog.ErrorContext(ctx, "transaction validate", "err", err) - return nil, fmt.Errorf("could not parse timestamp: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest) } t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location()) diff --git a/internal/service/treasure_chest.go b/internal/service/treasure_chest.go index cdc60b5..65db751 100644 --- a/internal/service/treasure_chest.go +++ b/internal/service/treasure_chest.go @@ -6,6 +6,8 @@ import ( "fmt" "log/slog" "slices" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/core" "spend-sparrow/internal/db" "spend-sparrow/internal/types" @@ -14,20 +16,20 @@ import ( ) type TreasureChest interface { - Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) - Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error) - Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) - GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) - Delete(ctx context.Context, user *types.User, id string) error + Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error) + Update(ctx context.Context, user *auth_types.User, id, parentId, name string) (*types.TreasureChest, error) + Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error) + GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error) + Delete(ctx context.Context, user *auth_types.User, id string) error } type TreasureChestImpl struct { db *sqlx.DB - clock Clock - random Random + clock core.Clock + random core.Random } -func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest { +func NewTreasureChest(db *sqlx.DB, random core.Random, clock core.Clock) TreasureChest { return TreasureChestImpl{ db: db, clock: clock, @@ -35,14 +37,14 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest { } } -func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) { +func (s TreasureChestImpl) Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } newId, err := s.random.UUID(ctx) if err != nil { - return nil, types.ErrInternal + return nil, core.ErrInternal } err = ValidateString(name, "name") @@ -57,7 +59,7 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, return nil, err } if parent.ParentId != nil { - return nil, fmt.Errorf("only a depth of 1 allowed: %w", ErrBadRequest) + return nil, fmt.Errorf("only a depth of 1 allowed: %w", core.ErrBadRequest) } parentUuid = &parent.Id } @@ -88,9 +90,9 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, return treasureChest, nil } -func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) { +func (s TreasureChestImpl) Update(ctx context.Context, user *auth_types.User, idStr, parentId, name string) (*types.TreasureChest, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } err := ValidateString(name, "name") if err != nil { @@ -99,7 +101,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, id, err := uuid.Parse(idStr) if err != nil { slog.ErrorContext(ctx, "treasureChest update", "err", err) - return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -115,10 +117,10 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id) err = db.TransformAndLogDbError(ctx, "treasureChest Update", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { + if errors.Is(err, core.ErrNotFound) { return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err) } - return nil, types.ErrInternal + return nil, core.ErrInternal } var parentUuid *uuid.UUID @@ -134,7 +136,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, return nil, err } if parent.ParentId != nil || childCount > 0 { - return nil, fmt.Errorf("only one level allowed: %w", ErrBadRequest) + return nil, fmt.Errorf("only one level allowed: %w", core.ErrBadRequest) } parentUuid = &parent.Id @@ -170,32 +172,32 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, return treasureChest, nil } -func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) { +func (s TreasureChestImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } uuid, err := uuid.Parse(id) if err != nil { slog.ErrorContext(ctx, "treasureChest get", "err", err) - return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) + return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } var treasureChest types.TreasureChest err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError(ctx, "treasureChest Get", nil, err) if err != nil { - if errors.Is(err, db.ErrNotFound) { + if errors.Is(err, core.ErrNotFound) { return nil, fmt.Errorf("treasureChest %v not found: %w", id, err) } - return nil, types.ErrInternal + return nil, core.ErrInternal } return &treasureChest, nil } -func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) { +func (s TreasureChestImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error) { if user == nil { - return nil, ErrUnauthorized + return nil, core.ErrUnauthorized } treasureChests := make([]*types.TreasureChest, 0) @@ -208,14 +210,14 @@ func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*typ return sortTreasureChests(treasureChests), nil } -func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error { +func (s TreasureChestImpl) Delete(ctx context.Context, user *auth_types.User, idStr string) error { if user == nil { - return ErrUnauthorized + return core.ErrUnauthorized } id, err := uuid.Parse(idStr) if err != nil { slog.ErrorContext(ctx, "treasureChest delete", "err", err) - return fmt.Errorf("could not parse Id: %w", ErrBadRequest) + return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest) } tx, err := s.db.BeginTxx(ctx, nil) @@ -235,7 +237,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s } if childCount > 0 { - return fmt.Errorf("treasure chest has children: %w", ErrBadRequest) + return fmt.Errorf("treasure chest has children: %w", core.ErrBadRequest) } transactionsCount := 0 @@ -247,7 +249,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s return err } if transactionsCount > 0 { - return fmt.Errorf("treasure chest has transactions: %w", ErrBadRequest) + return fmt.Errorf("treasure chest has transactions: %w", core.ErrBadRequest) } recurringCount := 0 @@ -259,7 +261,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s return err } if recurringCount > 0 { - return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest) + return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", core.ErrBadRequest) } r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id) diff --git a/internal/template/auth/default.go b/internal/template/auth/default.go deleted file mode 100644 index 8832b06..0000000 --- a/internal/template/auth/default.go +++ /dev/null @@ -1 +0,0 @@ -package auth diff --git a/internal/types/types.go b/internal/types/types.go deleted file mode 100644 index 88d3af4..0000000 --- a/internal/types/types.go +++ /dev/null @@ -1,10 +0,0 @@ -package types - -import ( - "errors" -) - -var ( - ErrInternal = errors.New("internal server error") - ErrUnauthorized = errors.New("you are not authorized to perform this action") -) diff --git a/mocks/default.go b/mocks/default.go deleted file mode 100644 index f726b26..0000000 --- a/mocks/default.go +++ /dev/null @@ -1 +0,0 @@ -package mocks diff --git a/test/auth_it_test.go b/test/auth_it_test.go index 44be9d1..59d71d4 100644 --- a/test/auth_it_test.go +++ b/test/auth_it_test.go @@ -2,8 +2,10 @@ package test_test import ( "context" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/authentication" + "spend-sparrow/internal/core" "spend-sparrow/internal/db" - "spend-sparrow/internal/types" "testing" "time" @@ -42,11 +44,11 @@ func TestUser(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + expected := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(context.Background(), expected) require.NoError(t, err) @@ -63,38 +65,38 @@ func TestUser(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) _, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail") - assert.Equal(t, db.ErrNotFound, err) + assert.Equal(t, core.ErrNotFound, err) }) t.Run("should return ErrUserExist", func(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + user := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(context.Background(), user) require.NoError(t, err) err = underTest.InsertUser(context.Background(), user) - assert.Equal(t, db.ErrAlreadyExists, err) + assert.Equal(t, core.ErrAlreadyExists, err) }) t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) + user := auth_types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) err := underTest.InsertUser(context.Background(), user) - assert.Equal(t, types.ErrInternal, err) + assert.Equal(t, core.ErrInternal, err) }) } @@ -105,11 +107,11 @@ func TestToken(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) expiresAt := createAt.Add(24 * time.Hour) - expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt) + expected := auth_types.NewToken(uuid.New(), "sessionId", "token", auth_types.TokenTypeCsrf, createAt, expiresAt) err := underTest.InsertToken(context.Background(), expected) require.NoError(t, err) @@ -121,25 +123,25 @@ func TestToken(t *testing.T) { expected.SessionId = "" actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type) require.NoError(t, err) - assert.Equal(t, []*types.Token{expected}, actuals) + assert.Equal(t, []*auth_types.Token{expected}, actuals) expected.SessionId = "sessionId" expected.UserId = uuid.Nil actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type) require.NoError(t, err) - assert.Equal(t, []*types.Token{expected}, actuals) + assert.Equal(t, []*auth_types.Token{expected}, actuals) }) t.Run("should insert and return multiple tokens", func(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) expiresAt := createAt.Add(24 * time.Hour) userId := uuid.New() - expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt) - expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt) + expected1 := auth_types.NewToken(userId, "sessionId", "token1", auth_types.TokenTypeCsrf, createAt, expiresAt) + expected2 := auth_types.NewToken(userId, "sessionId", "token2", auth_types.TokenTypeCsrf, createAt, expiresAt) err := underTest.InsertToken(context.Background(), expected1) require.NoError(t, err) @@ -150,7 +152,7 @@ func TestToken(t *testing.T) { expected2.UserId = uuid.Nil actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type) require.NoError(t, err) - assert.Equal(t, []*types.Token{expected1, expected2}, actuals) + assert.Equal(t, []*auth_types.Token{expected1, expected2}, actuals) expected1.SessionId = "" expected2.SessionId = "" @@ -158,49 +160,49 @@ func TestToken(t *testing.T) { expected2.UserId = userId actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type) require.NoError(t, err) - assert.Equal(t, []*types.Token{expected1, expected2}, actuals) + assert.Equal(t, []*auth_types.Token{expected1, expected2}, actuals) }) t.Run("should return ErrNotFound", func(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) _, err := underTest.GetToken(context.Background(), "nonExistent") - assert.Equal(t, db.ErrNotFound, err) + assert.Equal(t, core.ErrNotFound, err) - _, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), types.TokenTypeEmailVerify) - assert.Equal(t, db.ErrNotFound, err) + _, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), auth_types.TokenTypeEmailVerify) + assert.Equal(t, core.ErrNotFound, err) - _, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", types.TokenTypeEmailVerify) - assert.Equal(t, db.ErrNotFound, err) + _, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", auth_types.TokenTypeEmailVerify) + assert.Equal(t, core.ErrNotFound, err) }) t.Run("should return ErrAlreadyExists", func(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + user := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(context.Background(), user) require.NoError(t, err) err = underTest.InsertUser(context.Background(), user) - assert.Equal(t, db.ErrAlreadyExists, err) + assert.Equal(t, core.ErrAlreadyExists, err) }) t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Parallel() d := setupDb(t) - underTest := db.NewAuthSqlite(d) + underTest := authentication.NewDbSqlite(d) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) + user := auth_types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) err := underTest.InsertUser(context.Background(), user) - assert.Equal(t, types.ErrInternal, err) + assert.Equal(t, core.ErrInternal, err) }) } diff --git a/test/auth_test.go b/test/auth_test.go index f8a1ea5..1b3a5c7 100644 --- a/test/auth_test.go +++ b/test/auth_test.go @@ -2,8 +2,9 @@ package test_test import ( "context" - "spend-sparrow/internal/db" - "spend-sparrow/internal/service" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/authentication" + "spend-sparrow/internal/core" "spend-sparrow/internal/types" "spend-sparrow/mocks" "strings" @@ -30,26 +31,26 @@ func TestSignUp(t *testing.T) { t.Run("should check for correct email address", func(t *testing.T) { t.Parallel() - mockAuthDb := mocks.NewMockAuth(t) + mockAuthDb := mocks.NewMockDb(t) mockRandom := mocks.NewMockRandom(t) mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) - underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) + underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings) _, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!") - assert.Equal(t, service.ErrInvalidEmail, err) + assert.Equal(t, authentication.ErrInvalidEmail, err) }) t.Run("should check for password complexity", func(t *testing.T) { t.Parallel() - mockAuthDb := mocks.NewMockAuth(t) + mockAuthDb := mocks.NewMockDb(t) mockRandom := mocks.NewMockRandom(t) mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) - underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) + underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings) weakPasswords := []string{ "123!ab", // too short @@ -60,13 +61,13 @@ func TestSignUp(t *testing.T) { for _, password := range weakPasswords { _, err := underTest.SignUp(context.Background(), "some@valid.email", password) - assert.Equal(t, service.ErrInvalidPassword, err) + assert.Equal(t, authentication.ErrInvalidPassword, err) } }) t.Run("should signup correctly", func(t *testing.T) { t.Parallel() - mockAuthDb := mocks.NewMockAuth(t) + mockAuthDb := mocks.NewMockDb(t) mockRandom := mocks.NewMockRandom(t) mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) @@ -77,7 +78,7 @@ func TestSignUp(t *testing.T) { salt := []byte("salt") createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) - expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime) + expected := auth_types.NewUser(userId, email, false, nil, false, authentication.GetHashPassword(password, salt), salt, createTime) ctx := context.Background() @@ -86,7 +87,7 @@ func TestSignUp(t *testing.T) { mockClock.EXPECT().Now().Return(createTime) mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil) - underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) + underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings) actual, err := underTest.SignUp(context.Background(), email, password) require.NoError(t, err) @@ -96,7 +97,7 @@ func TestSignUp(t *testing.T) { t.Run("should return ErrAccountExists", func(t *testing.T) { t.Parallel() - mockAuthDb := mocks.NewMockAuth(t) + mockAuthDb := mocks.NewMockDb(t) mockRandom := mocks.NewMockRandom(t) mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) @@ -106,19 +107,19 @@ func TestSignUp(t *testing.T) { createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) password := "SomeStrongPassword123!" salt := []byte("salt") - user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime) + user := auth_types.NewUser(userId, email, false, nil, false, authentication.GetHashPassword(password, salt), salt, createTime) ctx := context.Background() mockRandom.EXPECT().UUID(ctx).Return(user.Id, nil) mockRandom.EXPECT().Bytes(ctx, 16).Return(salt, nil) mockClock.EXPECT().Now().Return(createTime) - mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(db.ErrAlreadyExists) + mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(core.ErrAlreadyExists) - underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) + underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings) _, err := underTest.SignUp(context.Background(), user.Email, password) - assert.Equal(t, service.ErrAccountExists, err) + assert.Equal(t, authentication.ErrAccountExists, err) }) } @@ -127,30 +128,30 @@ func TestSendVerificationMail(t *testing.T) { t.Run("should use stored token and send mail", func(t *testing.T) { t.Parallel() - token := types.NewToken( + token := auth_types.NewToken( uuid.New(), "sessionId", "someRandomTokenToUse", - types.TokenTypeEmailVerify, + auth_types.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) - tokens := []*types.Token{token} + tokens := []*auth_types.Token{token} email := "some@email.de" userId := uuid.New() - mockAuthDb := mocks.NewMockAuth(t) + mockAuthDb := mocks.NewMockDb(t) mockRandom := mocks.NewMockRandom(t) mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) ctx := context.Background() - mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, types.TokenTypeEmailVerify).Return(tokens, nil) + mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, auth_types.TokenTypeEmailVerify).Return(tokens, nil) mockMail.EXPECT().SendMail(ctx, email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool { return strings.Contains(message, token.Token) })).Return() - underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) + underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest.SendVerificationMail(context.Background(), userId, email) }) diff --git a/test/it_test.go b/test/it_test.go index 186ac94..9a4cf3f 100644 --- a/test/it_test.go +++ b/test/it_test.go @@ -7,8 +7,9 @@ import ( "net/http" "net/url" "spend-sparrow/internal" - "spend-sparrow/internal/service" - "spend-sparrow/internal/types" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/authentication" + "spend-sparrow/internal/core" "strconv" "strings" "sync/atomic" @@ -117,7 +118,7 @@ func waitForReady( default: if time.Since(startTime) >= timeout { t.Fatal("timeout reached while waiting for endpoint") - return types.ErrInternal + return core.ErrInternal } // wait a little while between checks time.Sleep(250 * time.Millisecond) @@ -178,7 +179,7 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s t.Helper() userId := uuid.New() sessionId := "session-id" + add - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) csrfToken := "my-verifying-token" + add email := add + "mail@mail.de" @@ -193,7 +194,7 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s _, err = db.ExecContext(context.Background(), ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) - VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf) + VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, auth_types.TokenTypeCsrf) require.NoError(t, err) return userId, csrfToken, sessionId diff --git a/test/main_it_test.go b/test/main_it_test.go index c4a6b50..08ed048 100644 --- a/test/main_it_test.go +++ b/test/main_it_test.go @@ -3,8 +3,8 @@ package test_test import ( "net/http" "net/url" - "spend-sparrow/internal/service" - "spend-sparrow/internal/types" + "spend-sparrow/internal/auth_types" + "spend-sparrow/internal/authentication" "strings" "testing" "time" @@ -110,7 +110,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() sessionId := "session-id" - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) @@ -136,7 +136,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) @@ -163,7 +163,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) @@ -206,7 +206,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) @@ -247,7 +247,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) @@ -295,7 +295,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) @@ -414,7 +414,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) @@ -467,7 +467,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() sessionId := "session-id" - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) @@ -550,7 +550,7 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) - VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt")) + VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), authentication.GetHashPassword("password", []byte("salt")), []byte("salt")) require.NoError(t, err) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil) @@ -631,7 +631,7 @@ func TestIntegrationAuth(t *testing.T) { require.NoError(t, err) assert.Equal(t, 1, rows) var token string - err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token) + err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", auth_types.TokenTypeEmailVerify).Scan(&token) require.NoError(t, err) assert.NotEmpty(t, token) }) @@ -676,7 +676,7 @@ func TestIntegrationAuth(t *testing.T) { require.NoError(t, err) _, err = db.ExecContext(ctx, ` INSERT INTO token (token, user_id, type, created_at, expires_at) - VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify) + VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, auth_types.TokenTypeEmailVerify) require.NoError(t, err) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil) @@ -706,7 +706,7 @@ func TestIntegrationAuth(t *testing.T) { require.NoError(t, err) _, err = db.ExecContext(ctx, ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) - VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify) + VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, auth_types.TokenTypeEmailVerify) require.NoError(t, err) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil) @@ -746,7 +746,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() sessionId := "session-id" - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -765,7 +765,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var csrfToken string - err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken) + err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypeCsrf).Scan(&csrfToken) require.NoError(t, err) req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil) @@ -824,7 +824,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -870,7 +870,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1039,7 +1039,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1078,7 +1078,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) @@ -1128,7 +1128,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) @@ -1180,7 +1180,7 @@ func TestIntegrationAuth(t *testing.T) { userId := uuid.New() userIdOther := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1230,7 +1230,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) - pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt")) + pass = authentication.GetHashPassword("MyNewSecurePassword1!", []byte("salt")) var rows int err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) require.NoError(t, err) @@ -1259,7 +1259,7 @@ func TestIntegrationAuth(t *testing.T) { d, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1287,7 +1287,7 @@ func TestIntegrationAuth(t *testing.T) { d, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1317,7 +1317,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int - err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) + err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypePasswordReset).Scan(&rows) require.NoError(t, err) assert.Equal(t, 0, rows) }) @@ -1362,7 +1362,7 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := db.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1399,7 +1399,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg) var rows int - err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypePasswordReset).Scan(&rows) require.NoError(t, err) assert.Equal(t, 1, rows) }) @@ -1412,7 +1412,7 @@ func TestIntegrationAuth(t *testing.T) { d, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1455,7 +1455,7 @@ func TestIntegrationAuth(t *testing.T) { d, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1475,7 +1475,7 @@ func TestIntegrationAuth(t *testing.T) { token := "password-reset-token" _, err = d.ExecContext(ctx, ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) - VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset) + VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", auth_types.TokenTypePasswordReset) require.NoError(t, err) formData := url.Values{ @@ -1504,7 +1504,7 @@ func TestIntegrationAuth(t *testing.T) { d, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1524,7 +1524,7 @@ func TestIntegrationAuth(t *testing.T) { token := "password-reset-token" _, err = d.ExecContext(ctx, ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) - VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset) + VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", auth_types.TokenTypePasswordReset) require.NoError(t, err) formData := url.Values{ @@ -1553,7 +1553,7 @@ func TestIntegrationAuth(t *testing.T) { d, basePath, ctx := setupIntegrationTest(t) userId := uuid.New() - pass := service.GetHashPassword("password", []byte("salt")) + pass := authentication.GetHashPassword("password", []byte("salt")) _, err := d.ExecContext(ctx, ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) @@ -1590,7 +1590,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var token string - err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token) + err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", auth_types.TokenTypePasswordReset).Scan(&token) require.NoError(t, err) formData = url.Values{