wip
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 1m8s

This commit is contained in:
2025-12-25 07:36:58 +01:00
parent 1c091dc924
commit aaddb84144
23 changed files with 379 additions and 366 deletions

View File

@@ -3,11 +3,11 @@ dir: mocks/
outpkg: mocks outpkg: mocks
issue-845-fix: True issue-845-fix: True
packages: packages:
spend-sparrow/internal/service: spend-sparrow/internal/core:
interfaces: interfaces:
Random: Random:
Clock: Clock:
Mail: Mail:
spend-sparrow/internal/db: spend-sparrow/internal/authentication:
interfaces: interfaces:
Auth: Db:

View File

@@ -4,29 +4,31 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"log/slog" "log/slog"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/service" "spend-sparrow/internal/service"
"spend-sparrow/internal/types"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
) )
type Service interface { type Service interface {
Add(ctx context.Context, user *types.User, name string) (*Account, error) Add(ctx context.Context, user *auth_types.User, name string) (*Account, error)
UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error) UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error)
Get(ctx context.Context, user *types.User, id string) (*Account, error) Get(ctx context.Context, user *auth_types.User, id string) (*Account, error)
GetAll(ctx context.Context, user *types.User) ([]*Account, error) GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error)
Delete(ctx context.Context, user *types.User, id string) error Delete(ctx context.Context, user *auth_types.User, id string) error
} }
type ServiceImpl struct { type ServiceImpl struct {
db *sqlx.DB db *sqlx.DB
clock service.Clock clock core.Clock
random service.Random random core.Random
} }
func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Service { func NewServiceImpl(db *sqlx.DB, random core.Random, clock core.Clock) Service {
return ServiceImpl{ return ServiceImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -34,14 +36,14 @@ func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Ser
} }
} }
func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*Account, error) { func (s ServiceImpl) Add(ctx context.Context, user *auth_types.User, name string) (*Account, error) {
if user == nil { if user == nil {
return nil, types.ErrUnauthorized return nil, core.ErrUnauthorized
} }
newId, err := s.random.UUID(ctx) newId, err := s.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
err = service.ValidateString(name, "name") err = service.ValidateString(name, "name")
@@ -76,9 +78,9 @@ func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*A
return account, nil return account, nil
} }
func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error) { func (s ServiceImpl) UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error) {
if user == nil { if user == nil {
return nil, types.ErrUnauthorized return nil, core.ErrUnauthorized
} }
err := service.ValidateString(name, "name") err := service.ValidateString(name, "name")
if err != nil { if err != nil {
@@ -87,7 +89,7 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "account update", "err", err) slog.ErrorContext(ctx, "account update", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -103,10 +105,10 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError(ctx, "account Update", nil, err) err = db.TransformAndLogDbError(ctx, "account Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", id, service.ErrBadRequest) return nil, fmt.Errorf("account %v not found: %w", id, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
timestamp := s.clock.Now() timestamp := s.clock.Now()
@@ -136,14 +138,14 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
return &account, nil return &account, nil
} }
func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Account, error) { func (s ServiceImpl) Get(ctx context.Context, user *auth_types.User, id string) (*Account, error) {
if user == nil { if user == nil {
return nil, service.ErrUnauthorized return nil, core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "account get", "err", err) slog.ErrorContext(ctx, "account get", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
var account Account var account Account
@@ -158,9 +160,9 @@ func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Acc
return &account, nil return &account, nil
} }
func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account, error) { func (s ServiceImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error) {
if user == nil { if user == nil {
return nil, service.ErrUnauthorized return nil, core.ErrUnauthorized
} }
accounts := make([]*Account, 0) accounts := make([]*Account, 0)
@@ -174,14 +176,14 @@ func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account,
return accounts, nil return accounts, nil
} }
func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) error { func (s ServiceImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
if user == nil { if user == nil {
return service.ErrUnauthorized return core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "account delete", "err", err) slog.ErrorContext(ctx, "account delete", "err", err)
return fmt.Errorf("could not parse Id: %w", service.ErrBadRequest) return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -200,7 +202,7 @@ func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) er
return err return err
} }
if transactionsCount > 0 { if transactionsCount > 0 {
return fmt.Errorf("account has transactions, cannot delete: %w", service.ErrBadRequest) return fmt.Errorf("account has transactions, cannot delete: %w", core.ErrBadRequest)
} }
res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id) res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)

View File

@@ -5,9 +5,9 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"net/url" "net/url"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/authentication/template" "spend-sparrow/internal/authentication/template"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/types"
"spend-sparrow/internal/utils" "spend-sparrow/internal/utils"
"time" "time"
) )
@@ -79,7 +79,7 @@ func (handler HandlerImpl) handleSignIn() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
core.UpdateSpan(r) core.UpdateSpan(r)
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*User, error) { user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*auth_types.User, error) {
session := core.GetSession(r) session := core.GetSession(r)
email := r.FormValue("email") email := r.FormValue("email")
password := r.FormValue("password") password := r.FormValue("password")
@@ -89,14 +89,14 @@ func (handler HandlerImpl) handleSignIn() http.HandlerFunc {
return nil, err return nil, err
} }
cookie := middleware.CreateSessionCookie(session.Id) cookie := core.CreateSessionCookie(session.Id)
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
return user, nil return user, nil
}) })
if err != nil { if err != nil {
if errors.Is(err, service.ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) {
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Invalid email or password", http.StatusUnauthorized) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Invalid email or password", http.StatusUnauthorized)
} else { } else {
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
@@ -127,7 +127,7 @@ func (handler HandlerImpl) handleSignUpPage() http.HandlerFunc {
return return
} }
signUpComp := auth.SignInOrUpComp(false) signUpComp := template.SignInOrUpComp(false)
handler.render.RenderLayout(r, w, signUpComp, nil) handler.render.RenderLayout(r, w, signUpComp, nil)
} }
} }
@@ -147,7 +147,7 @@ func (handler HandlerImpl) handleSignUpVerifyPage() http.HandlerFunc {
return return
} }
signIn := auth.VerifyComp() signIn := template.VerifyComp()
handler.render.RenderLayout(r, w, signIn, user) handler.render.RenderLayout(r, w, signIn, user)
} }
} }
@@ -180,7 +180,7 @@ func (handler HandlerImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
err := handler.service.VerifyUserEmail(r.Context(), token) err := handler.service.VerifyUserEmail(r.Context(), token)
isVerified := err == nil isVerified := err == nil
comp := auth.VerifyResponseComp(isVerified) comp := template.VerifyResponseComp(isVerified)
var status int var status int
if isVerified { if isVerified {
@@ -214,14 +214,14 @@ func (handler HandlerImpl) handleSignUp() http.HandlerFunc {
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, types.ErrInternal): case errors.Is(err, core.ErrInternal):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
return return
case errors.Is(err, service.ErrInvalidEmail): case errors.Is(err, ErrInvalidEmail):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "The email provided is invalid", http.StatusBadRequest) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "The email provided is invalid", http.StatusBadRequest)
return return
case errors.Is(err, service.ErrInvalidPassword): case errors.Is(err, ErrInvalidPassword):
utils.TriggerToastWithStatus(r.Context(), w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest) utils.TriggerToastWithStatus(r.Context(), w, r, "error", ErrInvalidPassword.Error(), http.StatusBadRequest)
return return
} }
// If err is "service.ErrAccountExists", then just continue // If err is "service.ErrAccountExists", then just continue
@@ -270,7 +270,7 @@ func (handler HandlerImpl) handleDeleteAccountPage() http.HandlerFunc {
return return
} }
comp := auth.DeleteAccountComp() comp := template.DeleteAccountComp()
handler.render.RenderLayout(r, w, comp, user) handler.render.RenderLayout(r, w, comp, user)
} }
} }
@@ -289,7 +289,7 @@ func (handler HandlerImpl) handleDeleteAccountComp() http.HandlerFunc {
err := handler.service.DeleteAccount(r.Context(), user, password) err := handler.service.DeleteAccount(r.Context(), user, password)
if err != nil { if err != nil {
if errors.Is(err, service.ErrInvalidCredentials) { if errors.Is(err, ErrInvalidCredentials) {
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Password not correct", http.StatusBadRequest) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Password not correct", http.StatusBadRequest)
} else { } else {
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Internal Server Error", http.StatusInternalServerError) utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Internal Server Error", http.StatusInternalServerError)
@@ -314,7 +314,7 @@ func (handler HandlerImpl) handleChangePasswordPage() http.HandlerFunc {
return return
} }
comp := auth.ChangePasswordComp(isPasswordReset) comp := template.ChangePasswordComp(isPasswordReset)
handler.render.RenderLayout(r, w, comp, user) handler.render.RenderLayout(r, w, comp, user)
} }
} }
@@ -353,7 +353,7 @@ func (handler HandlerImpl) handleForgotPasswordPage() http.HandlerFunc {
return return
} }
comp := auth.ResetPasswordComp() comp := template.ResetPasswordComp()
handler.render.RenderLayout(r, w, comp, user) handler.render.RenderLayout(r, w, comp, user)
} }
} }

View File

@@ -8,7 +8,6 @@ import (
"net/mail" "net/mail"
"spend-sparrow/internal/auth_types" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/db"
mailTemplate "spend-sparrow/internal/template/mail" mailTemplate "spend-sparrow/internal/template/mail"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"strings" "strings"
@@ -53,13 +52,13 @@ type Service interface {
type ServiceImpl struct { type ServiceImpl struct {
db Db db Db
random core.Random random core.Random
clock Clock clock core.Clock
mail Mail mail core.Mail
serverSettings *types.Settings serverSettings *types.Settings
} }
func NewService(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *types.Settings) *HandlerImpl { func NewService(db Db, random core.Random, clock core.Clock, mail core.Mail, serverSettings *types.Settings) *ServiceImpl {
return &HandlerImpl{ return &ServiceImpl{
db: db, db: db,
random: random, random: random,
clock: clock, clock: clock,
@@ -68,13 +67,13 @@ func NewService(db db.Auth, random Random, clock Clock, mail Mail, serverSetting
} }
} }
func (service HandlerImpl) SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_type.Session, *auth_types.User, error) { func (service ServiceImpl) SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_types.Session, *auth_types.User, error) {
user, err := service.db.GetUserByEmail(ctx, email) user, err := service.db.GetUserByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, nil, ErrInvalidCredentials return nil, nil, ErrInvalidCredentials
} else { } else {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
} }
@@ -86,36 +85,36 @@ func (service HandlerImpl) SignIn(ctx context.Context, session *auth_types.Sessi
newSession, err := service.createSession(ctx, user.Id) newSession, err := service.createSession(ctx, user.Id)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
err = service.db.DeleteSession(ctx, session.Id) err = service.db.DeleteSession(ctx, session.Id)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
for _, token := range tokens { for _, token := range tokens {
err = service.db.DeleteToken(ctx, token.Token) err = service.db.DeleteToken(ctx, token.Token)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
} }
return newSession, user, nil return newSession, user, nil
} }
func (service HandlerImpl) SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error) { func (service ServiceImpl) SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error) {
if sessionId == "" { if sessionId == "" {
return nil, nil, ErrSessionIdInvalid return nil, nil, ErrSessionIdInvalid
} }
session, err := service.db.GetSession(ctx, sessionId) session, err := service.db.GetSession(ctx, sessionId)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
if session.ExpiresAt.Before(service.clock.Now()) { if session.ExpiresAt.Before(service.clock.Now()) {
_ = service.db.DeleteSession(ctx, sessionId) _ = service.db.DeleteSession(ctx, sessionId)
@@ -128,16 +127,16 @@ func (service HandlerImpl) SignInSession(ctx context.Context, sessionId string)
user, err := service.db.GetUser(ctx, session.UserId) user, err := service.db.GetUser(ctx, session.UserId)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, core.ErrInternal
} }
return session, user, nil return session, user, nil
} }
func (service HandlerImpl) SignInAnonymous(ctx context.Context) (*auth_types.Session, error) { func (service ServiceImpl) SignInAnonymous(ctx context.Context) (*auth_types.Session, error) {
session, err := service.createSession(ctx, uuid.Nil) session, err := service.createSession(ctx, uuid.Nil)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
slog.InfoContext(ctx, "anonymous session created", "session-id", session.Id) slog.InfoContext(ctx, "anonymous session created", "session-id", session.Id)
@@ -145,7 +144,7 @@ func (service HandlerImpl) SignInAnonymous(ctx context.Context) (*auth_types.Ses
return session, nil return session, nil
} }
func (service HandlerImpl) SignUp(ctx context.Context, email string, password string) (*auth_types.User, error) { func (service ServiceImpl) SignUp(ctx context.Context, email string, password string) (*auth_types.User, error) {
_, err := mail.ParseAddress(email) _, err := mail.ParseAddress(email)
if err != nil { if err != nil {
return nil, ErrInvalidEmail return nil, ErrInvalidEmail
@@ -157,37 +156,37 @@ func (service HandlerImpl) SignUp(ctx context.Context, email string, password st
userId, err := service.random.UUID(ctx) userId, err := service.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
salt, err := service.random.Bytes(ctx, 16) salt, err := service.random.Bytes(ctx, 16)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
hash := GetHashPassword(password, salt) hash := GetHashPassword(password, salt)
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) user := auth_types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
err = service.db.InsertUser(ctx, user) err = service.db.InsertUser(ctx, user)
if err != nil { if err != nil {
if errors.Is(err, db.ErrAlreadyExists) { if errors.Is(err, core.ErrAlreadyExists) {
return nil, ErrAccountExists return nil, ErrAccountExists
} else { } else {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
} }
return user, nil return user, nil
} }
func (service HandlerImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) { func (service ServiceImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify) tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, auth_types.TokenTypeEmailVerify)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return return
} }
var token *types.Token var token *auth_types.Token
if len(tokens) > 0 { if len(tokens) > 0 {
token = tokens[0] token = tokens[0]
@@ -199,11 +198,11 @@ func (service HandlerImpl) SendVerificationMail(ctx context.Context, userId uuid
return return
} }
token = types.NewToken( token = auth_types.NewToken(
userId, userId,
"", "",
newTokenStr, newTokenStr,
types.TokenTypeEmailVerify, auth_types.TokenTypeEmailVerify,
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(24*time.Hour)) service.clock.Now().Add(24*time.Hour))
@@ -223,29 +222,29 @@ func (service HandlerImpl) SendVerificationMail(ctx context.Context, userId uuid
service.mail.SendMail(ctx, email, "Welcome to spend-sparrow", w.String()) service.mail.SendMail(ctx, email, "Welcome to spend-sparrow", w.String())
} }
func (service HandlerImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error { func (service ServiceImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
if tokenStr == "" { if tokenStr == "" {
return types.ErrInternal return core.ErrInternal
} }
token, err := service.db.GetToken(ctx, tokenStr) token, err := service.db.GetToken(ctx, tokenStr)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
user, err := service.db.GetUser(ctx, token.UserId) user, err := service.db.GetUser(ctx, token.UserId)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
if token.Type != types.TokenTypeEmailVerify { if token.Type != auth_types.TokenTypeEmailVerify {
return types.ErrInternal return core.ErrInternal
} }
now := service.clock.Now() now := service.clock.Now()
if token.ExpiresAt.Before(now) { if token.ExpiresAt.Before(now) {
return types.ErrInternal return core.ErrInternal
} }
user.EmailVerified = true user.EmailVerified = true
@@ -253,21 +252,21 @@ func (service HandlerImpl) VerifyUserEmail(ctx context.Context, tokenStr string)
err = service.db.UpdateUser(ctx, user) err = service.db.UpdateUser(ctx, user)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
_ = service.db.DeleteToken(ctx, token.Token) _ = service.db.DeleteToken(ctx, token.Token)
return nil return nil
} }
func (service HandlerImpl) SignOut(ctx context.Context, sessionId string) error { func (service ServiceImpl) SignOut(ctx context.Context, sessionId string) error {
return service.db.DeleteSession(ctx, sessionId) return service.db.DeleteSession(ctx, sessionId)
} }
func (service HandlerImpl) DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error { func (service ServiceImpl) DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error {
userDb, err := service.db.GetUser(ctx, user.Id) userDb, err := service.db.GetUser(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
currHash := GetHashPassword(currPass, userDb.Salt) currHash := GetHashPassword(currPass, userDb.Salt)
@@ -285,7 +284,7 @@ func (service HandlerImpl) DeleteAccount(ctx context.Context, user *auth_types.U
return nil return nil
} }
func (service HandlerImpl) ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error { func (service ServiceImpl) ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
} }
@@ -310,13 +309,13 @@ func (service HandlerImpl) ChangePassword(ctx context.Context, user *auth_types.
sessions, err := service.db.GetSessions(ctx, user.Id) sessions, err := service.db.GetSessions(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
for _, s := range sessions { for _, s := range sessions {
if s.Id != sessionId { if s.Id != sessionId {
err = service.db.DeleteSession(ctx, s.Id) err = service.db.DeleteSession(ctx, s.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
} }
} }
@@ -324,7 +323,7 @@ func (service HandlerImpl) ChangePassword(ctx context.Context, user *auth_types.
return nil return nil
} }
func (service HandlerImpl) SendForgotPasswordMail(ctx context.Context, email string) error { func (service ServiceImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
tokenStr, err := service.random.String(ctx, 32) tokenStr, err := service.random.String(ctx, 32)
if err != nil { if err != nil {
return err return err
@@ -332,38 +331,38 @@ func (service HandlerImpl) SendForgotPasswordMail(ctx context.Context, email str
user, err := service.db.GetUserByEmail(ctx, email) user, err := service.db.GetUserByEmail(ctx, email)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil return nil
} else { } else {
return types.ErrInternal return core.ErrInternal
} }
} }
token := types.NewToken( token := auth_types.NewToken(
user.Id, user.Id,
"", "",
tokenStr, tokenStr,
types.TokenTypePasswordReset, auth_types.TokenTypePasswordReset,
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(15*time.Minute)) service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(ctx, token) err = service.db.InsertToken(ctx, token)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
var mail strings.Builder var mail strings.Builder
err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail) err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not render reset password email", "err", err) slog.ErrorContext(ctx, "Could not render reset password email", "err", err)
return types.ErrInternal return core.ErrInternal
} }
service.mail.SendMail(ctx, email, "Reset Password", mail.String()) service.mail.SendMail(ctx, email, "Reset Password", mail.String())
return nil return nil
} }
func (service HandlerImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error { func (service ServiceImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
} }
@@ -378,7 +377,7 @@ func (service HandlerImpl) ForgotPassword(ctx context.Context, tokenStr string,
return err return err
} }
if token.Type != types.TokenTypePasswordReset || if token.Type != auth_types.TokenTypePasswordReset ||
token.ExpiresAt.Before(service.clock.Now()) { token.ExpiresAt.Before(service.clock.Now()) {
return ErrTokenInvalid return ErrTokenInvalid
} }
@@ -386,7 +385,7 @@ func (service HandlerImpl) ForgotPassword(ctx context.Context, tokenStr string,
user, err := service.db.GetUser(ctx, token.UserId) user, err := service.db.GetUser(ctx, token.UserId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "Could not get user from token", "err", err) slog.ErrorContext(ctx, "Could not get user from token", "err", err)
return types.ErrInternal return core.ErrInternal
} }
passHash := GetHashPassword(newPass, user.Salt) passHash := GetHashPassword(newPass, user.Salt)
@@ -399,26 +398,26 @@ func (service HandlerImpl) ForgotPassword(ctx context.Context, tokenStr string,
sessions, err := service.db.GetSessions(ctx, user.Id) sessions, err := service.db.GetSessions(ctx, user.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
for _, session := range sessions { for _, session := range sessions {
err = service.db.DeleteSession(ctx, session.Id) err = service.db.DeleteSession(ctx, session.Id)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
} }
return nil return nil
} }
func (service HandlerImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool { func (service ServiceImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
token, err := service.db.GetToken(ctx, tokenStr) token, err := service.db.GetToken(ctx, tokenStr)
if err != nil { if err != nil {
return false return false
} }
if token.Type != types.TokenTypeCsrf || if token.Type != auth_types.TokenTypeCsrf ||
token.SessionId != sessionId || token.SessionId != sessionId ||
token.ExpiresAt.Before(service.clock.Now()) { token.ExpiresAt.Before(service.clock.Now()) {
return false return false
@@ -427,12 +426,12 @@ func (service HandlerImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string
return true return true
} }
func (service HandlerImpl) GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error) { func (service ServiceImpl) GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error) {
if session == nil { if session == nil {
return "", types.ErrInternal return "", core.ErrInternal
} }
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf)
if len(tokens) > 0 { if len(tokens) > 0 {
return tokens[0].Token, nil return tokens[0].Token, nil
@@ -440,19 +439,19 @@ func (service HandlerImpl) GetCsrfToken(ctx context.Context, session *auth_types
tokenStr, err := service.random.String(ctx, 32) tokenStr, err := service.random.String(ctx, 32)
if err != nil { if err != nil {
return "", types.ErrInternal return "", core.ErrInternal
} }
token := types.NewToken( token := auth_types.NewToken(
session.UserId, session.UserId,
session.Id, session.Id,
tokenStr, tokenStr,
types.TokenTypeCsrf, auth_types.TokenTypeCsrf,
service.clock.Now(), service.clock.Now(),
service.clock.Now().Add(8*time.Hour)) service.clock.Now().Add(8*time.Hour))
err = service.db.InsertToken(ctx, token) err = service.db.InsertToken(ctx, token)
if err != nil { if err != nil {
return "", types.ErrInternal return "", core.ErrInternal
} }
slog.InfoContext(ctx, "CSRF-Token created", "token", tokenStr) slog.InfoContext(ctx, "CSRF-Token created", "token", tokenStr)
@@ -460,34 +459,34 @@ func (service HandlerImpl) GetCsrfToken(ctx context.Context, session *auth_types
return tokenStr, nil return tokenStr, nil
} }
func (service HandlerImpl) CleanupSessionsAndTokens(ctx context.Context) error { func (service ServiceImpl) CleanupSessionsAndTokens(ctx context.Context) error {
err := service.db.DeleteOldSessions(ctx) err := service.db.DeleteOldSessions(ctx)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
err = service.db.DeleteOldTokens(ctx) err = service.db.DeleteOldTokens(ctx)
if err != nil { if err != nil {
return types.ErrInternal return core.ErrInternal
} }
return nil return nil
} }
func (service HandlerImpl) createSession(ctx context.Context, userId uuid.UUID) (*auth_types.Session, error) { func (service ServiceImpl) createSession(ctx context.Context, userId uuid.UUID) (*auth_types.Session, error) {
sessionId, err := service.random.String(ctx, 32) sessionId, err := service.random.String(ctx, 32)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
createAt := service.clock.Now() createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
session := types.NewSession(sessionId, userId, createAt, expiresAt) session := auth_types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(ctx, session) err = service.db.InsertSession(ctx, session)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return session, nil return session, nil

View File

@@ -1,4 +1,4 @@
package service package core
import "time" import "time"

View File

@@ -1,4 +1,4 @@
package service package core
import ( import (
"context" "context"

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"os/signal" "os/signal"
"spend-sparrow/internal/account" "spend-sparrow/internal/account"
"spend-sparrow/internal/authentication"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/handler" "spend-sparrow/internal/handler"
@@ -107,13 +108,13 @@ func shutdownServer(ctx context.Context, s *http.Server, wg *sync.WaitGroup) {
func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler { func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
var router = http.NewServeMux() var router = http.NewServeMux()
authDb := db.NewAuthSqlite(d) authDb := authentication.NewDbSqlite(d)
randomService := service.NewRandom() randomService := core.NewRandom()
clockService := service.NewClock() clockService := core.NewClock()
mailService := service.NewMail(serverSettings) mailService := core.NewMail(serverSettings)
authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings) authService := authentication.NewService(authDb, randomService, clockService, mailService, serverSettings)
accountService := account.NewServiceImpl(d, randomService, clockService) accountService := account.NewServiceImpl(d, randomService, clockService)
treasureChestService := service.NewTreasureChest(d, randomService, clockService) treasureChestService := service.NewTreasureChest(d, randomService, clockService)
transactionService := service.NewTransaction(d, randomService, clockService) transactionService := service.NewTransaction(d, randomService, clockService)
@@ -123,7 +124,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *
render := core.NewRender() render := core.NewRender()
indexHandler := handler.NewIndex(render, clockService) indexHandler := handler.NewIndex(render, clockService)
dashboardHandler := handler.NewDashboard(render, dashboardService, treasureChestService) dashboardHandler := handler.NewDashboard(render, dashboardService, treasureChestService)
authHandler := handler.NewAuth(authService, render) authHandler := authentication.NewHandler(authService, render)
accountHandler := account.NewHandler(accountService, render) accountHandler := account.NewHandler(accountService, render)
treasureChestHandler := handler.NewTreasureChest(treasureChestService, transactionRecurringService, render) treasureChestHandler := handler.NewTreasureChest(treasureChestService, transactionRecurringService, render)
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render) transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
@@ -157,7 +158,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *
return wrapper return wrapper
} }
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) { func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) {
runDailyTasks(ctx, transactionRecurring, auth) runDailyTasks(ctx, transactionRecurring, auth)
ticker := time.NewTicker(24 * time.Hour) ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop() defer ticker.Stop()
@@ -172,7 +173,7 @@ func dailyTaskTimer(ctx context.Context, transactionRecurring service.Transactio
} }
} }
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) { func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) {
slog.InfoContext(ctx, "Running daily tasks") slog.InfoContext(ctx, "Running daily tasks")
_ = transactionRecurring.GenerateTransactions(ctx) _ = transactionRecurring.GenerateTransactions(ctx)
_ = auth.CleanupSessionsAndTokens(ctx) _ = auth.CleanupSessionsAndTokens(ctx)

View File

@@ -193,7 +193,7 @@ func (handler DashboardImpl) handleDashboardTreasureChest() http.HandlerFunc {
if treasureChestStr != "" { if treasureChestStr != "" {
id, err := uuid.Parse(treasureChestStr) id, err := uuid.Parse(treasureChestStr)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", core.ErrBadRequest))
return return
} }

View File

@@ -3,12 +3,12 @@ package middleware
import ( import (
"context" "context"
"net/http" "net/http"
"spend-sparrow/internal/authentication"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/service"
"strings" "strings"
) )
func Authenticate(service service.Auth) func(http.Handler) http.Handler { func Authenticate(service authentication.Service) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
@@ -31,7 +31,7 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
return return
} }
cookie := CreateSessionCookie(session.Id) cookie := core.CreateSessionCookie(session.Id)
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
} }

View File

@@ -3,8 +3,8 @@ package middleware
import ( import (
"log/slog" "log/slog"
"net/http" "net/http"
"spend-sparrow/internal/authentication"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/service"
"spend-sparrow/internal/utils" "spend-sparrow/internal/utils"
"strings" "strings"
) )
@@ -31,7 +31,7 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
return rr.ResponseWriter.Write([]byte(dataStr)) return rr.ResponseWriter.Write([]byte(dataStr))
} }
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { func CrossSiteRequestForgery(auth authentication.Service) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()

View File

@@ -3,7 +3,6 @@ package handler
import ( import (
"net/http" "net/http"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/service"
"spend-sparrow/internal/template" "spend-sparrow/internal/template"
"spend-sparrow/internal/utils" "spend-sparrow/internal/utils"
@@ -16,10 +15,10 @@ type Index interface {
type IndexImpl struct { type IndexImpl struct {
r *core.Render r *core.Render
c service.Clock c core.Clock
} }
func NewIndex(r *core.Render, c service.Clock) Index { func NewIndex(r *core.Render, c core.Clock) Index {
return IndexImpl{ return IndexImpl{
r: r, r: r,
c: c, c: c,

View File

@@ -157,7 +157,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
if idStr != "new" { if idStr != "new" {
id, err = uuid.Parse(idStr) id, err = uuid.Parse(idStr)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest))
return return
} }
} }
@@ -167,7 +167,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
if accountIdStr != "" { if accountIdStr != "" {
i, err := uuid.Parse(accountIdStr) i, err := uuid.Parse(accountIdStr)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", core.ErrBadRequest))
return return
} }
accountId = &i accountId = &i
@@ -178,7 +178,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
if treasureChestIdStr != "" { if treasureChestIdStr != "" {
i, err := uuid.Parse(treasureChestIdStr) i, err := uuid.Parse(treasureChestIdStr)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", core.ErrBadRequest))
return return
} }
treasureChestId = &i treasureChestId = &i
@@ -186,14 +186,14 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
valueF, err := strconv.ParseFloat(r.FormValue("value"), 64) valueF, err := strconv.ParseFloat(r.FormValue("value"), 64)
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse value: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse value: %w", core.ErrBadRequest))
return return
} }
value := int64(math.Round(valueF * service.DECIMALS_MULTIPLIER)) value := int64(math.Round(valueF * service.DECIMALS_MULTIPLIER))
timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp")) timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp"))
if err != nil { if err != nil {
core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", service.ErrBadRequest)) core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest))
return return
} }

View File

@@ -2,6 +2,7 @@ package handler
import ( import (
"net/http" "net/http"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core" "spend-sparrow/internal/core"
"spend-sparrow/internal/service" "spend-sparrow/internal/service"
t "spend-sparrow/internal/template/transaction_recurring" t "spend-sparrow/internal/template/transaction_recurring"
@@ -111,7 +112,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
} }
} }
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) { func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *auth_types.User, id, accountId, treasureChestId string) {
var transactionsRecurring []*types.TransactionRecurring var transactionsRecurring []*types.TransactionRecurring
var err error var err error
if accountId == "" && treasureChestId == "" { if accountId == "" && treasureChestId == "" {

View File

@@ -2,6 +2,8 @@ package service
import ( import (
"context" "context"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"time" "time"
@@ -22,10 +24,10 @@ func NewDashboard(db *sqlx.DB) *Dashboard {
func (s Dashboard) MainChart( func (s Dashboard) MainChart(
ctx context.Context, ctx context.Context,
user *types.User, user *auth_types.User,
) ([]types.DashboardMainChartEntry, error) { ) ([]types.DashboardMainChartEntry, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
transactions := make([]types.Transaction, 0) transactions := make([]types.Transaction, 0)
@@ -82,10 +84,10 @@ func (s Dashboard) MainChart(
func (s Dashboard) TreasureChests( func (s Dashboard) TreasureChests(
ctx context.Context, ctx context.Context,
user *types.User, user *auth_types.User,
) ([]*types.DashboardTreasureChest, error) { ) ([]*types.DashboardTreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
treasureChests := make([]*types.TreasureChest, 0) treasureChests := make([]*types.TreasureChest, 0)
@@ -120,11 +122,11 @@ func (s Dashboard) TreasureChests(
func (s Dashboard) TreasureChest( func (s Dashboard) TreasureChest(
ctx context.Context, ctx context.Context,
user *types.User, user *auth_types.User,
treausureChestId *uuid.UUID, treausureChestId *uuid.UUID,
) ([]types.DashboardMainChartEntry, error) { ) ([]types.DashboardMainChartEntry, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
transactions := make([]types.Transaction, 0) transactions := make([]types.Transaction, 0)

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"fmt" "fmt"
"regexp" "regexp"
"spend-sparrow/internal/core"
) )
const ( const (
@@ -16,9 +17,9 @@ var (
func ValidateString(value string, fieldName string) error { func ValidateString(value string, fieldName string) error {
switch { switch {
case value == "": case value == "":
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest) return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, core.ErrBadRequest)
case !safeInputRegex.MatchString(value): case !safeInputRegex.MatchString(value):
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest) return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, core.ErrBadRequest)
default: default:
return nil return nil
} }

View File

@@ -5,7 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
"spend-sparrow/internal/authentication" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"strconv" "strconv"
@@ -18,22 +19,22 @@ import (
const page_size = 25 const page_size = 25
type Transaction interface { type Transaction interface {
Add(ctx context.Context, tx *sqlx.Tx, user *authentication.User, transaction types.Transaction) (*types.Transaction, error) Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error)
Update(ctx context.Context, user *authentication.User, transaction types.Transaction) (*types.Transaction, error) Update(ctx context.Context, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error)
Get(ctx context.Context, user *authentication.User, id string) (*types.Transaction, error) Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error)
GetAll(ctx context.Context, user *authentication.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
Delete(ctx context.Context, user *authentication.User, id string) error Delete(ctx context.Context, user *auth_types.User, id string) error
RecalculateBalances(ctx context.Context, user *authentication.User) error RecalculateBalances(ctx context.Context, user *auth_types.User) error
} }
type TransactionImpl struct { type TransactionImpl struct {
db *sqlx.DB db *sqlx.DB
clock Clock clock core.Clock
random Random random core.Random
} }
func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction { func NewTransaction(db *sqlx.DB, random core.Random, clock core.Clock) Transaction {
return TransactionImpl{ return TransactionImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -41,9 +42,9 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
} }
} }
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) { func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transactionInput types.Transaction) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
var err error var err error
@@ -108,9 +109,9 @@ func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User,
return transaction, nil return transaction, nil
} }
func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) { func (s TransactionImpl) Update(ctx context.Context, user *auth_types.User, input types.Transaction) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -126,10 +127,10 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ
err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id) err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
err = db.TransformAndLogDbError(ctx, "transaction Update", nil, err) err = db.TransformAndLogDbError(ctx, "transaction Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest) return nil, fmt.Errorf("transaction %v not found: %w", input.Id, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
if transaction.Error == nil && transaction.AccountId != nil { if transaction.Error == nil && transaction.AccountId != nil {
@@ -207,32 +208,32 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ
return transaction, nil return transaction, nil
} }
func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) { func (s TransactionImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transaction get", "err", err) slog.ErrorContext(ctx, "transaction get", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
var transaction types.Transaction var transaction types.Transaction
err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError(ctx, "transaction Get", nil, err) err = db.TransformAndLogDbError(ctx, "transaction Get", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest) return nil, fmt.Errorf("transaction %v not found: %w", id, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return &transaction, nil return &transaction, nil
} }
func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) { func (s TransactionImpl) GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
var ( var (
@@ -278,14 +279,14 @@ func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter ty
return transactions, nil return transactions, nil
} }
func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error { func (s TransactionImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
if user == nil { if user == nil {
return ErrUnauthorized return core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transaction delete", "err", err) slog.ErrorContext(ctx, "transaction delete", "err", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -311,7 +312,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
WHERE id = ? WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err) err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
} }
@@ -323,7 +324,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
WHERE id = ? WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err) err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
} }
@@ -343,9 +344,9 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
return nil return nil
} }
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error { func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *auth_types.User) error {
if user == nil { if user == nil {
return ErrUnauthorized return core.ErrUnauthorized
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -362,7 +363,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
SET current_balance = 0 SET current_balance = 0
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
@@ -371,7 +372,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
SET current_balance = 0 SET current_balance = 0
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
@@ -380,7 +381,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
FROM "transaction" FROM "transaction"
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", nil, err) err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", nil, err)
if err != nil && !errors.Is(err, db.ErrNotFound) { if err != nil && !errors.Is(err, core.ErrNotFound) {
return err return err
} }
defer func() { defer func() {
@@ -459,7 +460,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
if oldTransaction == nil { if oldTransaction == nil {
id, err = s.random.UUID(ctx) id, err = s.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
createdAt = s.clock.Now() createdAt = s.clock.Now()
createdBy = userId createdBy = userId
@@ -480,7 +481,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
} }
if rowCount == 0 { if rowCount == 0 {
slog.ErrorContext(ctx, "transaction validate", "err", err) slog.ErrorContext(ctx, "transaction validate", "err", err)
return nil, fmt.Errorf("account not found: %w", ErrBadRequest) return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest)
} }
} }
@@ -489,13 +490,13 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId) err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
err = db.TransformAndLogDbError(ctx, "transaction validate", nil, err) err = db.TransformAndLogDbError(ctx, "transaction validate", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest)
} }
return nil, err return nil, err
} }
if treasureChest.ParentId == nil { if treasureChest.ParentId == nil {
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest)
} }
} }

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"strconv" "strconv"
@@ -16,24 +18,24 @@ import (
) )
type TransactionRecurring interface { type TransactionRecurring interface {
Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Add(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Update(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error)
GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error)
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) GetAllByTreasureChest(ctx context.Context, user *auth_types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(ctx context.Context, user *types.User, id string) error Delete(ctx context.Context, user *auth_types.User, id string) error
GenerateTransactions(ctx context.Context) error GenerateTransactions(ctx context.Context) error
} }
type TransactionRecurringImpl struct { type TransactionRecurringImpl struct {
db *sqlx.DB db *sqlx.DB
clock Clock clock core.Clock
random Random random core.Random
transaction Transaction transaction Transaction
} }
func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transaction Transaction) TransactionRecurring { func NewTransactionRecurring(db *sqlx.DB, random core.Random, clock core.Clock, transaction Transaction) TransactionRecurring {
return TransactionRecurringImpl{ return TransactionRecurringImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -43,11 +45,11 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio
} }
func (s TransactionRecurringImpl) Add(ctx context.Context, func (s TransactionRecurringImpl) Add(ctx context.Context,
user *types.User, user *auth_types.User,
transactionRecurringInput types.TransactionRecurringInput, transactionRecurringInput types.TransactionRecurringInput,
) (*types.TransactionRecurring, error) { ) (*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -85,16 +87,16 @@ func (s TransactionRecurringImpl) Add(ctx context.Context,
} }
func (s TransactionRecurringImpl) Update(ctx context.Context, func (s TransactionRecurringImpl) Update(ctx context.Context,
user *types.User, user *auth_types.User,
input types.TransactionRecurringInput, input types.TransactionRecurringInput,
) (*types.TransactionRecurring, error) { ) (*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
uuid, err := uuid.Parse(input.Id) uuid, err := uuid.Parse(input.Id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring update", "err", err) slog.ErrorContext(ctx, "transactionRecurring update", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -110,10 +112,10 @@ func (s TransactionRecurringImpl) Update(ctx context.Context,
err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError(ctx, "transactionRecurring Update", nil, err) err = db.TransformAndLogDbError(ctx, "transactionRecurring Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest) return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input) transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
@@ -149,9 +151,9 @@ func (s TransactionRecurringImpl) Update(ctx context.Context,
return transactionRecurring, nil return transactionRecurring, nil
} }
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
@@ -169,15 +171,15 @@ func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User)
return transactionRecurrings, nil return transactionRecurrings, nil
} }
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
accountUuid, err := uuid.Parse(accountId) accountUuid, err := uuid.Parse(accountId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring GetAllByAccount", "err", err) slog.ErrorContext(ctx, "transactionRecurring GetAllByAccount", "err", err)
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -193,10 +195,10 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByAccount", nil, err) err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByAccount", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest) return nil, fmt.Errorf("account %v not found: %w", accountId, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
@@ -222,17 +224,17 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ
} }
func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context, func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
user *types.User, user *auth_types.User,
treasureChestId string, treasureChestId string,
) ([]*types.TransactionRecurring, error) { ) ([]*types.TransactionRecurring, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
treasureChestUuid, err := uuid.Parse(treasureChestId) treasureChestUuid, err := uuid.Parse(treasureChestId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring GetAllByTreasureChest", "err", err) slog.ErrorContext(ctx, "transactionRecurring GetAllByTreasureChest", "err", err)
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -248,10 +250,10 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByTreasureChest", nil, err) err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest) return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, core.ErrBadRequest)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
transactionRecurrings := make([]*types.TransactionRecurring, 0) transactionRecurrings := make([]*types.TransactionRecurring, 0)
@@ -276,14 +278,14 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
return transactionRecurrings, nil return transactionRecurrings, nil
} }
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error { func (s TransactionRecurringImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
if user == nil { if user == nil {
return ErrUnauthorized return core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring delete", "err", err) slog.ErrorContext(ctx, "transactionRecurring delete", "err", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -339,7 +341,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) erro
} }
for _, transactionRecurring := range recurringTransactions { for _, transactionRecurring := range recurringTransactions {
user := &types.User{ user := &auth_types.User{
Id: transactionRecurring.UserId, Id: transactionRecurring.UserId,
} }
transaction := types.Transaction{ transaction := types.Transaction{
@@ -397,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
if oldTransactionRecurring == nil { if oldTransactionRecurring == nil {
id, err = s.random.UUID(ctx) id, err = s.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
createdAt = s.clock.Now() createdAt = s.clock.Now()
createdBy = userId createdBy = userId
@@ -416,7 +418,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
temp, err := uuid.Parse(input.AccountId) temp, err := uuid.Parse(input.AccountId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest)
} }
accountUuid = &temp accountUuid = &temp
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId) err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
@@ -426,7 +428,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
} }
if rowCount == 0 { if rowCount == 0 {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("account not found: %w", ErrBadRequest) return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest)
} }
hasAccount = true hasAccount = true
@@ -436,37 +438,37 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
temp, err := uuid.Parse(input.TreasureChestId) temp, err := uuid.Parse(input.TreasureChestId)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest)
} }
treasureChestUuid = &temp treasureChestUuid = &temp
var treasureChest types.TreasureChest var treasureChest types.TreasureChest
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError(ctx, "transactionRecurring validate", nil, err) err = db.TransformAndLogDbError(ctx, "transactionRecurring validate", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest)
} }
return nil, err return nil, err
} }
if treasureChest.ParentId == nil { if treasureChest.ParentId == nil {
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest)
} }
hasTreasureChest = true hasTreasureChest = true
} }
if !hasAccount && !hasTreasureChest { if !hasAccount && !hasTreasureChest {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("either account or treasure chest is required: %w", ErrBadRequest) return nil, fmt.Errorf("either account or treasure chest is required: %w", core.ErrBadRequest)
} }
if hasAccount && hasTreasureChest { if hasAccount && hasTreasureChest {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", ErrBadRequest) return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", core.ErrBadRequest)
} }
valueFloat, err := strconv.ParseFloat(input.Value, 64) valueFloat, err := strconv.ParseFloat(input.Value, 64)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse value: %w", core.ErrBadRequest)
} }
value := int64(math.Round(valueFloat * DECIMALS_MULTIPLIER)) value := int64(math.Round(valueFloat * DECIMALS_MULTIPLIER))
@@ -485,18 +487,18 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0) intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("could not parse intervalMonths: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse intervalMonths: %w", core.ErrBadRequest)
} }
if intervalMonths < 1 { if intervalMonths < 1 {
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err) slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", ErrBadRequest) return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", core.ErrBadRequest)
} }
var nextExecution *time.Time = nil var nextExecution *time.Time = nil
if input.NextExecution != "" { if input.NextExecution != "" {
t, err := time.Parse("2006-01-02", input.NextExecution) t, err := time.Parse("2006-01-02", input.NextExecution)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "transaction validate", "err", err) slog.ErrorContext(ctx, "transaction validate", "err", err)
return nil, fmt.Errorf("could not parse timestamp: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest)
} }
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location()) t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())

View File

@@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"slices" "slices"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
@@ -14,20 +16,20 @@ import (
) )
type TreasureChest interface { type TreasureChest interface {
Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error)
Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error) Update(ctx context.Context, user *auth_types.User, id, parentId, name string) (*types.TreasureChest, error)
Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error)
GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error)
Delete(ctx context.Context, user *types.User, id string) error Delete(ctx context.Context, user *auth_types.User, id string) error
} }
type TreasureChestImpl struct { type TreasureChestImpl struct {
db *sqlx.DB db *sqlx.DB
clock Clock clock core.Clock
random Random random core.Random
} }
func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest { func NewTreasureChest(db *sqlx.DB, random core.Random, clock core.Clock) TreasureChest {
return TreasureChestImpl{ return TreasureChestImpl{
db: db, db: db,
clock: clock, clock: clock,
@@ -35,14 +37,14 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
} }
} }
func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
newId, err := s.random.UUID(ctx) newId, err := s.random.UUID(ctx)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, core.ErrInternal
} }
err = ValidateString(name, "name") err = ValidateString(name, "name")
@@ -57,7 +59,7 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId,
return nil, err return nil, err
} }
if parent.ParentId != nil { if parent.ParentId != nil {
return nil, fmt.Errorf("only a depth of 1 allowed: %w", ErrBadRequest) return nil, fmt.Errorf("only a depth of 1 allowed: %w", core.ErrBadRequest)
} }
parentUuid = &parent.Id parentUuid = &parent.Id
} }
@@ -88,9 +90,9 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId,
return treasureChest, nil return treasureChest, nil
} }
func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Update(ctx context.Context, user *auth_types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
err := ValidateString(name, "name") err := ValidateString(name, "name")
if err != nil { if err != nil {
@@ -99,7 +101,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
id, err := uuid.Parse(idStr) id, err := uuid.Parse(idStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "treasureChest update", "err", err) slog.ErrorContext(ctx, "treasureChest update", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -115,10 +117,10 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id) err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
err = db.TransformAndLogDbError(ctx, "treasureChest Update", nil, err) err = db.TransformAndLogDbError(ctx, "treasureChest Update", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err) return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
var parentUuid *uuid.UUID var parentUuid *uuid.UUID
@@ -134,7 +136,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
return nil, err return nil, err
} }
if parent.ParentId != nil || childCount > 0 { if parent.ParentId != nil || childCount > 0 {
return nil, fmt.Errorf("only one level allowed: %w", ErrBadRequest) return nil, fmt.Errorf("only one level allowed: %w", core.ErrBadRequest)
} }
parentUuid = &parent.Id parentUuid = &parent.Id
@@ -170,32 +172,32 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
return treasureChest, nil return treasureChest, nil
} }
func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) { func (s TreasureChestImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
uuid, err := uuid.Parse(id) uuid, err := uuid.Parse(id)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "treasureChest get", "err", err) slog.ErrorContext(ctx, "treasureChest get", "err", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
var treasureChest types.TreasureChest var treasureChest types.TreasureChest
err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError(ctx, "treasureChest Get", nil, err) err = db.TransformAndLogDbError(ctx, "treasureChest Get", nil, err)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, core.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err) return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
} }
return nil, types.ErrInternal return nil, core.ErrInternal
} }
return &treasureChest, nil return &treasureChest, nil
} }
func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) { func (s TreasureChestImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error) {
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, core.ErrUnauthorized
} }
treasureChests := make([]*types.TreasureChest, 0) treasureChests := make([]*types.TreasureChest, 0)
@@ -208,14 +210,14 @@ func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*typ
return sortTreasureChests(treasureChests), nil return sortTreasureChests(treasureChests), nil
} }
func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error { func (s TreasureChestImpl) Delete(ctx context.Context, user *auth_types.User, idStr string) error {
if user == nil { if user == nil {
return ErrUnauthorized return core.ErrUnauthorized
} }
id, err := uuid.Parse(idStr) id, err := uuid.Parse(idStr)
if err != nil { if err != nil {
slog.ErrorContext(ctx, "treasureChest delete", "err", err) slog.ErrorContext(ctx, "treasureChest delete", "err", err)
return fmt.Errorf("could not parse Id: %w", ErrBadRequest) return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
} }
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -235,7 +237,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
} }
if childCount > 0 { if childCount > 0 {
return fmt.Errorf("treasure chest has children: %w", ErrBadRequest) return fmt.Errorf("treasure chest has children: %w", core.ErrBadRequest)
} }
transactionsCount := 0 transactionsCount := 0
@@ -247,7 +249,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
return err return err
} }
if transactionsCount > 0 { if transactionsCount > 0 {
return fmt.Errorf("treasure chest has transactions: %w", ErrBadRequest) return fmt.Errorf("treasure chest has transactions: %w", core.ErrBadRequest)
} }
recurringCount := 0 recurringCount := 0
@@ -259,7 +261,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
return err return err
} }
if recurringCount > 0 { if recurringCount > 0 {
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest) return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", core.ErrBadRequest)
} }
r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id) r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)

View File

@@ -1 +0,0 @@
package mocks

View File

@@ -2,8 +2,10 @@ package test_test
import ( import (
"context" "context"
"spend-sparrow/internal/auth_types"
"spend-sparrow/internal/authentication"
"spend-sparrow/internal/core"
"spend-sparrow/internal/db" "spend-sparrow/internal/db"
"spend-sparrow/internal/types"
"testing" "testing"
"time" "time"
@@ -42,11 +44,11 @@ func TestUser(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) expected := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(context.Background(), expected) err := underTest.InsertUser(context.Background(), expected)
require.NoError(t, err) require.NoError(t, err)
@@ -63,38 +65,38 @@ func TestUser(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
_, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail") _, err := underTest.GetUserByEmail(context.Background(), "nonExistentEmail")
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, core.ErrNotFound, err)
}) })
t.Run("should return ErrUserExist", func(t *testing.T) { t.Run("should return ErrUserExist", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(context.Background(), user) err := underTest.InsertUser(context.Background(), user)
require.NoError(t, err) require.NoError(t, err)
err = underTest.InsertUser(context.Background(), user) err = underTest.InsertUser(context.Background(), user)
assert.Equal(t, db.ErrAlreadyExists, err) assert.Equal(t, core.ErrAlreadyExists, err)
}) })
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := auth_types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
err := underTest.InsertUser(context.Background(), user) err := underTest.InsertUser(context.Background(), user)
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, core.ErrInternal, err)
}) })
} }
@@ -105,11 +107,11 @@ func TestToken(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt) expected := auth_types.NewToken(uuid.New(), "sessionId", "token", auth_types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(context.Background(), expected) err := underTest.InsertToken(context.Background(), expected)
require.NoError(t, err) require.NoError(t, err)
@@ -121,25 +123,25 @@ func TestToken(t *testing.T) {
expected.SessionId = "" expected.SessionId = ""
actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type) actuals, err := underTest.GetTokensByUserIdAndType(context.Background(), expected.UserId, expected.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals) assert.Equal(t, []*auth_types.Token{expected}, actuals)
expected.SessionId = "sessionId" expected.SessionId = "sessionId"
expected.UserId = uuid.Nil expected.UserId = uuid.Nil
actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type) actuals, err = underTest.GetTokensBySessionIdAndType(context.Background(), expected.SessionId, expected.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals) assert.Equal(t, []*auth_types.Token{expected}, actuals)
}) })
t.Run("should insert and return multiple tokens", func(t *testing.T) { t.Run("should insert and return multiple tokens", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
userId := uuid.New() userId := uuid.New()
expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt) expected1 := auth_types.NewToken(userId, "sessionId", "token1", auth_types.TokenTypeCsrf, createAt, expiresAt)
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt) expected2 := auth_types.NewToken(userId, "sessionId", "token2", auth_types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(context.Background(), expected1) err := underTest.InsertToken(context.Background(), expected1)
require.NoError(t, err) require.NoError(t, err)
@@ -150,7 +152,7 @@ func TestToken(t *testing.T) {
expected2.UserId = uuid.Nil expected2.UserId = uuid.Nil
actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type) actuals, err := underTest.GetTokensBySessionIdAndType(context.Background(), expected1.SessionId, expected1.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals) assert.Equal(t, []*auth_types.Token{expected1, expected2}, actuals)
expected1.SessionId = "" expected1.SessionId = ""
expected2.SessionId = "" expected2.SessionId = ""
@@ -158,49 +160,49 @@ func TestToken(t *testing.T) {
expected2.UserId = userId expected2.UserId = userId
actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type) actuals, err = underTest.GetTokensByUserIdAndType(context.Background(), userId, expected1.Type)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals) assert.Equal(t, []*auth_types.Token{expected1, expected2}, actuals)
}) })
t.Run("should return ErrNotFound", func(t *testing.T) { t.Run("should return ErrNotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
_, err := underTest.GetToken(context.Background(), "nonExistent") _, err := underTest.GetToken(context.Background(), "nonExistent")
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, core.ErrNotFound, err)
_, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), types.TokenTypeEmailVerify) _, err = underTest.GetTokensByUserIdAndType(context.Background(), uuid.New(), auth_types.TokenTypeEmailVerify)
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, core.ErrNotFound, err)
_, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", types.TokenTypeEmailVerify) _, err = underTest.GetTokensBySessionIdAndType(context.Background(), "sessionId", auth_types.TokenTypeEmailVerify)
assert.Equal(t, db.ErrNotFound, err) assert.Equal(t, core.ErrNotFound, err)
}) })
t.Run("should return ErrAlreadyExists", func(t *testing.T) { t.Run("should return ErrAlreadyExists", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := auth_types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(context.Background(), user) err := underTest.InsertUser(context.Background(), user)
require.NoError(t, err) require.NoError(t, err)
err = underTest.InsertUser(context.Background(), user) err = underTest.InsertUser(context.Background(), user)
assert.Equal(t, db.ErrAlreadyExists, err) assert.Equal(t, core.ErrAlreadyExists, err)
}) })
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
t.Parallel() t.Parallel()
d := setupDb(t) d := setupDb(t)
underTest := db.NewAuthSqlite(d) underTest := authentication.NewDbSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := auth_types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
err := underTest.InsertUser(context.Background(), user) err := underTest.InsertUser(context.Background(), user)
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, core.ErrInternal, err)
}) })
} }

View File

@@ -2,8 +2,9 @@ package test_test
import ( import (
"context" "context"
"spend-sparrow/internal/db" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/service" "spend-sparrow/internal/authentication"
"spend-sparrow/internal/core"
"spend-sparrow/internal/types" "spend-sparrow/internal/types"
"spend-sparrow/mocks" "spend-sparrow/mocks"
"strings" "strings"
@@ -30,26 +31,26 @@ func TestSignUp(t *testing.T) {
t.Run("should check for correct email address", func(t *testing.T) { t.Run("should check for correct email address", func(t *testing.T) {
t.Parallel() t.Parallel()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!") _, err := underTest.SignUp(context.Background(), "invalid email address", "SomeStrongPassword123!")
assert.Equal(t, service.ErrInvalidEmail, err) assert.Equal(t, authentication.ErrInvalidEmail, err)
}) })
t.Run("should check for password complexity", func(t *testing.T) { t.Run("should check for password complexity", func(t *testing.T) {
t.Parallel() t.Parallel()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
weakPasswords := []string{ weakPasswords := []string{
"123!ab", // too short "123!ab", // too short
@@ -60,13 +61,13 @@ func TestSignUp(t *testing.T) {
for _, password := range weakPasswords { for _, password := range weakPasswords {
_, err := underTest.SignUp(context.Background(), "some@valid.email", password) _, err := underTest.SignUp(context.Background(), "some@valid.email", password)
assert.Equal(t, service.ErrInvalidPassword, err) assert.Equal(t, authentication.ErrInvalidPassword, err)
} }
}) })
t.Run("should signup correctly", func(t *testing.T) { t.Run("should signup correctly", func(t *testing.T) {
t.Parallel() t.Parallel()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
@@ -77,7 +78,7 @@ func TestSignUp(t *testing.T) {
salt := []byte("salt") salt := []byte("salt")
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime) expected := auth_types.NewUser(userId, email, false, nil, false, authentication.GetHashPassword(password, salt), salt, createTime)
ctx := context.Background() ctx := context.Background()
@@ -86,7 +87,7 @@ func TestSignUp(t *testing.T) {
mockClock.EXPECT().Now().Return(createTime) mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil) mockAuthDb.EXPECT().InsertUser(context.Background(), expected).Return(nil)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
actual, err := underTest.SignUp(context.Background(), email, password) actual, err := underTest.SignUp(context.Background(), email, password)
require.NoError(t, err) require.NoError(t, err)
@@ -96,7 +97,7 @@ func TestSignUp(t *testing.T) {
t.Run("should return ErrAccountExists", func(t *testing.T) { t.Run("should return ErrAccountExists", func(t *testing.T) {
t.Parallel() t.Parallel()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
@@ -106,19 +107,19 @@ func TestSignUp(t *testing.T) {
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
password := "SomeStrongPassword123!" password := "SomeStrongPassword123!"
salt := []byte("salt") salt := []byte("salt")
user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime) user := auth_types.NewUser(userId, email, false, nil, false, authentication.GetHashPassword(password, salt), salt, createTime)
ctx := context.Background() ctx := context.Background()
mockRandom.EXPECT().UUID(ctx).Return(user.Id, nil) mockRandom.EXPECT().UUID(ctx).Return(user.Id, nil)
mockRandom.EXPECT().Bytes(ctx, 16).Return(salt, nil) mockRandom.EXPECT().Bytes(ctx, 16).Return(salt, nil)
mockClock.EXPECT().Now().Return(createTime) mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(db.ErrAlreadyExists) mockAuthDb.EXPECT().InsertUser(context.Background(), user).Return(core.ErrAlreadyExists)
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp(context.Background(), user.Email, password) _, err := underTest.SignUp(context.Background(), user.Email, password)
assert.Equal(t, service.ErrAccountExists, err) assert.Equal(t, authentication.ErrAccountExists, err)
}) })
} }
@@ -127,30 +128,30 @@ func TestSendVerificationMail(t *testing.T) {
t.Run("should use stored token and send mail", func(t *testing.T) { t.Run("should use stored token and send mail", func(t *testing.T) {
t.Parallel() t.Parallel()
token := types.NewToken( token := auth_types.NewToken(
uuid.New(), uuid.New(),
"sessionId", "sessionId",
"someRandomTokenToUse", "someRandomTokenToUse",
types.TokenTypeEmailVerify, auth_types.TokenTypeEmailVerify,
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
tokens := []*types.Token{token} tokens := []*auth_types.Token{token}
email := "some@email.de" email := "some@email.de"
userId := uuid.New() userId := uuid.New()
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockDb(t)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
ctx := context.Background() ctx := context.Background()
mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, types.TokenTypeEmailVerify).Return(tokens, nil) mockAuthDb.EXPECT().GetTokensByUserIdAndType(context.Background(), userId, auth_types.TokenTypeEmailVerify).Return(tokens, nil)
mockMail.EXPECT().SendMail(ctx, email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool { mockMail.EXPECT().SendMail(ctx, email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
return strings.Contains(message, token.Token) return strings.Contains(message, token.Token)
})).Return() })).Return()
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest := authentication.NewService(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
underTest.SendVerificationMail(context.Background(), userId, email) underTest.SendVerificationMail(context.Background(), userId, email)
}) })

View File

@@ -7,8 +7,9 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"spend-sparrow/internal" "spend-sparrow/internal"
"spend-sparrow/internal/service" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/types" "spend-sparrow/internal/authentication"
"spend-sparrow/internal/core"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -117,7 +118,7 @@ func waitForReady(
default: default:
if time.Since(startTime) >= timeout { if time.Since(startTime) >= timeout {
t.Fatal("timeout reached while waiting for endpoint") t.Fatal("timeout reached while waiting for endpoint")
return types.ErrInternal return core.ErrInternal
} }
// wait a little while between checks // wait a little while between checks
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
@@ -178,7 +179,7 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s
t.Helper() t.Helper()
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" + add sessionId := "session-id" + add
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
csrfToken := "my-verifying-token" + add csrfToken := "my-verifying-token" + add
email := add + "mail@mail.de" email := add + "mail@mail.de"
@@ -193,7 +194,7 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, s
_, err = db.ExecContext(context.Background(), ` _, err = db.ExecContext(context.Background(), `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf) VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, auth_types.TokenTypeCsrf)
require.NoError(t, err) require.NoError(t, err)
return userId, csrfToken, sessionId return userId, csrfToken, sessionId

View File

@@ -3,8 +3,8 @@ package test_test
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"spend-sparrow/internal/service" "spend-sparrow/internal/auth_types"
"spend-sparrow/internal/types" "spend-sparrow/internal/authentication"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -110,7 +110,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -136,7 +136,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -163,7 +163,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -206,7 +206,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -247,7 +247,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
@@ -295,7 +295,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
@@ -414,7 +414,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
@@ -467,7 +467,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -550,7 +550,7 @@ func TestIntegrationAuth(t *testing.T) {
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), authentication.GetHashPassword("password", []byte("salt")), []byte("salt"))
require.NoError(t, err) require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil)
@@ -631,7 +631,7 @@ func TestIntegrationAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
var token string var token string
err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token) err = db.QueryRowContext(ctx, "SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", auth_types.TokenTypeEmailVerify).Scan(&token)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, token) assert.NotEmpty(t, token)
}) })
@@ -676,7 +676,7 @@ func TestIntegrationAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.ExecContext(ctx, ` _, err = db.ExecContext(ctx, `
INSERT INTO token (token, user_id, type, created_at, expires_at) INSERT INTO token (token, user_id, type, created_at, expires_at)
VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify) VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, auth_types.TokenTypeEmailVerify)
require.NoError(t, err) require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil)
@@ -706,7 +706,7 @@ func TestIntegrationAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.ExecContext(ctx, ` _, err = db.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify) VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, auth_types.TokenTypeEmailVerify)
require.NoError(t, err) require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil)
@@ -746,7 +746,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
sessionId := "session-id" sessionId := "session-id"
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -765,7 +765,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var csrfToken string var csrfToken string
err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken) err = db.QueryRowContext(ctx, "SELECT token FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypeCsrf).Scan(&csrfToken)
require.NoError(t, err) require.NoError(t, err)
req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil) req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil)
@@ -824,7 +824,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -870,7 +870,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1039,7 +1039,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1078,7 +1078,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -1128,7 +1128,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
@@ -1180,7 +1180,7 @@ func TestIntegrationAuth(t *testing.T) {
userId := uuid.New() userId := uuid.New()
userIdOther := uuid.New() userIdOther := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1230,7 +1230,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt")) pass = authentication.GetHashPassword("MyNewSecurePassword1!", []byte("salt"))
var rows int var rows int
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
@@ -1259,7 +1259,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1287,7 +1287,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1317,7 +1317,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
var rows int var rows int
err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) err = d.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypePasswordReset).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 0, rows) assert.Equal(t, 0, rows)
}) })
@@ -1362,7 +1362,7 @@ func TestIntegrationAuth(t *testing.T) {
db, basePath, ctx := setupIntegrationTest(t) db, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := db.ExecContext(ctx, ` _, err := db.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1399,7 +1399,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg) assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg)
var rows int var rows int
err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, auth_types.TokenTypePasswordReset).Scan(&rows)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 1, rows) assert.Equal(t, 1, rows)
}) })
@@ -1412,7 +1412,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1455,7 +1455,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1475,7 +1475,7 @@ func TestIntegrationAuth(t *testing.T) {
token := "password-reset-token" token := "password-reset-token"
_, err = d.ExecContext(ctx, ` _, err = d.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset) VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", auth_types.TokenTypePasswordReset)
require.NoError(t, err) require.NoError(t, err)
formData := url.Values{ formData := url.Values{
@@ -1504,7 +1504,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1524,7 +1524,7 @@ func TestIntegrationAuth(t *testing.T) {
token := "password-reset-token" token := "password-reset-token"
_, err = d.ExecContext(ctx, ` _, err = d.ExecContext(ctx, `
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset) VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", auth_types.TokenTypePasswordReset)
require.NoError(t, err) require.NoError(t, err)
formData := url.Values{ formData := url.Values{
@@ -1553,7 +1553,7 @@ func TestIntegrationAuth(t *testing.T) {
d, basePath, ctx := setupIntegrationTest(t) d, basePath, ctx := setupIntegrationTest(t)
userId := uuid.New() userId := uuid.New()
pass := service.GetHashPassword("password", []byte("salt")) pass := authentication.GetHashPassword("password", []byte("salt"))
_, err := d.ExecContext(ctx, ` _, err := d.ExecContext(ctx, `
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt"))
@@ -1590,7 +1590,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var token string var token string
err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token) err = d.QueryRowContext(ctx, "SELECT token FROM token WHERE type = ?", auth_types.TokenTypePasswordReset).Scan(&token)
require.NoError(t, err) require.NoError(t, err)
formData = url.Values{ formData = url.Values{