feat: extract authentication to domain package #393
@@ -3,11 +3,11 @@ dir: mocks/
|
|||||||
outpkg: mocks
|
outpkg: mocks
|
||||||
issue-845-fix: True
|
issue-845-fix: True
|
||||||
packages:
|
packages:
|
||||||
spend-sparrow/internal/service:
|
spend-sparrow/internal/core:
|
||||||
interfaces:
|
interfaces:
|
||||||
Random:
|
Random:
|
||||||
Clock:
|
Clock:
|
||||||
Mail:
|
Mail:
|
||||||
spend-sparrow/internal/db:
|
spend-sparrow/internal/authentication:
|
||||||
interfaces:
|
interfaces:
|
||||||
Auth:
|
Db:
|
||||||
|
|||||||
@@ -4,29 +4,31 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/service"
|
"spend-sparrow/internal/service"
|
||||||
"spend-sparrow/internal/types"
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service interface {
|
type Service interface {
|
||||||
Add(ctx context.Context, user *types.User, name string) (*Account, error)
|
Add(ctx context.Context, user *auth_types.User, name string) (*Account, error)
|
||||||
UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error)
|
UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error)
|
||||||
Get(ctx context.Context, user *types.User, id string) (*Account, error)
|
Get(ctx context.Context, user *auth_types.User, id string) (*Account, error)
|
||||||
GetAll(ctx context.Context, user *types.User) ([]*Account, error)
|
GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error)
|
||||||
Delete(ctx context.Context, user *types.User, id string) error
|
Delete(ctx context.Context, user *auth_types.User, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServiceImpl struct {
|
type ServiceImpl struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
clock service.Clock
|
clock core.Clock
|
||||||
random service.Random
|
random core.Random
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Service {
|
func NewServiceImpl(db *sqlx.DB, random core.Random, clock core.Clock) Service {
|
||||||
return ServiceImpl{
|
return ServiceImpl{
|
||||||
db: db,
|
db: db,
|
||||||
clock: clock,
|
clock: clock,
|
||||||
@@ -34,14 +36,14 @@ func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Ser
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*Account, error) {
|
func (s ServiceImpl) Add(ctx context.Context, user *auth_types.User, name string) (*Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, types.ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
newId, err := s.random.UUID(ctx)
|
newId, err := s.random.UUID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.ValidateString(name, "name")
|
err = service.ValidateString(name, "name")
|
||||||
@@ -76,9 +78,9 @@ func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*A
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error) {
|
func (s ServiceImpl) UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, types.ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
err := service.ValidateString(name, "name")
|
err := service.ValidateString(name, "name")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -87,7 +89,7 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
|
|||||||
uuid, err := uuid.Parse(id)
|
uuid, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "account update", "err", err)
|
slog.ErrorContext(ctx, "account update", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -103,10 +105,10 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
|
|||||||
err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError(ctx, "account Update", nil, err)
|
err = db.TransformAndLogDbError(ctx, "account Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("account %v not found: %w", id, service.ErrBadRequest)
|
return nil, fmt.Errorf("account %v not found: %w", id, core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
timestamp := s.clock.Now()
|
timestamp := s.clock.Now()
|
||||||
@@ -136,14 +138,14 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
|
|||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Account, error) {
|
func (s ServiceImpl) Get(ctx context.Context, user *auth_types.User, id string) (*Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, service.ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
uuid, err := uuid.Parse(id)
|
uuid, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "account get", "err", err)
|
slog.ErrorContext(ctx, "account get", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
var account Account
|
var account Account
|
||||||
@@ -158,9 +160,9 @@ func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Acc
|
|||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account, error) {
|
func (s ServiceImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, service.ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts := make([]*Account, 0)
|
accounts := make([]*Account, 0)
|
||||||
@@ -174,14 +176,14 @@ func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account,
|
|||||||
return accounts, nil
|
return accounts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
func (s ServiceImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return service.ErrUnauthorized
|
return core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
uuid, err := uuid.Parse(id)
|
uuid, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "account delete", "err", err)
|
slog.ErrorContext(ctx, "account delete", "err", err)
|
||||||
return fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)
|
return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -200,7 +202,7 @@ func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) er
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if transactionsCount > 0 {
|
if transactionsCount > 0 {
|
||||||
return fmt.Errorf("account has transactions, cannot delete: %w", service.ErrBadRequest)
|
return fmt.Errorf("account has transactions, cannot delete: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
|
res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package types
|
package auth_types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
package db
|
package authentication
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -13,36 +14,36 @@ import (
|
|||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Auth interface {
|
type Db interface {
|
||||||
InsertUser(ctx context.Context, user *types.User) error
|
InsertUser(ctx context.Context, user *auth_types.User) error
|
||||||
UpdateUser(ctx context.Context, user *types.User) error
|
UpdateUser(ctx context.Context, user *auth_types.User) error
|
||||||
GetUserByEmail(ctx context.Context, email string) (*types.User, error)
|
GetUserByEmail(ctx context.Context, email string) (*auth_types.User, error)
|
||||||
GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error)
|
GetUser(ctx context.Context, userId uuid.UUID) (*auth_types.User, error)
|
||||||
DeleteUser(ctx context.Context, userId uuid.UUID) error
|
DeleteUser(ctx context.Context, userId uuid.UUID) error
|
||||||
|
|
||||||
InsertToken(ctx context.Context, token *types.Token) error
|
InsertToken(ctx context.Context, token *auth_types.Token) error
|
||||||
GetToken(ctx context.Context, token string) (*types.Token, error)
|
GetToken(ctx context.Context, token string) (*auth_types.Token, error)
|
||||||
GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
|
GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType auth_types.TokenType) ([]*auth_types.Token, error)
|
||||||
GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error)
|
GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error)
|
||||||
DeleteToken(ctx context.Context, token string) error
|
DeleteToken(ctx context.Context, token string) error
|
||||||
|
|
||||||
InsertSession(ctx context.Context, session *types.Session) error
|
InsertSession(ctx context.Context, session *auth_types.Session) error
|
||||||
GetSession(ctx context.Context, sessionId string) (*types.Session, error)
|
GetSession(ctx context.Context, sessionId string) (*auth_types.Session, error)
|
||||||
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
|
GetSessions(ctx context.Context, userId uuid.UUID) ([]*auth_types.Session, error)
|
||||||
DeleteSession(ctx context.Context, sessionId string) error
|
DeleteSession(ctx context.Context, sessionId string) error
|
||||||
DeleteOldSessions(ctx context.Context) error
|
DeleteOldSessions(ctx context.Context) error
|
||||||
DeleteOldTokens(ctx context.Context) error
|
DeleteOldTokens(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthSqlite struct {
|
type DbSqlite struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthSqlite(db *sqlx.DB) *AuthSqlite {
|
func NewDbSqlite(db *sqlx.DB) *DbSqlite {
|
||||||
return &AuthSqlite{db: db}
|
return &DbSqlite{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error {
|
func (db DbSqlite) InsertUser(ctx context.Context, user *auth_types.User) error {
|
||||||
_, err := db.db.ExecContext(ctx, `
|
_, err := db.db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
@@ -50,17 +51,17 @@ func (db AuthSqlite) InsertUser(ctx context.Context, user *types.User) error {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "email") {
|
if strings.Contains(err.Error(), "email") {
|
||||||
return ErrAlreadyExists
|
return core.ErrAlreadyExists
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.ErrorContext(ctx, "SQL error InsertUser", "err", err)
|
slog.ErrorContext(ctx, "SQL error InsertUser", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error {
|
func (db DbSqlite) UpdateUser(ctx context.Context, user *auth_types.User) error {
|
||||||
_, err := db.db.ExecContext(ctx, `
|
_, err := db.db.ExecContext(ctx, `
|
||||||
UPDATE user
|
UPDATE user
|
||||||
SET email_verified = ?, email_verified_at = ?, password = ?
|
SET email_verified = ?, email_verified_at = ?, password = ?
|
||||||
@@ -69,13 +70,13 @@ func (db AuthSqlite) UpdateUser(ctx context.Context, user *types.User) error {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "SQL error UpdateUser", "err", err)
|
slog.ErrorContext(ctx, "SQL error UpdateUser", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
func (db DbSqlite) GetUserByEmail(ctx context.Context, email string) (*auth_types.User, error) {
|
||||||
var (
|
var (
|
||||||
userId uuid.UUID
|
userId uuid.UUID
|
||||||
emailVerified bool
|
emailVerified bool
|
||||||
@@ -92,17 +93,17 @@ func (db AuthSqlite) GetUserByEmail(ctx context.Context, email string) (*types.U
|
|||||||
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, core.ErrNotFound
|
||||||
} else {
|
} else {
|
||||||
slog.ErrorContext(ctx, "SQL error GetUser", "err", err)
|
slog.ErrorContext(ctx, "SQL error GetUser", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
return auth_types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User, error) {
|
func (db DbSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*auth_types.User, error) {
|
||||||
var (
|
var (
|
||||||
email string
|
email string
|
||||||
emailVerified bool
|
emailVerified bool
|
||||||
@@ -119,92 +120,92 @@ func (db AuthSqlite) GetUser(ctx context.Context, userId uuid.UUID) (*types.User
|
|||||||
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, core.ErrNotFound
|
||||||
} else {
|
} else {
|
||||||
slog.ErrorContext(ctx, "SQL error GetUser", "err", err)
|
slog.ErrorContext(ctx, "SQL error GetUser", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
return auth_types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error {
|
func (db DbSqlite) DeleteUser(ctx context.Context, userId uuid.UUID) error {
|
||||||
tx, err := db.db.BeginTx(ctx, nil)
|
tx, err := db.db.BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not start transaction", "err", err)
|
slog.ErrorContext(ctx, "Could not start transaction", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM account WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.ErrorContext(ctx, "Could not delete accounts", "err", err)
|
slog.ErrorContext(ctx, "Could not delete accounts", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM token WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.ErrorContext(ctx, "Could not delete user tokens", "err", err)
|
slog.ErrorContext(ctx, "Could not delete user tokens", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM session WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.ErrorContext(ctx, "Could not delete sessions", "err", err)
|
slog.ErrorContext(ctx, "Could not delete sessions", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM user WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.ErrorContext(ctx, "Could not delete user", "err", err)
|
slog.ErrorContext(ctx, "Could not delete user", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM treasure_chest WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.ErrorContext(ctx, "Could not delete user", "err", err)
|
slog.ErrorContext(ctx, "Could not delete user", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId)
|
_, err = tx.ExecContext(ctx, "DELETE FROM \"transaction\" WHERE user_id = ?", userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tx.Rollback()
|
_ = tx.Rollback()
|
||||||
slog.ErrorContext(ctx, "Could not delete user", "err", err)
|
slog.ErrorContext(ctx, "Could not delete user", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit()
|
err = tx.Commit()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not commit transaction", "err", err)
|
slog.ErrorContext(ctx, "Could not commit transaction", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) InsertToken(ctx context.Context, token *types.Token) error {
|
func (db DbSqlite) InsertToken(ctx context.Context, token *auth_types.Token) error {
|
||||||
_, err := db.db.ExecContext(ctx, `
|
_, err := db.db.ExecContext(ctx, `
|
||||||
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
|
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
|
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not insert token", "err", err)
|
slog.ErrorContext(ctx, "Could not insert token", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token, error) {
|
func (db DbSqlite) GetToken(ctx context.Context, token string) (*auth_types.Token, error) {
|
||||||
var (
|
var (
|
||||||
userId uuid.UUID
|
userId uuid.UUID
|
||||||
sessionId string
|
sessionId string
|
||||||
tokenType types.TokenType
|
tokenType auth_types.TokenType
|
||||||
createdAtStr string
|
createdAtStr string
|
||||||
expiresAtStr string
|
expiresAtStr string
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
@@ -219,29 +220,29 @@ func (db AuthSqlite) GetToken(ctx context.Context, token string) (*types.Token,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
slog.InfoContext(ctx, "Token not found", "token", token)
|
slog.InfoContext(ctx, "Token not found", "token", token)
|
||||||
return nil, ErrNotFound
|
return nil, core.ErrNotFound
|
||||||
} else {
|
} else {
|
||||||
slog.ErrorContext(ctx, "Could not get token", "err", err)
|
slog.ErrorContext(ctx, "Could not get token", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
createdAt, err = time.Parse(time.RFC3339, createdAtStr)
|
createdAt, err = time.Parse(time.RFC3339, createdAtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err)
|
slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
expiresAt, err = time.Parse(time.RFC3339, expiresAtStr)
|
expiresAt, err = time.Parse(time.RFC3339, expiresAtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err)
|
slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
|
return auth_types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
|
func (db DbSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.UUID, tokenType auth_types.TokenType) ([]*auth_types.Token, error) {
|
||||||
query, err := db.db.QueryContext(ctx, `
|
query, err := db.db.QueryContext(ctx, `
|
||||||
SELECT token, created_at, expires_at
|
SELECT token, created_at, expires_at
|
||||||
FROM token
|
FROM token
|
||||||
@@ -250,13 +251,13 @@ func (db AuthSqlite) GetTokensByUserIdAndType(ctx context.Context, userId uuid.U
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not get token", "err", err)
|
slog.ErrorContext(ctx, "Could not get token", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return getTokensFromQuery(ctx, query, userId, "", tokenType)
|
return getTokensFromQuery(ctx, query, userId, "", tokenType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
|
func (db DbSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error) {
|
||||||
query, err := db.db.QueryContext(ctx, `
|
query, err := db.db.QueryContext(ctx, `
|
||||||
SELECT token, created_at, expires_at
|
SELECT token, created_at, expires_at
|
||||||
FROM token
|
FROM token
|
||||||
@@ -265,14 +266,14 @@ func (db AuthSqlite) GetTokensBySessionIdAndType(ctx context.Context, sessionId
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not get token", "err", err)
|
slog.ErrorContext(ctx, "Could not get token", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return getTokensFromQuery(ctx, query, uuid.Nil, sessionId, tokenType)
|
return getTokensFromQuery(ctx, query, uuid.Nil, sessionId, tokenType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID, sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
|
func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID, sessionId string, tokenType auth_types.TokenType) ([]*auth_types.Token, error) {
|
||||||
var tokens []*types.Token
|
var tokens []*auth_types.Token
|
||||||
|
|
||||||
hasRows := false
|
hasRows := false
|
||||||
for query.Next() {
|
for query.Next() {
|
||||||
@@ -289,54 +290,54 @@ func getTokensFromQuery(ctx context.Context, query *sql.Rows, userId uuid.UUID,
|
|||||||
err := query.Scan(&token, &createdAtStr, &expiresAtStr)
|
err := query.Scan(&token, &createdAtStr, &expiresAtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not scan token", "err", err)
|
slog.ErrorContext(ctx, "Could not scan token", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
createdAt, err = time.Parse(time.RFC3339, createdAtStr)
|
createdAt, err = time.Parse(time.RFC3339, createdAtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err)
|
slog.ErrorContext(ctx, "Could not parse token.created_at", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
expiresAt, err = time.Parse(time.RFC3339, expiresAtStr)
|
expiresAt, err = time.Parse(time.RFC3339, expiresAtStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err)
|
slog.ErrorContext(ctx, "Could not parse token.expires_at", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens = append(tokens, types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt))
|
tokens = append(tokens, auth_types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasRows {
|
if !hasRows {
|
||||||
return nil, ErrNotFound
|
return nil, core.ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokens, nil
|
return tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteToken(ctx context.Context, token string) error {
|
func (db DbSqlite) DeleteToken(ctx context.Context, token string) error {
|
||||||
_, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token)
|
_, err := db.db.ExecContext(ctx, "DELETE FROM token WHERE token = ?", token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not delete token", "err", err)
|
slog.ErrorContext(ctx, "Could not delete token", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) InsertSession(ctx context.Context, session *types.Session) error {
|
func (db DbSqlite) InsertSession(ctx context.Context, session *auth_types.Session) error {
|
||||||
_, err := db.db.ExecContext(ctx, `
|
_, err := db.db.ExecContext(ctx, `
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
|
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not insert new session", "err", err)
|
slog.ErrorContext(ctx, "Could not insert new session", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.Session, error) {
|
func (db DbSqlite) GetSession(ctx context.Context, sessionId string) (*auth_types.Session, error) {
|
||||||
var (
|
var (
|
||||||
userId uuid.UUID
|
userId uuid.UUID
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
@@ -350,56 +351,56 @@ func (db AuthSqlite) GetSession(ctx context.Context, sessionId string) (*types.S
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.WarnContext(ctx, "Session not found", "session-id", sessionId, "err", err)
|
slog.WarnContext(ctx, "Session not found", "session-id", sessionId, "err", err)
|
||||||
return nil, ErrNotFound
|
return nil, core.ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
return types.NewSession(sessionId, userId, createdAt, expiresAt), nil
|
return auth_types.NewSession(sessionId, userId, createdAt, expiresAt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) {
|
func (db DbSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*auth_types.Session, error) {
|
||||||
var sessions []*types.Session
|
var sessions []*auth_types.Session
|
||||||
err := db.db.SelectContext(ctx, &sessions, `
|
err := db.db.SelectContext(ctx, &sessions, `
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM session
|
FROM session
|
||||||
WHERE user_id = ?`, userId)
|
WHERE user_id = ?`, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not get sessions", "err", err)
|
slog.ErrorContext(ctx, "Could not get sessions", "err", err)
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return sessions, nil
|
return sessions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
|
func (db DbSqlite) DeleteSession(ctx context.Context, sessionId string) error {
|
||||||
if sessionId != "" {
|
if sessionId != "" {
|
||||||
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
|
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not delete session", "err", err)
|
slog.ErrorContext(ctx, "Could not delete session", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteOldSessions(ctx context.Context) error {
|
func (db DbSqlite) DeleteOldSessions(ctx context.Context) error {
|
||||||
_, err := db.db.ExecContext(ctx, `
|
_, err := db.db.ExecContext(ctx, `
|
||||||
DELETE FROM session
|
DELETE FROM session
|
||||||
WHERE expires_at < datetime('now')`)
|
WHERE expires_at < datetime('now')`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not delete old sessions", "err", err)
|
slog.ErrorContext(ctx, "Could not delete old sessions", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteOldTokens(ctx context.Context) error {
|
func (db DbSqlite) DeleteOldTokens(ctx context.Context) error {
|
||||||
_, err := db.db.ExecContext(ctx, `
|
_, err := db.db.ExecContext(ctx, `
|
||||||
DELETE FROM token
|
DELETE FROM token
|
||||||
WHERE expires_at < datetime('now')`)
|
WHERE expires_at < datetime('now')`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not delete old tokens", "err", err)
|
slog.ErrorContext(ctx, "Could not delete old tokens", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1,36 +1,34 @@
|
|||||||
package handler
|
package authentication
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/authentication/template"
|
||||||
"spend-sparrow/internal/core"
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/handler/middleware"
|
|
||||||
"spend-sparrow/internal/service"
|
|
||||||
"spend-sparrow/internal/template/auth"
|
|
||||||
"spend-sparrow/internal/types"
|
|
||||||
"spend-sparrow/internal/utils"
|
"spend-sparrow/internal/utils"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Auth interface {
|
type Handler interface {
|
||||||
Handle(router *http.ServeMux)
|
Handle(router *http.ServeMux)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthImpl struct {
|
type HandlerImpl struct {
|
||||||
service service.Auth
|
service Service
|
||||||
render *core.Render
|
render *core.Render
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuth(service service.Auth, render *core.Render) Auth {
|
func NewHandler(service Service, render *core.Render) Handler {
|
||||||
return AuthImpl{
|
return HandlerImpl{
|
||||||
service: service,
|
service: service,
|
||||||
render: render,
|
render: render,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) Handle(router *http.ServeMux) {
|
func (handler HandlerImpl) Handle(router *http.ServeMux) {
|
||||||
router.Handle("GET /auth/signin", handler.handleSignInPage())
|
router.Handle("GET /auth/signin", handler.handleSignInPage())
|
||||||
router.Handle("POST /api/auth/signin", handler.handleSignIn())
|
router.Handle("POST /api/auth/signin", handler.handleSignIn())
|
||||||
|
|
||||||
@@ -57,7 +55,7 @@ var (
|
|||||||
securityWaitDuration = 250 * time.Millisecond
|
securityWaitDuration = 250 * time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
|
func (handler HandlerImpl) handleSignInPage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -71,17 +69,17 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
comp := auth.SignInOrUpComp(true)
|
comp := template.SignInOrUpComp(true)
|
||||||
|
|
||||||
handler.render.RenderLayout(r, w, comp, nil)
|
handler.render.RenderLayout(r, w, comp, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
func (handler HandlerImpl) handleSignIn() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) {
|
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*auth_types.User, error) {
|
||||||
session := core.GetSession(r)
|
session := core.GetSession(r)
|
||||||
email := r.FormValue("email")
|
email := r.FormValue("email")
|
||||||
password := r.FormValue("password")
|
password := r.FormValue("password")
|
||||||
@@ -91,14 +89,14 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie := middleware.CreateSessionCookie(session.Id)
|
cookie := core.CreateSessionCookie(session.Id)
|
||||||
http.SetCookie(w, &cookie)
|
http.SetCookie(w, &cookie)
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrInvalidCredentials) {
|
if errors.Is(err, ErrInvalidCredentials) {
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Invalid email or password", http.StatusUnauthorized)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Invalid email or password", http.StatusUnauthorized)
|
||||||
} else {
|
} else {
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
|
||||||
@@ -114,7 +112,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleSignUpPage() http.HandlerFunc {
|
func (handler HandlerImpl) handleSignUpPage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -129,12 +127,12 @@ func (handler AuthImpl) handleSignUpPage() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
signUpComp := auth.SignInOrUpComp(false)
|
signUpComp := template.SignInOrUpComp(false)
|
||||||
handler.render.RenderLayout(r, w, signUpComp, nil)
|
handler.render.RenderLayout(r, w, signUpComp, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleSignUpVerifyPage() http.HandlerFunc {
|
func (handler HandlerImpl) handleSignUpVerifyPage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -149,12 +147,12 @@ func (handler AuthImpl) handleSignUpVerifyPage() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
signIn := auth.VerifyComp()
|
signIn := template.VerifyComp()
|
||||||
handler.render.RenderLayout(r, w, signIn, user)
|
handler.render.RenderLayout(r, w, signIn, user)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
|
func (handler HandlerImpl) handleVerifyResendComp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -173,7 +171,7 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
func (handler HandlerImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -182,7 +180,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
|||||||
err := handler.service.VerifyUserEmail(r.Context(), token)
|
err := handler.service.VerifyUserEmail(r.Context(), token)
|
||||||
|
|
||||||
isVerified := err == nil
|
isVerified := err == nil
|
||||||
comp := auth.VerifyResponseComp(isVerified)
|
comp := template.VerifyResponseComp(isVerified)
|
||||||
|
|
||||||
var status int
|
var status int
|
||||||
if isVerified {
|
if isVerified {
|
||||||
@@ -195,7 +193,7 @@ func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleSignUp() http.HandlerFunc {
|
func (handler HandlerImpl) handleSignUp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -216,14 +214,14 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, types.ErrInternal):
|
case errors.Is(err, core.ErrInternal):
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
case errors.Is(err, service.ErrInvalidEmail):
|
case errors.Is(err, ErrInvalidEmail):
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "The email provided is invalid", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "The email provided is invalid", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
case errors.Is(err, service.ErrInvalidPassword):
|
case errors.Is(err, ErrInvalidPassword):
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", ErrInvalidPassword.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// If err is "service.ErrAccountExists", then just continue
|
// If err is "service.ErrAccountExists", then just continue
|
||||||
@@ -233,7 +231,7 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleSignOut() http.HandlerFunc {
|
func (handler HandlerImpl) handleSignOut() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -262,7 +260,7 @@ func (handler AuthImpl) handleSignOut() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleDeleteAccountPage() http.HandlerFunc {
|
func (handler HandlerImpl) handleDeleteAccountPage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -272,12 +270,12 @@ func (handler AuthImpl) handleDeleteAccountPage() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
comp := auth.DeleteAccountComp()
|
comp := template.DeleteAccountComp()
|
||||||
handler.render.RenderLayout(r, w, comp, user)
|
handler.render.RenderLayout(r, w, comp, user)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
|
func (handler HandlerImpl) handleDeleteAccountComp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -291,7 +289,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
|
|||||||
|
|
||||||
err := handler.service.DeleteAccount(r.Context(), user, password)
|
err := handler.service.DeleteAccount(r.Context(), user, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, service.ErrInvalidCredentials) {
|
if errors.Is(err, ErrInvalidCredentials) {
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Password not correct", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Password not correct", http.StatusBadRequest)
|
||||||
} else {
|
} else {
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
||||||
@@ -303,7 +301,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
|
func (handler HandlerImpl) handleChangePasswordPage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -316,12 +314,12 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
comp := auth.ChangePasswordComp(isPasswordReset)
|
comp := template.ChangePasswordComp(isPasswordReset)
|
||||||
handler.render.RenderLayout(r, w, comp, user)
|
handler.render.RenderLayout(r, w, comp, user)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
|
func (handler HandlerImpl) handleChangePasswordComp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -345,7 +343,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
|
func (handler HandlerImpl) handleForgotPasswordPage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -355,12 +353,12 @@ func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
comp := auth.ResetPasswordComp()
|
comp := template.ResetPasswordComp()
|
||||||
handler.render.RenderLayout(r, w, comp, user)
|
handler.render.RenderLayout(r, w, comp, user)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
|
func (handler HandlerImpl) handleForgotPasswordComp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -383,7 +381,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
|
func (handler HandlerImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
core.UpdateSpan(r)
|
core.UpdateSpan(r)
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package authentication
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -6,7 +6,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
mailTemplate "spend-sparrow/internal/template/mail"
|
mailTemplate "spend-sparrow/internal/template/mail"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -25,39 +26,39 @@ var (
|
|||||||
ErrTokenInvalid = errors.New("token is invalid")
|
ErrTokenInvalid = errors.New("token is invalid")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Auth interface {
|
type Service interface {
|
||||||
SignUp(ctx context.Context, email string, password string) (*types.User, error)
|
SignUp(ctx context.Context, email string, password string) (*auth_types.User, error)
|
||||||
SendVerificationMail(ctx context.Context, userId uuid.UUID, email string)
|
SendVerificationMail(ctx context.Context, userId uuid.UUID, email string)
|
||||||
VerifyUserEmail(ctx context.Context, token string) error
|
VerifyUserEmail(ctx context.Context, token string) error
|
||||||
|
|
||||||
SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error)
|
SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_types.Session, *auth_types.User, error)
|
||||||
SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error)
|
SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error)
|
||||||
SignInAnonymous(ctx context.Context) (*types.Session, error)
|
SignInAnonymous(ctx context.Context) (*auth_types.Session, error)
|
||||||
SignOut(ctx context.Context, sessionId string) error
|
SignOut(ctx context.Context, sessionId string) error
|
||||||
|
|
||||||
DeleteAccount(ctx context.Context, user *types.User, currPass string) error
|
DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error
|
||||||
|
|
||||||
ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error
|
ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error
|
||||||
|
|
||||||
SendForgotPasswordMail(ctx context.Context, email string) error
|
SendForgotPasswordMail(ctx context.Context, email string) error
|
||||||
ForgotPassword(ctx context.Context, token string, newPass string) error
|
ForgotPassword(ctx context.Context, token string, newPass string) error
|
||||||
|
|
||||||
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
|
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
|
||||||
GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
|
GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error)
|
||||||
|
|
||||||
CleanupSessionsAndTokens(ctx context.Context) error
|
CleanupSessionsAndTokens(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthImpl struct {
|
type ServiceImpl struct {
|
||||||
db db.Auth
|
db Db
|
||||||
random Random
|
random core.Random
|
||||||
clock Clock
|
clock core.Clock
|
||||||
mail Mail
|
mail core.Mail
|
||||||
serverSettings *types.Settings
|
serverSettings *types.Settings
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *types.Settings) *AuthImpl {
|
func NewService(db Db, random core.Random, clock core.Clock, mail core.Mail, serverSettings *types.Settings) *ServiceImpl {
|
||||||
return &AuthImpl{
|
return &ServiceImpl{
|
||||||
db: db,
|
db: db,
|
||||||
random: random,
|
random: random,
|
||||||
clock: clock,
|
clock: clock,
|
||||||
@@ -66,13 +67,13 @@ func NewAuth(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, email string, password string) (*types.Session, *types.User, error) {
|
func (service ServiceImpl) SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_types.Session, *auth_types.User, error) {
|
||||||
user, err := service.db.GetUserByEmail(ctx, email)
|
user, err := service.db.GetUserByEmail(ctx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, nil, ErrInvalidCredentials
|
return nil, nil, ErrInvalidCredentials
|
||||||
} else {
|
} else {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,36 +85,36 @@ func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, emai
|
|||||||
|
|
||||||
newSession, err := service.createSession(ctx, user.Id)
|
newSession, err := service.createSession(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.DeleteSession(ctx, session.Id)
|
err = service.db.DeleteSession(ctx, session.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
err = service.db.DeleteToken(ctx, token.Token)
|
err = service.db.DeleteToken(ctx, token.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newSession, user, nil
|
return newSession, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
|
func (service ServiceImpl) SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error) {
|
||||||
if sessionId == "" {
|
if sessionId == "" {
|
||||||
return nil, nil, ErrSessionIdInvalid
|
return nil, nil, ErrSessionIdInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := service.db.GetSession(ctx, sessionId)
|
session, err := service.db.GetSession(ctx, sessionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
if session.ExpiresAt.Before(service.clock.Now()) {
|
if session.ExpiresAt.Before(service.clock.Now()) {
|
||||||
_ = service.db.DeleteSession(ctx, sessionId)
|
_ = service.db.DeleteSession(ctx, sessionId)
|
||||||
@@ -126,16 +127,16 @@ func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*t
|
|||||||
|
|
||||||
user, err := service.db.GetUser(ctx, session.UserId)
|
user, err := service.db.GetUser(ctx, session.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return session, user, nil
|
return session, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, error) {
|
func (service ServiceImpl) SignInAnonymous(ctx context.Context) (*auth_types.Session, error) {
|
||||||
session, err := service.createSession(ctx, uuid.Nil)
|
session, err := service.createSession(ctx, uuid.Nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.InfoContext(ctx, "anonymous session created", "session-id", session.Id)
|
slog.InfoContext(ctx, "anonymous session created", "session-id", session.Id)
|
||||||
@@ -143,7 +144,7 @@ func (service AuthImpl) SignInAnonymous(ctx context.Context) (*types.Session, er
|
|||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignUp(ctx context.Context, email string, password string) (*types.User, error) {
|
func (service ServiceImpl) SignUp(ctx context.Context, email string, password string) (*auth_types.User, error) {
|
||||||
_, err := mail.ParseAddress(email)
|
_, err := mail.ParseAddress(email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrInvalidEmail
|
return nil, ErrInvalidEmail
|
||||||
@@ -155,37 +156,37 @@ func (service AuthImpl) SignUp(ctx context.Context, email string, password strin
|
|||||||
|
|
||||||
userId, err := service.random.UUID(ctx)
|
userId, err := service.random.UUID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
salt, err := service.random.Bytes(ctx, 16)
|
salt, err := service.random.Bytes(ctx, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
hash := GetHashPassword(password, salt)
|
hash := GetHashPassword(password, salt)
|
||||||
|
|
||||||
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
|
user := auth_types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
|
||||||
|
|
||||||
err = service.db.InsertUser(ctx, user)
|
err = service.db.InsertUser(ctx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrAlreadyExists) {
|
if errors.Is(err, core.ErrAlreadyExists) {
|
||||||
return nil, ErrAccountExists
|
return nil, ErrAccountExists
|
||||||
} else {
|
} else {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
|
func (service ServiceImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
|
||||||
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify)
|
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, auth_types.TokenTypeEmailVerify)
|
||||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var token *types.Token
|
var token *auth_types.Token
|
||||||
|
|
||||||
if len(tokens) > 0 {
|
if len(tokens) > 0 {
|
||||||
token = tokens[0]
|
token = tokens[0]
|
||||||
@@ -197,11 +198,11 @@ func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UU
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token = types.NewToken(
|
token = auth_types.NewToken(
|
||||||
userId,
|
userId,
|
||||||
"",
|
"",
|
||||||
newTokenStr,
|
newTokenStr,
|
||||||
types.TokenTypeEmailVerify,
|
auth_types.TokenTypeEmailVerify,
|
||||||
service.clock.Now(),
|
service.clock.Now(),
|
||||||
service.clock.Now().Add(24*time.Hour))
|
service.clock.Now().Add(24*time.Hour))
|
||||||
|
|
||||||
@@ -221,29 +222,29 @@ func (service AuthImpl) SendVerificationMail(ctx context.Context, userId uuid.UU
|
|||||||
service.mail.SendMail(ctx, email, "Welcome to spend-sparrow", w.String())
|
service.mail.SendMail(ctx, email, "Welcome to spend-sparrow", w.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
|
func (service ServiceImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
|
||||||
if tokenStr == "" {
|
if tokenStr == "" {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := service.db.GetToken(ctx, tokenStr)
|
token, err := service.db.GetToken(ctx, tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := service.db.GetUser(ctx, token.UserId)
|
user, err := service.db.GetUser(ctx, token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
if token.Type != types.TokenTypeEmailVerify {
|
if token.Type != auth_types.TokenTypeEmailVerify {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
now := service.clock.Now()
|
now := service.clock.Now()
|
||||||
|
|
||||||
if token.ExpiresAt.Before(now) {
|
if token.ExpiresAt.Before(now) {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
user.EmailVerified = true
|
user.EmailVerified = true
|
||||||
@@ -251,21 +252,21 @@ func (service AuthImpl) VerifyUserEmail(ctx context.Context, tokenStr string) er
|
|||||||
|
|
||||||
err = service.db.UpdateUser(ctx, user)
|
err = service.db.UpdateUser(ctx, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = service.db.DeleteToken(ctx, token.Token)
|
_ = service.db.DeleteToken(ctx, token.Token)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignOut(ctx context.Context, sessionId string) error {
|
func (service ServiceImpl) SignOut(ctx context.Context, sessionId string) error {
|
||||||
return service.db.DeleteSession(ctx, sessionId)
|
return service.db.DeleteSession(ctx, sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, currPass string) error {
|
func (service ServiceImpl) DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error {
|
||||||
userDb, err := service.db.GetUser(ctx, user.Id)
|
userDb, err := service.db.GetUser(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
currHash := GetHashPassword(currPass, userDb.Salt)
|
currHash := GetHashPassword(currPass, userDb.Salt)
|
||||||
@@ -283,7 +284,7 @@ func (service AuthImpl) DeleteAccount(ctx context.Context, user *types.User, cur
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, sessionId string, currPass, newPass string) error {
|
func (service ServiceImpl) ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error {
|
||||||
if !isPasswordValid(newPass) {
|
if !isPasswordValid(newPass) {
|
||||||
return ErrInvalidPassword
|
return ErrInvalidPassword
|
||||||
}
|
}
|
||||||
@@ -308,13 +309,13 @@ func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, se
|
|||||||
|
|
||||||
sessions, err := service.db.GetSessions(ctx, user.Id)
|
sessions, err := service.db.GetSessions(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
for _, s := range sessions {
|
for _, s := range sessions {
|
||||||
if s.Id != sessionId {
|
if s.Id != sessionId {
|
||||||
err = service.db.DeleteSession(ctx, s.Id)
|
err = service.db.DeleteSession(ctx, s.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -322,7 +323,7 @@ func (service AuthImpl) ChangePassword(ctx context.Context, user *types.User, se
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
|
func (service ServiceImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
|
||||||
tokenStr, err := service.random.String(ctx, 32)
|
tokenStr, err := service.random.String(ctx, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -330,38 +331,38 @@ func (service AuthImpl) SendForgotPasswordMail(ctx context.Context, email string
|
|||||||
|
|
||||||
user, err := service.db.GetUserByEmail(ctx, email)
|
user, err := service.db.GetUserByEmail(ctx, email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
token := types.NewToken(
|
token := auth_types.NewToken(
|
||||||
user.Id,
|
user.Id,
|
||||||
"",
|
"",
|
||||||
tokenStr,
|
tokenStr,
|
||||||
types.TokenTypePasswordReset,
|
auth_types.TokenTypePasswordReset,
|
||||||
service.clock.Now(),
|
service.clock.Now(),
|
||||||
service.clock.Now().Add(15*time.Minute))
|
service.clock.Now().Add(15*time.Minute))
|
||||||
|
|
||||||
err = service.db.InsertToken(ctx, token)
|
err = service.db.InsertToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
var mail strings.Builder
|
var mail strings.Builder
|
||||||
err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail)
|
err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not render reset password email", "err", err)
|
slog.ErrorContext(ctx, "Could not render reset password email", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
service.mail.SendMail(ctx, email, "Reset Password", mail.String())
|
service.mail.SendMail(ctx, email, "Reset Password", mail.String())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
|
func (service ServiceImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
|
||||||
if !isPasswordValid(newPass) {
|
if !isPasswordValid(newPass) {
|
||||||
return ErrInvalidPassword
|
return ErrInvalidPassword
|
||||||
}
|
}
|
||||||
@@ -376,7 +377,7 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if token.Type != types.TokenTypePasswordReset ||
|
if token.Type != auth_types.TokenTypePasswordReset ||
|
||||||
token.ExpiresAt.Before(service.clock.Now()) {
|
token.ExpiresAt.Before(service.clock.Now()) {
|
||||||
return ErrTokenInvalid
|
return ErrTokenInvalid
|
||||||
}
|
}
|
||||||
@@ -384,7 +385,7 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new
|
|||||||
user, err := service.db.GetUser(ctx, token.UserId)
|
user, err := service.db.GetUser(ctx, token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not get user from token", "err", err)
|
slog.ErrorContext(ctx, "Could not get user from token", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
passHash := GetHashPassword(newPass, user.Salt)
|
passHash := GetHashPassword(newPass, user.Salt)
|
||||||
@@ -397,26 +398,26 @@ func (service AuthImpl) ForgotPassword(ctx context.Context, tokenStr string, new
|
|||||||
|
|
||||||
sessions, err := service.db.GetSessions(ctx, user.Id)
|
sessions, err := service.db.GetSessions(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, session := range sessions {
|
for _, session := range sessions {
|
||||||
err = service.db.DeleteSession(ctx, session.Id)
|
err = service.db.DeleteSession(ctx, session.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
|
func (service ServiceImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
|
||||||
token, err := service.db.GetToken(ctx, tokenStr)
|
token, err := service.db.GetToken(ctx, tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if token.Type != types.TokenTypeCsrf ||
|
if token.Type != auth_types.TokenTypeCsrf ||
|
||||||
token.SessionId != sessionId ||
|
token.SessionId != sessionId ||
|
||||||
token.ExpiresAt.Before(service.clock.Now()) {
|
token.ExpiresAt.Before(service.clock.Now()) {
|
||||||
return false
|
return false
|
||||||
@@ -425,12 +426,12 @@ func (service AuthImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, s
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session) (string, error) {
|
func (service ServiceImpl) GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error) {
|
||||||
if session == nil {
|
if session == nil {
|
||||||
return "", types.ErrInternal
|
return "", core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf)
|
||||||
|
|
||||||
if len(tokens) > 0 {
|
if len(tokens) > 0 {
|
||||||
return tokens[0].Token, nil
|
return tokens[0].Token, nil
|
||||||
@@ -438,19 +439,19 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session
|
|||||||
|
|
||||||
tokenStr, err := service.random.String(ctx, 32)
|
tokenStr, err := service.random.String(ctx, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", types.ErrInternal
|
return "", core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
token := types.NewToken(
|
token := auth_types.NewToken(
|
||||||
session.UserId,
|
session.UserId,
|
||||||
session.Id,
|
session.Id,
|
||||||
tokenStr,
|
tokenStr,
|
||||||
types.TokenTypeCsrf,
|
auth_types.TokenTypeCsrf,
|
||||||
service.clock.Now(),
|
service.clock.Now(),
|
||||||
service.clock.Now().Add(8*time.Hour))
|
service.clock.Now().Add(8*time.Hour))
|
||||||
err = service.db.InsertToken(ctx, token)
|
err = service.db.InsertToken(ctx, token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", types.ErrInternal
|
return "", core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.InfoContext(ctx, "CSRF-Token created", "token", tokenStr)
|
slog.InfoContext(ctx, "CSRF-Token created", "token", tokenStr)
|
||||||
@@ -458,34 +459,34 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session
|
|||||||
return tokenStr, nil
|
return tokenStr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) CleanupSessionsAndTokens(ctx context.Context) error {
|
func (service ServiceImpl) CleanupSessionsAndTokens(ctx context.Context) error {
|
||||||
err := service.db.DeleteOldSessions(ctx)
|
err := service.db.DeleteOldSessions(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.DeleteOldTokens(ctx)
|
err = service.db.DeleteOldTokens(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*types.Session, error) {
|
func (service ServiceImpl) createSession(ctx context.Context, userId uuid.UUID) (*auth_types.Session, error) {
|
||||||
sessionId, err := service.random.String(ctx, 32)
|
sessionId, err := service.random.String(ctx, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
createAt := service.clock.Now()
|
createAt := service.clock.Now()
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
|
|
||||||
session := types.NewSession(sessionId, userId, createAt, expiresAt)
|
session := auth_types.NewSession(sessionId, userId, createAt, expiresAt)
|
||||||
|
|
||||||
err = service.db.InsertSession(ctx, session)
|
err = service.db.InsertSession(ctx, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return session, nil
|
return session, nil
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package auth
|
package template
|
||||||
|
|
||||||
templ ChangePasswordComp(isPasswordReset bool) {
|
templ ChangePasswordComp(isPasswordReset bool) {
|
||||||
<form
|
<form
|
||||||
1
internal/authentication/template/default.go
Normal file
1
internal/authentication/template/default.go
Normal file
@@ -0,0 +1 @@
|
|||||||
|
package template
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package auth
|
package template
|
||||||
|
|
||||||
templ DeleteAccountComp() {
|
templ DeleteAccountComp() {
|
||||||
<form
|
<form
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package auth
|
package template
|
||||||
|
|
||||||
templ ResetPasswordComp() {
|
templ ResetPasswordComp() {
|
||||||
<form
|
<form
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
package auth
|
package template
|
||||||
|
|
||||||
templ SignInOrUpComp(isSignIn bool) {
|
templ SignInOrUpComp(isSignIn bool) {
|
||||||
{{
|
{{
|
||||||
var postUrl string
|
var postUrl string
|
||||||
if isSignIn {
|
if isSignIn {
|
||||||
postUrl = "/api/auth/signin"
|
postUrl = "/api/auth/signin"
|
||||||
} else {
|
} else {
|
||||||
postUrl = "/api/auth/signup"
|
postUrl = "/api/auth/signup"
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
<form
|
<form
|
||||||
class="max-w-xl px-2 mx-auto flex flex-col gap-4 h-full justify-center"
|
class="max-w-xl px-2 mx-auto flex flex-col gap-4 h-full justify-center"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package auth
|
package template
|
||||||
|
|
||||||
templ VerifyComp() {
|
templ VerifyComp() {
|
||||||
<main class="h-full">
|
<main class="h-full">
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package auth
|
package template
|
||||||
|
|
||||||
templ VerifyResponseComp(isVerified bool) {
|
templ VerifyResponseComp(isVerified bool) {
|
||||||
<main>
|
<main>
|
||||||
@@ -2,7 +2,7 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/auth_types"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextKey string
|
type ContextKey string
|
||||||
@@ -10,13 +10,13 @@ type ContextKey string
|
|||||||
var SessionKey ContextKey = "session"
|
var SessionKey ContextKey = "session"
|
||||||
var UserKey ContextKey = "user"
|
var UserKey ContextKey = "user"
|
||||||
|
|
||||||
func GetUser(r *http.Request) *types.User {
|
func GetUser(r *http.Request) *auth_types.User {
|
||||||
obj := r.Context().Value(UserKey)
|
obj := r.Context().Value(UserKey)
|
||||||
if obj == nil {
|
if obj == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := obj.(*types.User)
|
user, ok := obj.(*auth_types.User)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -24,16 +24,28 @@ func GetUser(r *http.Request) *types.User {
|
|||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSession(r *http.Request) *types.Session {
|
func GetSession(r *http.Request) *auth_types.Session {
|
||||||
obj := r.Context().Value(SessionKey)
|
obj := r.Context().Value(SessionKey)
|
||||||
if obj == nil {
|
if obj == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
session, ok := obj.(*types.Session)
|
session, ok := obj.(*auth_types.Session)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateSessionCookie(sessionId string) http.Cookie {
|
||||||
|
return http.Cookie{
|
||||||
|
Name: "id",
|
||||||
|
Value: sessionId,
|
||||||
|
MaxAge: 60 * 60 * 8, // 8 hours
|
||||||
|
Secure: true,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
Path: "/",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package core
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
@@ -3,8 +3,6 @@ package core
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"spend-sparrow/internal/db"
|
|
||||||
"spend-sparrow/internal/service"
|
|
||||||
"spend-sparrow/internal/utils"
|
"spend-sparrow/internal/utils"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -14,13 +12,13 @@ import (
|
|||||||
|
|
||||||
func HandleError(w http.ResponseWriter, r *http.Request, err error) {
|
func HandleError(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, service.ErrUnauthorized):
|
case errors.Is(err, ErrUnauthorized):
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
case errors.Is(err, service.ErrBadRequest):
|
case errors.Is(err, ErrBadRequest):
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
case errors.Is(err, db.ErrNotFound):
|
case errors.Is(err, ErrNotFound):
|
||||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusNotFound)
|
utils.TriggerToastWithStatus(r.Context(), w, r, "error", extractErrorMessage(err), http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
13
internal/core/error.go
Normal file
13
internal/core/error.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package core
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotFound = errors.New("the value does not exist")
|
||||||
|
ErrAlreadyExists = errors.New("row already exists")
|
||||||
|
|
||||||
|
ErrInternal = errors.New("internal server error")
|
||||||
|
ErrUnauthorized = errors.New("you are not authorized to perform this action")
|
||||||
|
|
||||||
|
ErrBadRequest = errors.New("bad request")
|
||||||
|
)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
package service
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"spend-sparrow/internal/types"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@@ -28,7 +27,7 @@ func (r *RandomImpl) Bytes(ctx context.Context, tsize int) ([]byte, error) {
|
|||||||
_, err := rand.Read(b)
|
_, err := rand.Read(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Error generating random bytes", "err", err)
|
slog.ErrorContext(ctx, "Error generating random bytes", "err", err)
|
||||||
return []byte{}, types.ErrInternal
|
return []byte{}, ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return b, nil
|
return b, nil
|
||||||
@@ -38,7 +37,7 @@ func (r *RandomImpl) String(ctx context.Context, size int) (string, error) {
|
|||||||
bytes, err := r.Bytes(ctx, size)
|
bytes, err := r.Bytes(ctx, size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Error generating random string", "err", err)
|
slog.ErrorContext(ctx, "Error generating random string", "err", err)
|
||||||
return "", types.ErrInternal
|
return "", ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return base64.StdEncoding.EncodeToString(bytes), nil
|
return base64.StdEncoding.EncodeToString(bytes), nil
|
||||||
@@ -48,7 +47,7 @@ func (r *RandomImpl) UUID(ctx context.Context) (uuid.UUID, error) {
|
|||||||
id, err := uuid.NewRandom()
|
id, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Error generating random UUID", "err", err)
|
slog.ErrorContext(ctx, "Error generating random UUID", "err", err)
|
||||||
return uuid.Nil, types.ErrInternal
|
return uuid.Nil, ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return id, nil
|
return id, nil
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/a-h/templ"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/auth_types"
|
||||||
|
|
||||||
|
"github.com/a-h/templ"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Render struct {
|
type Render struct {
|
||||||
@@ -28,18 +29,18 @@ func (render *Render) Render(r *http.Request, w http.ResponseWriter, comp templ.
|
|||||||
render.RenderWithStatus(r, w, comp, http.StatusOK)
|
render.RenderWithStatus(r, w, comp, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *types.User) {
|
func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *auth_types.User) {
|
||||||
render.RenderLayoutWithStatus(r, w, slot, user, http.StatusOK)
|
render.RenderLayoutWithStatus(r, w, slot, user, http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWriter, slot templ.Component, user *types.User, status int) {
|
func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWriter, slot templ.Component, user *auth_types.User, status int) {
|
||||||
userComp := render.getUserComp(user)
|
userComp := render.getUserComp(user)
|
||||||
layout := Layout(slot, userComp, user != nil, r.URL.Path)
|
layout := Layout(slot, userComp, user != nil, r.URL.Path)
|
||||||
|
|
||||||
render.RenderWithStatus(r, w, layout, status)
|
render.RenderWithStatus(r, w, layout, status)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (render *Render) getUserComp(user *types.User) templ.Component {
|
func (render *Render) getUserComp(user *auth_types.User) templ.Component {
|
||||||
if user != nil {
|
if user != nil {
|
||||||
return UserComp(user.Email)
|
return UserComp(user.Email)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -5,33 +5,28 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/core"
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrNotFound = errors.New("the value does not exist")
|
|
||||||
ErrAlreadyExists = errors.New("row already exists")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TransformAndLogDbError(ctx context.Context, module string, r sql.Result, err error) error {
|
func TransformAndLogDbError(ctx context.Context, module string, r sql.Result, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return ErrNotFound
|
return core.ErrNotFound
|
||||||
}
|
}
|
||||||
slog.ErrorContext(ctx, "database sql", "module", module, "err", err)
|
slog.ErrorContext(ctx, "database sql", "module", module, "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
if r != nil {
|
if r != nil {
|
||||||
rows, err := r.RowsAffected()
|
rows, err := r.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "database rows affected", "module", module, "err", err)
|
slog.ErrorContext(ctx, "database rows affected", "module", module, "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows == 0 {
|
if rows == 0 {
|
||||||
slog.InfoContext(ctx, "row not found", "module", module)
|
slog.InfoContext(ctx, "row not found", "module", module)
|
||||||
return ErrNotFound
|
return core.ErrNotFound
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/core"
|
||||||
|
|
||||||
"github.com/golang-migrate/migrate/v4"
|
"github.com/golang-migrate/migrate/v4"
|
||||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||||
@@ -25,7 +25,7 @@ func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error {
|
|||||||
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not create Migration instance", "err", err)
|
slog.ErrorContext(ctx, "Could not create Migration instance", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := migrate.NewWithDatabaseInstance(
|
m, err := migrate.NewWithDatabaseInstance(
|
||||||
@@ -34,14 +34,14 @@ func RunMigrations(ctx context.Context, db *sqlx.DB, pathPrefix string) error {
|
|||||||
driver)
|
driver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "Could not create migrations instance", "err", err)
|
slog.ErrorContext(ctx, "Could not create migrations instance", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Log = migrationLogger{}
|
m.Log = migrationLogger{}
|
||||||
|
|
||||||
if err = m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
if err = m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||||
slog.ErrorContext(ctx, "Could not run migrations", "err", err)
|
slog.ErrorContext(ctx, "Could not run migrations", "err", err)
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"spend-sparrow/internal/account"
|
"spend-sparrow/internal/account"
|
||||||
|
"spend-sparrow/internal/authentication"
|
||||||
"spend-sparrow/internal/core"
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/handler"
|
"spend-sparrow/internal/handler"
|
||||||
@@ -107,13 +108,13 @@ func shutdownServer(ctx context.Context, s *http.Server, wg *sync.WaitGroup) {
|
|||||||
func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||||
var router = http.NewServeMux()
|
var router = http.NewServeMux()
|
||||||
|
|
||||||
authDb := db.NewAuthSqlite(d)
|
authDb := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
randomService := service.NewRandom()
|
randomService := core.NewRandom()
|
||||||
clockService := service.NewClock()
|
clockService := core.NewClock()
|
||||||
mailService := service.NewMail(serverSettings)
|
mailService := core.NewMail(serverSettings)
|
||||||
|
|
||||||
authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings)
|
authService := authentication.NewService(authDb, randomService, clockService, mailService, serverSettings)
|
||||||
accountService := account.NewServiceImpl(d, randomService, clockService)
|
accountService := account.NewServiceImpl(d, randomService, clockService)
|
||||||
treasureChestService := service.NewTreasureChest(d, randomService, clockService)
|
treasureChestService := service.NewTreasureChest(d, randomService, clockService)
|
||||||
transactionService := service.NewTransaction(d, randomService, clockService)
|
transactionService := service.NewTransaction(d, randomService, clockService)
|
||||||
@@ -123,7 +124,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *
|
|||||||
render := core.NewRender()
|
render := core.NewRender()
|
||||||
indexHandler := handler.NewIndex(render, clockService)
|
indexHandler := handler.NewIndex(render, clockService)
|
||||||
dashboardHandler := handler.NewDashboard(render, dashboardService, treasureChestService)
|
dashboardHandler := handler.NewDashboard(render, dashboardService, treasureChestService)
|
||||||
authHandler := handler.NewAuth(authService, render)
|
authHandler := authentication.NewHandler(authService, render)
|
||||||
accountHandler := account.NewHandler(accountService, render)
|
accountHandler := account.NewHandler(accountService, render)
|
||||||
treasureChestHandler := handler.NewTreasureChest(treasureChestService, transactionRecurringService, render)
|
treasureChestHandler := handler.NewTreasureChest(treasureChestService, transactionRecurringService, render)
|
||||||
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
|
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
|
||||||
@@ -157,7 +158,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *
|
|||||||
return wrapper
|
return wrapper
|
||||||
}
|
}
|
||||||
|
|
||||||
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
|
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) {
|
||||||
runDailyTasks(ctx, transactionRecurring, auth)
|
runDailyTasks(ctx, transactionRecurring, auth)
|
||||||
ticker := time.NewTicker(24 * time.Hour)
|
ticker := time.NewTicker(24 * time.Hour)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
@@ -172,7 +173,7 @@ func dailyTaskTimer(ctx context.Context, transactionRecurring service.Transactio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
|
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) {
|
||||||
slog.InfoContext(ctx, "Running daily tasks")
|
slog.InfoContext(ctx, "Running daily tasks")
|
||||||
_ = transactionRecurring.GenerateTransactions(ctx)
|
_ = transactionRecurring.GenerateTransactions(ctx)
|
||||||
_ = auth.CleanupSessionsAndTokens(ctx)
|
_ = auth.CleanupSessionsAndTokens(ctx)
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ func (handler DashboardImpl) handleDashboardTreasureChest() http.HandlerFunc {
|
|||||||
if treasureChestStr != "" {
|
if treasureChestStr != "" {
|
||||||
id, err := uuid.Parse(treasureChestStr)
|
id, err := uuid.Parse(treasureChestStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", service.ErrBadRequest))
|
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", core.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"spend-sparrow/internal/authentication"
|
||||||
"spend-sparrow/internal/core"
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/service"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
func Authenticate(service authentication.Service) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
@@ -31,7 +31,7 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie := CreateSessionCookie(session.Id)
|
cookie := core.CreateSessionCookie(session.Id)
|
||||||
http.SetCookie(w, &cookie)
|
http.SetCookie(w, &cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"spend-sparrow/internal/authentication"
|
||||||
"spend-sparrow/internal/core"
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/service"
|
|
||||||
"spend-sparrow/internal/utils"
|
"spend-sparrow/internal/utils"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -31,7 +31,7 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
|
|||||||
return rr.ResponseWriter.Write([]byte(dataStr))
|
return rr.ResponseWriter.Write([]byte(dataStr))
|
||||||
}
|
}
|
||||||
|
|
||||||
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
|
func CrossSiteRequestForgery(auth authentication.Service) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|||||||
@@ -1,15 +1 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
func CreateSessionCookie(sessionId string) http.Cookie {
|
|
||||||
return http.Cookie{
|
|
||||||
Name: "id",
|
|
||||||
Value: sessionId,
|
|
||||||
MaxAge: 60 * 60 * 8, // 8 hours
|
|
||||||
Secure: true,
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
Path: "/",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"spend-sparrow/internal/core"
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/service"
|
|
||||||
"spend-sparrow/internal/template"
|
"spend-sparrow/internal/template"
|
||||||
"spend-sparrow/internal/utils"
|
"spend-sparrow/internal/utils"
|
||||||
|
|
||||||
@@ -16,10 +15,10 @@ type Index interface {
|
|||||||
|
|
||||||
type IndexImpl struct {
|
type IndexImpl struct {
|
||||||
r *core.Render
|
r *core.Render
|
||||||
c service.Clock
|
c core.Clock
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIndex(r *core.Render, c service.Clock) Index {
|
func NewIndex(r *core.Render, c core.Clock) Index {
|
||||||
return IndexImpl{
|
return IndexImpl{
|
||||||
r: r,
|
r: r,
|
||||||
c: c,
|
c: c,
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
|||||||
if idStr != "new" {
|
if idStr != "new" {
|
||||||
id, err = uuid.Parse(idStr)
|
id, err = uuid.Parse(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest))
|
core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -167,7 +167,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
|||||||
if accountIdStr != "" {
|
if accountIdStr != "" {
|
||||||
i, err := uuid.Parse(accountIdStr)
|
i, err := uuid.Parse(accountIdStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", service.ErrBadRequest))
|
core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", core.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
accountId = &i
|
accountId = &i
|
||||||
@@ -178,7 +178,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
|||||||
if treasureChestIdStr != "" {
|
if treasureChestIdStr != "" {
|
||||||
i, err := uuid.Parse(treasureChestIdStr)
|
i, err := uuid.Parse(treasureChestIdStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", service.ErrBadRequest))
|
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", core.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
treasureChestId = &i
|
treasureChestId = &i
|
||||||
@@ -186,14 +186,14 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
|||||||
|
|
||||||
valueF, err := strconv.ParseFloat(r.FormValue("value"), 64)
|
valueF, err := strconv.ParseFloat(r.FormValue("value"), 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.HandleError(w, r, fmt.Errorf("could not parse value: %w", service.ErrBadRequest))
|
core.HandleError(w, r, fmt.Errorf("could not parse value: %w", core.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
value := int64(math.Round(valueF * service.DECIMALS_MULTIPLIER))
|
value := int64(math.Round(valueF * service.DECIMALS_MULTIPLIER))
|
||||||
|
|
||||||
timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp"))
|
timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", service.ErrBadRequest))
|
core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"spend-sparrow/internal/auth_types"
|
||||||
"spend-sparrow/internal/core"
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/service"
|
"spend-sparrow/internal/service"
|
||||||
t "spend-sparrow/internal/template/transaction_recurring"
|
t "spend-sparrow/internal/template/transaction_recurring"
|
||||||
@@ -111,7 +112,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) {
|
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *auth_types.User, id, accountId, treasureChestId string) {
|
||||||
var transactionsRecurring []*types.TransactionRecurring
|
var transactionsRecurring []*types.TransactionRecurring
|
||||||
var err error
|
var err error
|
||||||
if accountId == "" && treasureChestId == "" {
|
if accountId == "" && treasureChestId == "" {
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
"time"
|
"time"
|
||||||
@@ -22,10 +24,10 @@ func NewDashboard(db *sqlx.DB) *Dashboard {
|
|||||||
|
|
||||||
func (s Dashboard) MainChart(
|
func (s Dashboard) MainChart(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
user *types.User,
|
user *auth_types.User,
|
||||||
) ([]types.DashboardMainChartEntry, error) {
|
) ([]types.DashboardMainChartEntry, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
transactions := make([]types.Transaction, 0)
|
transactions := make([]types.Transaction, 0)
|
||||||
@@ -82,10 +84,10 @@ func (s Dashboard) MainChart(
|
|||||||
|
|
||||||
func (s Dashboard) TreasureChests(
|
func (s Dashboard) TreasureChests(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
user *types.User,
|
user *auth_types.User,
|
||||||
) ([]*types.DashboardTreasureChest, error) {
|
) ([]*types.DashboardTreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChests := make([]*types.TreasureChest, 0)
|
treasureChests := make([]*types.TreasureChest, 0)
|
||||||
@@ -120,11 +122,11 @@ func (s Dashboard) TreasureChests(
|
|||||||
|
|
||||||
func (s Dashboard) TreasureChest(
|
func (s Dashboard) TreasureChest(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
user *types.User,
|
user *auth_types.User,
|
||||||
treausureChestId *uuid.UUID,
|
treausureChestId *uuid.UUID,
|
||||||
) ([]types.DashboardMainChartEntry, error) {
|
) ([]types.DashboardMainChartEntry, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
transactions := make([]types.Transaction, 0)
|
transactions := make([]types.Transaction, 0)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -16,9 +17,9 @@ var (
|
|||||||
func ValidateString(value string, fieldName string) error {
|
func ValidateString(value string, fieldName string) error {
|
||||||
switch {
|
switch {
|
||||||
case value == "":
|
case value == "":
|
||||||
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest)
|
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, core.ErrBadRequest)
|
||||||
case !safeInputRegex.MatchString(value):
|
case !safeInputRegex.MatchString(value):
|
||||||
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest)
|
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, core.ErrBadRequest)
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrBadRequest = errors.New("bad request")
|
|
||||||
ErrUnauthorized = errors.New("unauthorized")
|
|
||||||
)
|
|
||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -17,22 +19,22 @@ import (
|
|||||||
const page_size = 25
|
const page_size = 25
|
||||||
|
|
||||||
type Transaction interface {
|
type Transaction interface {
|
||||||
Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||||
Update(ctx context.Context, user *types.User, transaction types.Transaction) (*types.Transaction, error)
|
Update(ctx context.Context, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||||
Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error)
|
Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error)
|
||||||
GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
|
GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
|
||||||
Delete(ctx context.Context, user *types.User, id string) error
|
Delete(ctx context.Context, user *auth_types.User, id string) error
|
||||||
|
|
||||||
RecalculateBalances(ctx context.Context, user *types.User) error
|
RecalculateBalances(ctx context.Context, user *auth_types.User) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransactionImpl struct {
|
type TransactionImpl struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
clock Clock
|
clock core.Clock
|
||||||
random Random
|
random core.Random
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
|
func NewTransaction(db *sqlx.DB, random core.Random, clock core.Clock) Transaction {
|
||||||
return TransactionImpl{
|
return TransactionImpl{
|
||||||
db: db,
|
db: db,
|
||||||
clock: clock,
|
clock: clock,
|
||||||
@@ -40,9 +42,9 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
|
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transactionInput types.Transaction) (*types.Transaction, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
@@ -107,9 +109,9 @@ func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User,
|
|||||||
return transaction, nil
|
return transaction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) {
|
func (s TransactionImpl) Update(ctx context.Context, user *auth_types.User, input types.Transaction) (*types.Transaction, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -125,10 +127,10 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ
|
|||||||
err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
|
err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
|
||||||
err = db.TransformAndLogDbError(ctx, "transaction Update", nil, err)
|
err = db.TransformAndLogDbError(ctx, "transaction Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest)
|
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
if transaction.Error == nil && transaction.AccountId != nil {
|
if transaction.Error == nil && transaction.AccountId != nil {
|
||||||
@@ -206,32 +208,32 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ
|
|||||||
return transaction, nil
|
return transaction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) {
|
func (s TransactionImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
uuid, err := uuid.Parse(id)
|
uuid, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transaction get", "err", err)
|
slog.ErrorContext(ctx, "transaction get", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
var transaction types.Transaction
|
var transaction types.Transaction
|
||||||
err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError(ctx, "transaction Get", nil, err)
|
err = db.TransformAndLogDbError(ctx, "transaction Get", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest)
|
return nil, fmt.Errorf("transaction %v not found: %w", id, core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return &transaction, nil
|
return &transaction, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
|
func (s TransactionImpl) GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -277,14 +279,14 @@ func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter ty
|
|||||||
return transactions, nil
|
return transactions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
func (s TransactionImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
uuid, err := uuid.Parse(id)
|
uuid, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transaction delete", "err", err)
|
slog.ErrorContext(ctx, "transaction delete", "err", err)
|
||||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -310,7 +312,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
|
|||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||||
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
|
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
|
||||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -322,7 +324,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
|
|||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||||
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
|
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
|
||||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -342,9 +344,9 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error {
|
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *auth_types.User) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -361,7 +363,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
|
|||||||
SET current_balance = 0
|
SET current_balance = 0
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
|
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
|
||||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,7 +372,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
|
|||||||
SET current_balance = 0
|
SET current_balance = 0
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
|
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
|
||||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,7 +381,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
|
|||||||
FROM "transaction"
|
FROM "transaction"
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", nil, err)
|
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", nil, err)
|
||||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -458,7 +460,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
|
|||||||
if oldTransaction == nil {
|
if oldTransaction == nil {
|
||||||
id, err = s.random.UUID(ctx)
|
id, err = s.random.UUID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
createdAt = s.clock.Now()
|
createdAt = s.clock.Now()
|
||||||
createdBy = userId
|
createdBy = userId
|
||||||
@@ -479,7 +481,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
|
|||||||
}
|
}
|
||||||
if rowCount == 0 {
|
if rowCount == 0 {
|
||||||
slog.ErrorContext(ctx, "transaction validate", "err", err)
|
slog.ErrorContext(ctx, "transaction validate", "err", err)
|
||||||
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
|
return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -488,13 +490,13 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
|
|||||||
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
|
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
|
||||||
err = db.TransformAndLogDbError(ctx, "transaction validate", nil, err)
|
err = db.TransformAndLogDbError(ctx, "transaction validate", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
|
return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if treasureChest.ParentId == nil {
|
if treasureChest.ParentId == nil {
|
||||||
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest)
|
return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -16,24 +18,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TransactionRecurring interface {
|
type TransactionRecurring interface {
|
||||||
Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
Add(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||||
Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
Update(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||||
GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error)
|
GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error)
|
||||||
GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error)
|
GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error)
|
||||||
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
GetAllByTreasureChest(ctx context.Context, user *auth_types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||||
Delete(ctx context.Context, user *types.User, id string) error
|
Delete(ctx context.Context, user *auth_types.User, id string) error
|
||||||
|
|
||||||
GenerateTransactions(ctx context.Context) error
|
GenerateTransactions(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransactionRecurringImpl struct {
|
type TransactionRecurringImpl struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
clock Clock
|
clock core.Clock
|
||||||
random Random
|
random core.Random
|
||||||
transaction Transaction
|
transaction Transaction
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transaction Transaction) TransactionRecurring {
|
func NewTransactionRecurring(db *sqlx.DB, random core.Random, clock core.Clock, transaction Transaction) TransactionRecurring {
|
||||||
return TransactionRecurringImpl{
|
return TransactionRecurringImpl{
|
||||||
db: db,
|
db: db,
|
||||||
clock: clock,
|
clock: clock,
|
||||||
@@ -43,11 +45,11 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Add(ctx context.Context,
|
func (s TransactionRecurringImpl) Add(ctx context.Context,
|
||||||
user *types.User,
|
user *auth_types.User,
|
||||||
transactionRecurringInput types.TransactionRecurringInput,
|
transactionRecurringInput types.TransactionRecurringInput,
|
||||||
) (*types.TransactionRecurring, error) {
|
) (*types.TransactionRecurring, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -85,16 +87,16 @@ func (s TransactionRecurringImpl) Add(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Update(ctx context.Context,
|
func (s TransactionRecurringImpl) Update(ctx context.Context,
|
||||||
user *types.User,
|
user *auth_types.User,
|
||||||
input types.TransactionRecurringInput,
|
input types.TransactionRecurringInput,
|
||||||
) (*types.TransactionRecurring, error) {
|
) (*types.TransactionRecurring, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
uuid, err := uuid.Parse(input.Id)
|
uuid, err := uuid.Parse(input.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring update", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring update", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -110,10 +112,10 @@ func (s TransactionRecurringImpl) Update(ctx context.Context,
|
|||||||
err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError(ctx, "transactionRecurring Update", nil, err)
|
err = db.TransformAndLogDbError(ctx, "transactionRecurring Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest)
|
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
|
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
|
||||||
@@ -149,9 +151,9 @@ func (s TransactionRecurringImpl) Update(ctx context.Context,
|
|||||||
return transactionRecurring, nil
|
return transactionRecurring, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) {
|
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||||
@@ -169,15 +171,15 @@ func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User)
|
|||||||
return transactionRecurrings, nil
|
return transactionRecurrings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
accountUuid, err := uuid.Parse(accountId)
|
accountUuid, err := uuid.Parse(accountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring GetAllByAccount", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring GetAllByAccount", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -193,10 +195,10 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ
|
|||||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
||||||
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByAccount", nil, err)
|
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByAccount", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest)
|
return nil, fmt.Errorf("account %v not found: %w", accountId, core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||||
@@ -222,17 +224,17 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
||||||
user *types.User,
|
user *auth_types.User,
|
||||||
treasureChestId string,
|
treasureChestId string,
|
||||||
) ([]*types.TransactionRecurring, error) {
|
) ([]*types.TransactionRecurring, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChestUuid, err := uuid.Parse(treasureChestId)
|
treasureChestUuid, err := uuid.Parse(treasureChestId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring GetAllByTreasureChest", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring GetAllByTreasureChest", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -248,10 +250,10 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
|||||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
||||||
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByTreasureChest", nil, err)
|
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByTreasureChest", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest)
|
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||||
@@ -276,14 +278,14 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
|||||||
return transactionRecurrings, nil
|
return transactionRecurrings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
uuid, err := uuid.Parse(id)
|
uuid, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring delete", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring delete", "err", err)
|
||||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -339,7 +341,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, transactionRecurring := range recurringTransactions {
|
for _, transactionRecurring := range recurringTransactions {
|
||||||
user := &types.User{
|
user := &auth_types.User{
|
||||||
Id: transactionRecurring.UserId,
|
Id: transactionRecurring.UserId,
|
||||||
}
|
}
|
||||||
transaction := types.Transaction{
|
transaction := types.Transaction{
|
||||||
@@ -397,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
if oldTransactionRecurring == nil {
|
if oldTransactionRecurring == nil {
|
||||||
id, err = s.random.UUID(ctx)
|
id, err = s.random.UUID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
createdAt = s.clock.Now()
|
createdAt = s.clock.Now()
|
||||||
createdBy = userId
|
createdBy = userId
|
||||||
@@ -416,7 +418,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
temp, err := uuid.Parse(input.AccountId)
|
temp, err := uuid.Parse(input.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
accountUuid = &temp
|
accountUuid = &temp
|
||||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
|
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
|
||||||
@@ -426,7 +428,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
}
|
}
|
||||||
if rowCount == 0 {
|
if rowCount == 0 {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||||
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
|
return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
hasAccount = true
|
hasAccount = true
|
||||||
@@ -436,37 +438,37 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
temp, err := uuid.Parse(input.TreasureChestId)
|
temp, err := uuid.Parse(input.TreasureChestId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
treasureChestUuid = &temp
|
treasureChestUuid = &temp
|
||||||
var treasureChest types.TreasureChest
|
var treasureChest types.TreasureChest
|
||||||
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
||||||
err = db.TransformAndLogDbError(ctx, "transactionRecurring validate", nil, err)
|
err = db.TransformAndLogDbError(ctx, "transactionRecurring validate", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
|
return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if treasureChest.ParentId == nil {
|
if treasureChest.ParentId == nil {
|
||||||
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest)
|
return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
hasTreasureChest = true
|
hasTreasureChest = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasAccount && !hasTreasureChest {
|
if !hasAccount && !hasTreasureChest {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||||
return nil, fmt.Errorf("either account or treasure chest is required: %w", ErrBadRequest)
|
return nil, fmt.Errorf("either account or treasure chest is required: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
if hasAccount && hasTreasureChest {
|
if hasAccount && hasTreasureChest {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||||
return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", ErrBadRequest)
|
return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
valueFloat, err := strconv.ParseFloat(input.Value, 64)
|
valueFloat, err := strconv.ParseFloat(input.Value, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse value: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
value := int64(math.Round(valueFloat * DECIMALS_MULTIPLIER))
|
value := int64(math.Round(valueFloat * DECIMALS_MULTIPLIER))
|
||||||
|
|
||||||
@@ -485,18 +487,18 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0)
|
intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse intervalMonths: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse intervalMonths: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
if intervalMonths < 1 {
|
if intervalMonths < 1 {
|
||||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||||
return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", ErrBadRequest)
|
return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
var nextExecution *time.Time = nil
|
var nextExecution *time.Time = nil
|
||||||
if input.NextExecution != "" {
|
if input.NextExecution != "" {
|
||||||
t, err := time.Parse("2006-01-02", input.NextExecution)
|
t, err := time.Parse("2006-01-02", input.NextExecution)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "transaction validate", "err", err)
|
slog.ErrorContext(ctx, "transaction validate", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse timestamp: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
|
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
"slices"
|
||||||
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
|
|
||||||
@@ -14,20 +16,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TreasureChest interface {
|
type TreasureChest interface {
|
||||||
Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error)
|
Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error)
|
||||||
Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error)
|
Update(ctx context.Context, user *auth_types.User, id, parentId, name string) (*types.TreasureChest, error)
|
||||||
Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error)
|
Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error)
|
||||||
GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error)
|
GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error)
|
||||||
Delete(ctx context.Context, user *types.User, id string) error
|
Delete(ctx context.Context, user *auth_types.User, id string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TreasureChestImpl struct {
|
type TreasureChestImpl struct {
|
||||||
db *sqlx.DB
|
db *sqlx.DB
|
||||||
clock Clock
|
clock core.Clock
|
||||||
random Random
|
random core.Random
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
|
func NewTreasureChest(db *sqlx.DB, random core.Random, clock core.Clock) TreasureChest {
|
||||||
return TreasureChestImpl{
|
return TreasureChestImpl{
|
||||||
db: db,
|
db: db,
|
||||||
clock: clock,
|
clock: clock,
|
||||||
@@ -35,14 +37,14 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) {
|
func (s TreasureChestImpl) Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
newId, err := s.random.UUID(ctx)
|
newId, err := s.random.UUID(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ValidateString(name, "name")
|
err = ValidateString(name, "name")
|
||||||
@@ -57,7 +59,7 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if parent.ParentId != nil {
|
if parent.ParentId != nil {
|
||||||
return nil, fmt.Errorf("only a depth of 1 allowed: %w", ErrBadRequest)
|
return nil, fmt.Errorf("only a depth of 1 allowed: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
parentUuid = &parent.Id
|
parentUuid = &parent.Id
|
||||||
}
|
}
|
||||||
@@ -88,9 +90,9 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId,
|
|||||||
return treasureChest, nil
|
return treasureChest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
|
func (s TreasureChestImpl) Update(ctx context.Context, user *auth_types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
err := ValidateString(name, "name")
|
err := ValidateString(name, "name")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -99,7 +101,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
|
|||||||
id, err := uuid.Parse(idStr)
|
id, err := uuid.Parse(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "treasureChest update", "err", err)
|
slog.ErrorContext(ctx, "treasureChest update", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -115,10 +117,10 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
|
|||||||
err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
||||||
err = db.TransformAndLogDbError(ctx, "treasureChest Update", nil, err)
|
err = db.TransformAndLogDbError(ctx, "treasureChest Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
|
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
var parentUuid *uuid.UUID
|
var parentUuid *uuid.UUID
|
||||||
@@ -134,7 +136,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if parent.ParentId != nil || childCount > 0 {
|
if parent.ParentId != nil || childCount > 0 {
|
||||||
return nil, fmt.Errorf("only one level allowed: %w", ErrBadRequest)
|
return nil, fmt.Errorf("only one level allowed: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
parentUuid = &parent.Id
|
parentUuid = &parent.Id
|
||||||
@@ -170,32 +172,32 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
|
|||||||
return treasureChest, nil
|
return treasureChest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) {
|
func (s TreasureChestImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
uuid, err := uuid.Parse(id)
|
uuid, err := uuid.Parse(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "treasureChest get", "err", err)
|
slog.ErrorContext(ctx, "treasureChest get", "err", err)
|
||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
var treasureChest types.TreasureChest
|
var treasureChest types.TreasureChest
|
||||||
err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError(ctx, "treasureChest Get", nil, err)
|
err = db.TransformAndLogDbError(ctx, "treasureChest Get", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, core.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
|
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, core.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return &treasureChest, nil
|
return &treasureChest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) {
|
func (s TreasureChestImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChests := make([]*types.TreasureChest, 0)
|
treasureChests := make([]*types.TreasureChest, 0)
|
||||||
@@ -208,14 +210,14 @@ func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*typ
|
|||||||
return sortTreasureChests(treasureChests), nil
|
return sortTreasureChests(treasureChests), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error {
|
func (s TreasureChestImpl) Delete(ctx context.Context, user *auth_types.User, idStr string) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return ErrUnauthorized
|
return core.ErrUnauthorized
|
||||||
}
|
}
|
||||||
id, err := uuid.Parse(idStr)
|
id, err := uuid.Parse(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.ErrorContext(ctx, "treasureChest delete", "err", err)
|
slog.ErrorContext(ctx, "treasureChest delete", "err", err)
|
||||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -235,7 +237,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
|
|||||||
}
|
}
|
||||||
|
|
||||||
if childCount > 0 {
|
if childCount > 0 {
|
||||||
return fmt.Errorf("treasure chest has children: %w", ErrBadRequest)
|
return fmt.Errorf("treasure chest has children: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
transactionsCount := 0
|
transactionsCount := 0
|
||||||
@@ -247,7 +249,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if transactionsCount > 0 {
|
if transactionsCount > 0 {
|
||||||
return fmt.Errorf("treasure chest has transactions: %w", ErrBadRequest)
|
return fmt.Errorf("treasure chest has transactions: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
recurringCount := 0
|
recurringCount := 0
|
||||||
@@ -259,7 +261,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if recurringCount > 0 {
|
if recurringCount > 0 {
|
||||||
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest)
|
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", core.ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
package auth
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
package types
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrInternal = errors.New("internal server error")
|
|
||||||
ErrUnauthorized = errors.New("you are not authorized to perform this action")
|
|
||||||
)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
package mocks
|
|
||||||
@@ -2,8 +2,10 @@ package test_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"spend-sparrow/internal/auth_types"
|
||||||
|
"spend-sparrow/internal/authentication"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
"spend-sparrow/internal/types"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -42,11 +44,11 @@ func TestUser(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
expected := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(context.Background(), expected)
|
err := underTest.InsertUser(context.Background(), expected)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -63,38 +65,38 @@ func TestUser(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
_, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail")
|
_, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail")
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
assert.Equal(t, core.ErrNotFound, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrUserExist", func(t *testing.T) {
|
t.Run("should return ErrUserExist", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
user := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(context.Background(), user)
|
err := underTest.InsertUser(context.Background(), user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = underTest.InsertUser(context.Background(), user)
|
err = underTest.InsertUser(context.Background(), user)
|
||||||
assert.Equal(t, db.ErrAlreadyExists, err)
|
assert.Equal(t, core.ErrAlreadyExists, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
user := auth_types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(context.Background(), user)
|
err := underTest.InsertUser(context.Background(), user)
|
||||||
assert.Equal(t, types.ErrInternal, err)
|
assert.Equal(t, core.ErrInternal, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,11 +107,11 @@ func TestToken(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
|
expected := auth_types.NewToken(uuid.New(), "sessionId", "token", auth_types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
|
||||||
err := underTest.InsertToken(context.Background(), expected)
|
err := underTest.InsertToken(context.Background(), expected)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -121,25 +123,25 @@ func TestToken(t *testing.T) {
|
|||||||
expected.SessionId = ""
|
expected.SessionId = ""
|
||||||
actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type)
|
actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected}, actuals)
|
assert.Equal(t, []*auth_types.Token{expected}, actuals)
|
||||||
|
|
||||||
expected.SessionId = "sessionId"
|
expected.SessionId = "sessionId"
|
||||||
expected.UserId = uuid.Nil
|
expected.UserId = uuid.Nil
|
||||||
actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type)
|
actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected}, actuals)
|
assert.Equal(t, []*auth_types.Token{expected}, actuals)
|
||||||
})
|
})
|
||||||
t.Run("should insert and return multiple tokens", func(t *testing.T) {
|
t.Run("should insert and return multiple tokens", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt)
|
expected1 := auth_types.NewToken(userId, "sessionId", "token1", auth_types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
|
expected2 := auth_types.NewToken(userId, "sessionId", "token2", auth_types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
|
||||||
err := underTest.InsertToken(context.Background(), expected1)
|
err := underTest.InsertToken(context.Background(), expected1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -150,7 +152,7 @@ func TestToken(t *testing.T) {
|
|||||||
expected2.UserId = uuid.Nil
|
expected2.UserId = uuid.Nil
|
||||||
actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type)
|
actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
assert.Equal(t, []*auth_types.Token{expected1, expected2}, actuals)
|
||||||
|
|
||||||
expected1.SessionId = ""
|
expected1.SessionId = ""
|
||||||
expected2.SessionId = ""
|
expected2.SessionId = ""
|
||||||
@@ -158,49 +160,49 @@ func TestToken(t *testing.T) {
|
|||||||
expected2.UserId = userId
|
expected2.UserId = userId
|
||||||
actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type)
|
actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
assert.Equal(t, []*auth_types.Token{expected1, expected2}, actuals)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrNotFound", func(t *testing.T) {
|
t.Run("should return ErrNotFound", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
_, err := underTest.GetToken(context.Background(), "nonExistent")
|
_, err := underTest.GetToken(context.Background(), "nonExistent")
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
assert.Equal(t, core.ErrNotFound, err)
|
||||||
|
|
||||||
_, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), types.TokenTypeEmailVerify)
|
_, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), auth_types.TokenTypeEmailVerify)
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
assert.Equal(t, core.ErrNotFound, err)
|
||||||
|
|
||||||
_, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", types.TokenTypeEmailVerify)
|
_, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", auth_types.TokenTypeEmailVerify)
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
assert.Equal(t, core.ErrNotFound, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
|
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
user := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(context.Background(), user)
|
err := underTest.InsertUser(context.Background(), user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = underTest.InsertUser(context.Background(), user)
|
err = underTest.InsertUser(context.Background(), user)
|
||||||
assert.Equal(t, db.ErrAlreadyExists, err)
|
assert.Equal(t, core.ErrAlreadyExists, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := authentication.NewDbSqlite(d)
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
user := auth_types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(context.Background(), user)
|
err := underTest.InsertUser(context.Background(), user)
|
||||||
assert.Equal(t, types.ErrInternal, err)
|
assert.Equal(t, core.ErrInternal, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ package test_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/auth_types"
|
||||||
"spend-sparrow/internal/service"
|
"spend-sparrow/internal/authentication"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
"spend-sparrow/mocks"
|
"spend-sparrow/mocks"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -30,26 +31,26 @@ func TestSignUp(t *testing.T) {
|
|||||||
t.Run("should check for correct email address", func(t *testing.T) {
|
t.Run("should check for correct email address", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
mockAuthDb := mocks.NewMockDb(t)
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
mockClock := mocks.NewMockClock(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
mockMail := mocks.NewMockMail(t)
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
_, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!")
|
_, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!")
|
||||||
|
|
||||||
assert.Equal(t, service.ErrInvalidEmail, err)
|
assert.Equal(t, authentication.ErrInvalidEmail, err)
|
||||||
})
|
})
|
||||||
t.Run("should check for password complexity", func(t *testing.T) {
|
t.Run("should check for password complexity", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
mockAuthDb := mocks.NewMockDb(t)
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
mockClock := mocks.NewMockClock(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
mockMail := mocks.NewMockMail(t)
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
weakPasswords := []string{
|
weakPasswords := []string{
|
||||||
"123!ab", // too short
|
"123!ab", // too short
|
||||||
@@ -60,13 +61,13 @@ func TestSignUp(t *testing.T) {
|
|||||||
|
|
||||||
for _, password := range weakPasswords {
|
for _, password := range weakPasswords {
|
||||||
_, err := underTest.SignUp(context.Background(), "some@valid.email", password)
|
_, err := underTest.SignUp(context.Background(), "some@valid.email", password)
|
||||||
assert.Equal(t, service.ErrInvalidPassword, err)
|
assert.Equal(t, authentication.ErrInvalidPassword, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("should signup correctly", func(t *testing.T) {
|
t.Run("should signup correctly", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
mockAuthDb := mocks.NewMockDb(t)
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
mockClock := mocks.NewMockClock(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
mockMail := mocks.NewMockMail(t)
|
mockMail := mocks.NewMockMail(t)
|
||||||
@@ -77,7 +78,7 @@ func TestSignUp(t *testing.T) {
|
|||||||
salt := []byte("salt")
|
salt := []byte("salt")
|
||||||
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
|
expected := auth_types.NewUser(userId, email, false, nil, false, authentication.GetHashPassword(password, salt), salt, createTime)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
@@ -86,7 +87,7 @@ func TestSignUp(t *testing.T) {
|
|||||||
mockClock.EXPECT().Now().Return(createTime)
|
mockClock.EXPECT().Now().Return(createTime)
|
||||||
mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil)
|
mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil)
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
actual, err := underTest.SignUp(context.Background(), email, password)
|
actual, err := underTest.SignUp(context.Background(), email, password)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -96,7 +97,7 @@ func TestSignUp(t *testing.T) {
|
|||||||
t.Run("should return ErrAccountExists", func(t *testing.T) {
|
t.Run("should return ErrAccountExists", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
mockAuthDb := mocks.NewMockDb(t)
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
mockClock := mocks.NewMockClock(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
mockMail := mocks.NewMockMail(t)
|
mockMail := mocks.NewMockMail(t)
|
||||||
@@ -106,19 +107,19 @@ func TestSignUp(t *testing.T) {
|
|||||||
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
password := "SomeStrongPassword123!"
|
password := "SomeStrongPassword123!"
|
||||||
salt := []byte("salt")
|
salt := []byte("salt")
|
||||||
user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
|
user := auth_types.NewUser(userId, email, false, nil, false, authentication.GetHashPassword(password, salt), salt, createTime)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
mockRandom.EXPECT().UUID(ctx).Return(user.Id, nil)
|
mockRandom.EXPECT().UUID(ctx).Return(user.Id, nil)
|
||||||
mockRandom.EXPECT().Bytes(ctx, 16).Return(salt, nil)
|
mockRandom.EXPECT().Bytes(ctx, 16).Return(salt, nil)
|
||||||
mockClock.EXPECT().Now().Return(createTime)
|
mockClock.EXPECT().Now().Return(createTime)
|
||||||
|
|
||||||
mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(db.ErrAlreadyExists)
|
mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(core.ErrAlreadyExists)
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
_, err := underTest.SignUp(context.Background(), user.Email, password)
|
_, err := underTest.SignUp(context.Background(), user.Email, password)
|
||||||
assert.Equal(t, service.ErrAccountExists, err)
|
assert.Equal(t, authentication.ErrAccountExists, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,30 +128,30 @@ func TestSendVerificationMail(t *testing.T) {
|
|||||||
t.Run("should use stored token and send mail", func(t *testing.T) {
|
t.Run("should use stored token and send mail", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
token := types.NewToken(
|
token := auth_types.NewToken(
|
||||||
uuid.New(),
|
uuid.New(),
|
||||||
"sessionId",
|
"sessionId",
|
||||||
"someRandomTokenToUse",
|
"someRandomTokenToUse",
|
||||||
types.TokenTypeEmailVerify,
|
auth_types.TokenTypeEmailVerify,
|
||||||
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
|
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||||
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
|
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
|
||||||
tokens := []*types.Token{token}
|
tokens := []*auth_types.Token{token}
|
||||||
|
|
||||||
email := "some@email.de"
|
email := "some@email.de"
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
mockAuthDb := mocks.NewMockDb(t)
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
mockClock := mocks.NewMockClock(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
mockMail := mocks.NewMockMail(t)
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, types.TokenTypeEmailVerify).Return(tokens, nil)
|
mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, auth_types.TokenTypeEmailVerify).Return(tokens, nil)
|
||||||
mockMail.EXPECT().SendMail(ctx, email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
|
mockMail.EXPECT().SendMail(ctx, email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
|
||||||
return strings.Contains(message, token.Token)
|
return strings.Contains(message, token.Token)
|
||||||
})).Return()
|
})).Return()
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
underTest.SendVerificationMail(context.Background(), userId, email)
|
underTest.SendVerificationMail(context.Background(), userId, email)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"spend-sparrow/internal"
|
"spend-sparrow/internal"
|
||||||
"spend-sparrow/internal/service"
|
"spend-sparrow/internal/auth_types"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/authentication"
|
||||||
|
"spend-sparrow/internal/core"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -117,7 +118,7 @@ func waitForReady(
|
|||||||
default:
|
default:
|
||||||
if time.Since(startTime) >= timeout {
|
if time.Since(startTime) >= timeout {
|
||||||
t.Fatal("timeout reached while waiting for endpoint")
|
t.Fatal("timeout reached while waiting for endpoint")
|
||||||
return types.ErrInternal
|
return core.ErrInternal
|
||||||
}
|
}
|
||||||
// wait a little while between checks
|
// wait a little while between checks
|
||||||
time.Sleep(250 * time.Millisecond)
|
time.Sleep(250 * time.Millisecond)
|
||||||
@@ -178,7 +179,7 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
sessionId := "session-id" + add
|
sessionId := "session-id" + add
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
csrfToken := "my-verifying-token" + add
|
csrfToken := "my-verifying-token" + add
|
||||||
email := add + "mail@mail.de"
|
email := add + "mail@mail.de"
|
||||||
|
|
||||||
@@ -193,7 +194,7 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s
|
|||||||
|
|
||||||
_, err = db.ExecContext(context.Background(), `
|
_, err = db.ExecContext(context.Background(), `
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf)
|
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, auth_types.TokenTypeCsrf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return userId, csrfToken, sessionId
|
return userId, csrfToken, sessionId
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ package test_test
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"spend-sparrow/internal/service"
|
"spend-sparrow/internal/auth_types"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/authentication"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -110,7 +110,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
@@ -136,7 +136,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
@@ -163,7 +163,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
@@ -206,7 +206,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
@@ -247,7 +247,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
||||||
@@ -295,7 +295,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
||||||
@@ -414,7 +414,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
||||||
@@ -467,7 +467,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
@@ -550,7 +550,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), authentication.GetHashPassword("password", []byte("salt")), []byte("salt"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil)
|
||||||
@@ -631,7 +631,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
var token string
|
var token string
|
||||||
err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token)
|
err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", auth_types.TokenTypeEmailVerify).Scan(&token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, token)
|
assert.NotEmpty(t, token)
|
||||||
})
|
})
|
||||||
@@ -676,7 +676,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.ExecContext(ctx, `
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO token (token, user_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify)
|
VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, auth_types.TokenTypeEmailVerify)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil)
|
||||||
@@ -706,7 +706,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.ExecContext(ctx, `
|
_, err = db.ExecContext(ctx, `
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify)
|
VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, auth_types.TokenTypeEmailVerify)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil)
|
||||||
@@ -746,7 +746,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
sessionId := "session-id"
|
sessionId := "session-id"
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -765,7 +765,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var csrfToken string
|
var csrfToken string
|
||||||
err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken)
|
err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypeCsrf).Scan(&csrfToken)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil)
|
req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil)
|
||||||
@@ -824,7 +824,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -870,7 +870,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1039,7 +1039,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1078,7 +1078,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
@@ -1128,7 +1128,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
|
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
@@ -1180,7 +1180,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
userIdOther := uuid.New()
|
userIdOther := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1230,7 +1230,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt"))
|
pass = authentication.GetHashPassword("MyNewSecurePassword1!", []byte("salt"))
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1259,7 +1259,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
d, basePath, ctx := setupIntegrationTest(t)
|
d, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.ExecContext(ctx, `
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1287,7 +1287,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
d, basePath, ctx := setupIntegrationTest(t)
|
d, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.ExecContext(ctx, `
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1317,7 +1317,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows)
|
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypePasswordReset).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 0, rows)
|
assert.Equal(t, 0, rows)
|
||||||
})
|
})
|
||||||
@@ -1362,7 +1362,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
db, basePath, ctx := setupIntegrationTest(t)
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := db.ExecContext(ctx, `
|
_, err := db.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1399,7 +1399,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg)
|
assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg)
|
||||||
|
|
||||||
var rows int
|
var rows int
|
||||||
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows)
|
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypePasswordReset).Scan(&rows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, rows)
|
assert.Equal(t, 1, rows)
|
||||||
})
|
})
|
||||||
@@ -1412,7 +1412,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
d, basePath, ctx := setupIntegrationTest(t)
|
d, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.ExecContext(ctx, `
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1455,7 +1455,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
d, basePath, ctx := setupIntegrationTest(t)
|
d, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.ExecContext(ctx, `
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1475,7 +1475,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
token := "password-reset-token"
|
token := "password-reset-token"
|
||||||
_, err = d.ExecContext(ctx, `
|
_, err = d.ExecContext(ctx, `
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset)
|
VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", auth_types.TokenTypePasswordReset)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
formData := url.Values{
|
formData := url.Values{
|
||||||
@@ -1504,7 +1504,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
d, basePath, ctx := setupIntegrationTest(t)
|
d, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.ExecContext(ctx, `
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1524,7 +1524,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
token := "password-reset-token"
|
token := "password-reset-token"
|
||||||
_, err = d.ExecContext(ctx, `
|
_, err = d.ExecContext(ctx, `
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset)
|
VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", auth_types.TokenTypePasswordReset)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
formData := url.Values{
|
formData := url.Values{
|
||||||
@@ -1553,7 +1553,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
d, basePath, ctx := setupIntegrationTest(t)
|
d, basePath, ctx := setupIntegrationTest(t)
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
pass := authentication.GetHashPassword("password", []byte("salt"))
|
||||||
_, err := d.ExecContext(ctx, `
|
_, err := d.ExecContext(ctx, `
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
|
||||||
@@ -1590,7 +1590,7 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
var token string
|
var token string
|
||||||
err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token)
|
err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", auth_types.TokenTypePasswordReset).Scan(&token)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
formData = url.Values{
|
formData = url.Values{
|
||||||
|
|||||||
Reference in New Issue
Block a user