This commit is contained in:
@@ -4,29 +4,31 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"log/slog"
|
||||
"spend-sparrow/internal/auth_types"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/db"
|
||||
"spend-sparrow/internal/service"
|
||||
"spend-sparrow/internal/types"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
Add(ctx context.Context, user *types.User, name string) (*Account, error)
|
||||
UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error)
|
||||
Get(ctx context.Context, user *types.User, id string) (*Account, error)
|
||||
GetAll(ctx context.Context, user *types.User) ([]*Account, error)
|
||||
Delete(ctx context.Context, user *types.User, id string) error
|
||||
Add(ctx context.Context, user *auth_types.User, name string) (*Account, error)
|
||||
UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error)
|
||||
Get(ctx context.Context, user *auth_types.User, id string) (*Account, error)
|
||||
GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error)
|
||||
Delete(ctx context.Context, user *auth_types.User, id string) error
|
||||
}
|
||||
|
||||
type ServiceImpl struct {
|
||||
db *sqlx.DB
|
||||
clock service.Clock
|
||||
random service.Random
|
||||
clock core.Clock
|
||||
random core.Random
|
||||
}
|
||||
|
||||
func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Service {
|
||||
func NewServiceImpl(db *sqlx.DB, random core.Random, clock core.Clock) Service {
|
||||
return ServiceImpl{
|
||||
db: db,
|
||||
clock: clock,
|
||||
@@ -34,14 +36,14 @@ func NewServiceImpl(db *sqlx.DB, random service.Random, clock service.Clock) Ser
|
||||
}
|
||||
}
|
||||
|
||||
func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*Account, error) {
|
||||
func (s ServiceImpl) Add(ctx context.Context, user *auth_types.User, name string) (*Account, error) {
|
||||
if user == nil {
|
||||
return nil, types.ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
newId, err := s.random.UUID(ctx)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
err = service.ValidateString(name, "name")
|
||||
@@ -76,9 +78,9 @@ func (s ServiceImpl) Add(ctx context.Context, user *types.User, name string) (*A
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string, name string) (*Account, error) {
|
||||
func (s ServiceImpl) UpdateName(ctx context.Context, user *auth_types.User, id string, name string) (*Account, error) {
|
||||
if user == nil {
|
||||
return nil, types.ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
err := service.ValidateString(name, "name")
|
||||
if err != nil {
|
||||
@@ -87,7 +89,7 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
|
||||
uuid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "account update", "err", err)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -103,10 +105,10 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
|
||||
err = tx.GetContext(ctx, &account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError(ctx, "account Update", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, fmt.Errorf("account %v not found: %w", id, service.ErrBadRequest)
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("account %v not found: %w", id, core.ErrBadRequest)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
timestamp := s.clock.Now()
|
||||
@@ -136,14 +138,14 @@ func (s ServiceImpl) UpdateName(ctx context.Context, user *types.User, id string
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Account, error) {
|
||||
func (s ServiceImpl) Get(ctx context.Context, user *auth_types.User, id string) (*Account, error) {
|
||||
if user == nil {
|
||||
return nil, service.ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
uuid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "account get", "err", err)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
var account Account
|
||||
@@ -158,9 +160,9 @@ func (s ServiceImpl) Get(ctx context.Context, user *types.User, id string) (*Acc
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account, error) {
|
||||
func (s ServiceImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*Account, error) {
|
||||
if user == nil {
|
||||
return nil, service.ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
accounts := make([]*Account, 0)
|
||||
@@ -174,14 +176,14 @@ func (s ServiceImpl) GetAll(ctx context.Context, user *types.User) ([]*Account,
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||
func (s ServiceImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
|
||||
if user == nil {
|
||||
return service.ErrUnauthorized
|
||||
return core.ErrUnauthorized
|
||||
}
|
||||
uuid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "account delete", "err", err)
|
||||
return fmt.Errorf("could not parse Id: %w", service.ErrBadRequest)
|
||||
return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -200,7 +202,7 @@ func (s ServiceImpl) Delete(ctx context.Context, user *types.User, id string) er
|
||||
return err
|
||||
}
|
||||
if transactionsCount > 0 {
|
||||
return fmt.Errorf("account has transactions, cannot delete: %w", service.ErrBadRequest)
|
||||
return fmt.Errorf("account has transactions, cannot delete: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
res, err := tx.ExecContext(ctx, "DELETE FROM account WHERE id = ? and user_id = ?", uuid, user.Id)
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"spend-sparrow/internal/auth_types"
|
||||
"spend-sparrow/internal/authentication/template"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/types"
|
||||
"spend-sparrow/internal/utils"
|
||||
"time"
|
||||
)
|
||||
@@ -79,7 +79,7 @@ func (handler HandlerImpl) handleSignIn() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
core.UpdateSpan(r)
|
||||
|
||||
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*User, error) {
|
||||
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*auth_types.User, error) {
|
||||
session := core.GetSession(r)
|
||||
email := r.FormValue("email")
|
||||
password := r.FormValue("password")
|
||||
@@ -89,14 +89,14 @@ func (handler HandlerImpl) handleSignIn() http.HandlerFunc {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cookie := middleware.CreateSessionCookie(session.Id)
|
||||
cookie := core.CreateSessionCookie(session.Id)
|
||||
http.SetCookie(w, &cookie)
|
||||
|
||||
return user, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrInvalidCredentials) {
|
||||
if errors.Is(err, ErrInvalidCredentials) {
|
||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Invalid email or password", http.StatusUnauthorized)
|
||||
} else {
|
||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
|
||||
@@ -127,7 +127,7 @@ func (handler HandlerImpl) handleSignUpPage() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
signUpComp := auth.SignInOrUpComp(false)
|
||||
signUpComp := template.SignInOrUpComp(false)
|
||||
handler.render.RenderLayout(r, w, signUpComp, nil)
|
||||
}
|
||||
}
|
||||
@@ -147,7 +147,7 @@ func (handler HandlerImpl) handleSignUpVerifyPage() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
signIn := auth.VerifyComp()
|
||||
signIn := template.VerifyComp()
|
||||
handler.render.RenderLayout(r, w, signIn, user)
|
||||
}
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (handler HandlerImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
||||
err := handler.service.VerifyUserEmail(r.Context(), token)
|
||||
|
||||
isVerified := err == nil
|
||||
comp := auth.VerifyResponseComp(isVerified)
|
||||
comp := template.VerifyResponseComp(isVerified)
|
||||
|
||||
var status int
|
||||
if isVerified {
|
||||
@@ -214,14 +214,14 @@ func (handler HandlerImpl) handleSignUp() http.HandlerFunc {
|
||||
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, types.ErrInternal):
|
||||
case errors.Is(err, core.ErrInternal):
|
||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "An error occurred", http.StatusInternalServerError)
|
||||
return
|
||||
case errors.Is(err, service.ErrInvalidEmail):
|
||||
case errors.Is(err, ErrInvalidEmail):
|
||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "The email provided is invalid", http.StatusBadRequest)
|
||||
return
|
||||
case errors.Is(err, service.ErrInvalidPassword):
|
||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest)
|
||||
case errors.Is(err, ErrInvalidPassword):
|
||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", ErrInvalidPassword.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// If err is "service.ErrAccountExists", then just continue
|
||||
@@ -270,7 +270,7 @@ func (handler HandlerImpl) handleDeleteAccountPage() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
comp := auth.DeleteAccountComp()
|
||||
comp := template.DeleteAccountComp()
|
||||
handler.render.RenderLayout(r, w, comp, user)
|
||||
}
|
||||
}
|
||||
@@ -289,7 +289,7 @@ func (handler HandlerImpl) handleDeleteAccountComp() http.HandlerFunc {
|
||||
|
||||
err := handler.service.DeleteAccount(r.Context(), user, password)
|
||||
if err != nil {
|
||||
if errors.Is(err, service.ErrInvalidCredentials) {
|
||||
if errors.Is(err, ErrInvalidCredentials) {
|
||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Password not correct", http.StatusBadRequest)
|
||||
} else {
|
||||
utils.TriggerToastWithStatus(r.Context(), w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -314,7 +314,7 @@ func (handler HandlerImpl) handleChangePasswordPage() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
comp := auth.ChangePasswordComp(isPasswordReset)
|
||||
comp := template.ChangePasswordComp(isPasswordReset)
|
||||
handler.render.RenderLayout(r, w, comp, user)
|
||||
}
|
||||
}
|
||||
@@ -353,7 +353,7 @@ func (handler HandlerImpl) handleForgotPasswordPage() http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
comp := auth.ResetPasswordComp()
|
||||
comp := template.ResetPasswordComp()
|
||||
handler.render.RenderLayout(r, w, comp, user)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/mail"
|
||||
"spend-sparrow/internal/auth_types"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/db"
|
||||
mailTemplate "spend-sparrow/internal/template/mail"
|
||||
"spend-sparrow/internal/types"
|
||||
"strings"
|
||||
@@ -53,13 +52,13 @@ type Service interface {
|
||||
type ServiceImpl struct {
|
||||
db Db
|
||||
random core.Random
|
||||
clock Clock
|
||||
mail Mail
|
||||
clock core.Clock
|
||||
mail core.Mail
|
||||
serverSettings *types.Settings
|
||||
}
|
||||
|
||||
func NewService(db db.Auth, random Random, clock Clock, mail Mail, serverSettings *types.Settings) *HandlerImpl {
|
||||
return &HandlerImpl{
|
||||
func NewService(db Db, random core.Random, clock core.Clock, mail core.Mail, serverSettings *types.Settings) *ServiceImpl {
|
||||
return &ServiceImpl{
|
||||
db: db,
|
||||
random: random,
|
||||
clock: clock,
|
||||
@@ -68,13 +67,13 @@ func NewService(db db.Auth, random Random, clock Clock, mail Mail, serverSetting
|
||||
}
|
||||
}
|
||||
|
||||
func (service HandlerImpl) SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_type.Session, *auth_types.User, error) {
|
||||
func (service ServiceImpl) SignIn(ctx context.Context, session *auth_types.Session, email string, password string) (*auth_types.Session, *auth_types.User, error) {
|
||||
user, err := service.db.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, nil, ErrInvalidCredentials
|
||||
} else {
|
||||
return nil, nil, types.ErrInternal
|
||||
return nil, nil, core.ErrInternal
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,36 +85,36 @@ func (service HandlerImpl) SignIn(ctx context.Context, session *auth_types.Sessi
|
||||
|
||||
newSession, err := service.createSession(ctx, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
return nil, nil, core.ErrInternal
|
||||
}
|
||||
|
||||
err = service.db.DeleteSession(ctx, session.Id)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
return nil, nil, core.ErrInternal
|
||||
}
|
||||
|
||||
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
return nil, nil, core.ErrInternal
|
||||
}
|
||||
for _, token := range tokens {
|
||||
err = service.db.DeleteToken(ctx, token.Token)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
return nil, nil, core.ErrInternal
|
||||
}
|
||||
}
|
||||
|
||||
return newSession, user, nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error) {
|
||||
func (service ServiceImpl) SignInSession(ctx context.Context, sessionId string) (*auth_types.Session, *auth_types.User, error) {
|
||||
if sessionId == "" {
|
||||
return nil, nil, ErrSessionIdInvalid
|
||||
}
|
||||
|
||||
session, err := service.db.GetSession(ctx, sessionId)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
return nil, nil, core.ErrInternal
|
||||
}
|
||||
if session.ExpiresAt.Before(service.clock.Now()) {
|
||||
_ = service.db.DeleteSession(ctx, sessionId)
|
||||
@@ -128,16 +127,16 @@ func (service HandlerImpl) SignInSession(ctx context.Context, sessionId string)
|
||||
|
||||
user, err := service.db.GetUser(ctx, session.UserId)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
return nil, nil, core.ErrInternal
|
||||
}
|
||||
|
||||
return session, user, nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) SignInAnonymous(ctx context.Context) (*auth_types.Session, error) {
|
||||
func (service ServiceImpl) SignInAnonymous(ctx context.Context) (*auth_types.Session, error) {
|
||||
session, err := service.createSession(ctx, uuid.Nil)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "anonymous session created", "session-id", session.Id)
|
||||
@@ -145,7 +144,7 @@ func (service HandlerImpl) SignInAnonymous(ctx context.Context) (*auth_types.Ses
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) SignUp(ctx context.Context, email string, password string) (*auth_types.User, error) {
|
||||
func (service ServiceImpl) SignUp(ctx context.Context, email string, password string) (*auth_types.User, error) {
|
||||
_, err := mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidEmail
|
||||
@@ -157,37 +156,37 @@ func (service HandlerImpl) SignUp(ctx context.Context, email string, password st
|
||||
|
||||
userId, err := service.random.UUID(ctx)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
salt, err := service.random.Bytes(ctx, 16)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
hash := GetHashPassword(password, salt)
|
||||
|
||||
user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
|
||||
user := auth_types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
|
||||
|
||||
err = service.db.InsertUser(ctx, user)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrAlreadyExists) {
|
||||
if errors.Is(err, core.ErrAlreadyExists) {
|
||||
return nil, ErrAccountExists
|
||||
} else {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
|
||||
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, types.TokenTypeEmailVerify)
|
||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||
func (service ServiceImpl) SendVerificationMail(ctx context.Context, userId uuid.UUID, email string) {
|
||||
tokens, err := service.db.GetTokensByUserIdAndType(ctx, userId, auth_types.TokenTypeEmailVerify)
|
||||
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||
return
|
||||
}
|
||||
|
||||
var token *types.Token
|
||||
var token *auth_types.Token
|
||||
|
||||
if len(tokens) > 0 {
|
||||
token = tokens[0]
|
||||
@@ -199,11 +198,11 @@ func (service HandlerImpl) SendVerificationMail(ctx context.Context, userId uuid
|
||||
return
|
||||
}
|
||||
|
||||
token = types.NewToken(
|
||||
token = auth_types.NewToken(
|
||||
userId,
|
||||
"",
|
||||
newTokenStr,
|
||||
types.TokenTypeEmailVerify,
|
||||
auth_types.TokenTypeEmailVerify,
|
||||
service.clock.Now(),
|
||||
service.clock.Now().Add(24*time.Hour))
|
||||
|
||||
@@ -223,29 +222,29 @@ func (service HandlerImpl) SendVerificationMail(ctx context.Context, userId uuid
|
||||
service.mail.SendMail(ctx, email, "Welcome to spend-sparrow", w.String())
|
||||
}
|
||||
|
||||
func (service HandlerImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
|
||||
func (service ServiceImpl) VerifyUserEmail(ctx context.Context, tokenStr string) error {
|
||||
if tokenStr == "" {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
token, err := service.db.GetToken(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
user, err := service.db.GetUser(ctx, token.UserId)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
if token.Type != types.TokenTypeEmailVerify {
|
||||
return types.ErrInternal
|
||||
if token.Type != auth_types.TokenTypeEmailVerify {
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
now := service.clock.Now()
|
||||
|
||||
if token.ExpiresAt.Before(now) {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
user.EmailVerified = true
|
||||
@@ -253,21 +252,21 @@ func (service HandlerImpl) VerifyUserEmail(ctx context.Context, tokenStr string)
|
||||
|
||||
err = service.db.UpdateUser(ctx, user)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
_ = service.db.DeleteToken(ctx, token.Token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) SignOut(ctx context.Context, sessionId string) error {
|
||||
func (service ServiceImpl) SignOut(ctx context.Context, sessionId string) error {
|
||||
return service.db.DeleteSession(ctx, sessionId)
|
||||
}
|
||||
|
||||
func (service HandlerImpl) DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error {
|
||||
func (service ServiceImpl) DeleteAccount(ctx context.Context, user *auth_types.User, currPass string) error {
|
||||
userDb, err := service.db.GetUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
currHash := GetHashPassword(currPass, userDb.Salt)
|
||||
@@ -285,7 +284,7 @@ func (service HandlerImpl) DeleteAccount(ctx context.Context, user *auth_types.U
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error {
|
||||
func (service ServiceImpl) ChangePassword(ctx context.Context, user *auth_types.User, sessionId string, currPass, newPass string) error {
|
||||
if !isPasswordValid(newPass) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
@@ -310,13 +309,13 @@ func (service HandlerImpl) ChangePassword(ctx context.Context, user *auth_types.
|
||||
|
||||
sessions, err := service.db.GetSessions(ctx, user.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
for _, s := range sessions {
|
||||
if s.Id != sessionId {
|
||||
err = service.db.DeleteSession(ctx, s.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -324,7 +323,7 @@ func (service HandlerImpl) ChangePassword(ctx context.Context, user *auth_types.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
|
||||
func (service ServiceImpl) SendForgotPasswordMail(ctx context.Context, email string) error {
|
||||
tokenStr, err := service.random.String(ctx, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -332,38 +331,38 @@ func (service HandlerImpl) SendForgotPasswordMail(ctx context.Context, email str
|
||||
|
||||
user, err := service.db.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil
|
||||
} else {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
}
|
||||
|
||||
token := types.NewToken(
|
||||
token := auth_types.NewToken(
|
||||
user.Id,
|
||||
"",
|
||||
tokenStr,
|
||||
types.TokenTypePasswordReset,
|
||||
auth_types.TokenTypePasswordReset,
|
||||
service.clock.Now(),
|
||||
service.clock.Now().Add(15*time.Minute))
|
||||
|
||||
err = service.db.InsertToken(ctx, token)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
var mail strings.Builder
|
||||
err = mailTemplate.ResetPassword(service.serverSettings.BaseUrl, token.Token).Render(context.Background(), &mail)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "Could not render reset password email", "err", err)
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
service.mail.SendMail(ctx, email, "Reset Password", mail.String())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
|
||||
func (service ServiceImpl) ForgotPassword(ctx context.Context, tokenStr string, newPass string) error {
|
||||
if !isPasswordValid(newPass) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
@@ -378,7 +377,7 @@ func (service HandlerImpl) ForgotPassword(ctx context.Context, tokenStr string,
|
||||
return err
|
||||
}
|
||||
|
||||
if token.Type != types.TokenTypePasswordReset ||
|
||||
if token.Type != auth_types.TokenTypePasswordReset ||
|
||||
token.ExpiresAt.Before(service.clock.Now()) {
|
||||
return ErrTokenInvalid
|
||||
}
|
||||
@@ -386,7 +385,7 @@ func (service HandlerImpl) ForgotPassword(ctx context.Context, tokenStr string,
|
||||
user, err := service.db.GetUser(ctx, token.UserId)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "Could not get user from token", "err", err)
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
passHash := GetHashPassword(newPass, user.Salt)
|
||||
@@ -399,26 +398,26 @@ func (service HandlerImpl) ForgotPassword(ctx context.Context, tokenStr string,
|
||||
|
||||
sessions, err := service.db.GetSessions(ctx, user.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
err = service.db.DeleteSession(ctx, session.Id)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
|
||||
func (service ServiceImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool {
|
||||
token, err := service.db.GetToken(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if token.Type != types.TokenTypeCsrf ||
|
||||
if token.Type != auth_types.TokenTypeCsrf ||
|
||||
token.SessionId != sessionId ||
|
||||
token.ExpiresAt.Before(service.clock.Now()) {
|
||||
return false
|
||||
@@ -427,12 +426,12 @@ func (service HandlerImpl) IsCsrfTokenValid(ctx context.Context, tokenStr string
|
||||
return true
|
||||
}
|
||||
|
||||
func (service HandlerImpl) GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error) {
|
||||
func (service ServiceImpl) GetCsrfToken(ctx context.Context, session *auth_types.Session) (string, error) {
|
||||
if session == nil {
|
||||
return "", types.ErrInternal
|
||||
return "", core.ErrInternal
|
||||
}
|
||||
|
||||
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||
tokens, _ := service.db.GetTokensBySessionIdAndType(ctx, session.Id, auth_types.TokenTypeCsrf)
|
||||
|
||||
if len(tokens) > 0 {
|
||||
return tokens[0].Token, nil
|
||||
@@ -440,19 +439,19 @@ func (service HandlerImpl) GetCsrfToken(ctx context.Context, session *auth_types
|
||||
|
||||
tokenStr, err := service.random.String(ctx, 32)
|
||||
if err != nil {
|
||||
return "", types.ErrInternal
|
||||
return "", core.ErrInternal
|
||||
}
|
||||
|
||||
token := types.NewToken(
|
||||
token := auth_types.NewToken(
|
||||
session.UserId,
|
||||
session.Id,
|
||||
tokenStr,
|
||||
types.TokenTypeCsrf,
|
||||
auth_types.TokenTypeCsrf,
|
||||
service.clock.Now(),
|
||||
service.clock.Now().Add(8*time.Hour))
|
||||
err = service.db.InsertToken(ctx, token)
|
||||
if err != nil {
|
||||
return "", types.ErrInternal
|
||||
return "", core.ErrInternal
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "CSRF-Token created", "token", tokenStr)
|
||||
@@ -460,34 +459,34 @@ func (service HandlerImpl) GetCsrfToken(ctx context.Context, session *auth_types
|
||||
return tokenStr, nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) CleanupSessionsAndTokens(ctx context.Context) error {
|
||||
func (service ServiceImpl) CleanupSessionsAndTokens(ctx context.Context) error {
|
||||
err := service.db.DeleteOldSessions(ctx)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
err = service.db.DeleteOldTokens(ctx)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
return core.ErrInternal
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service HandlerImpl) createSession(ctx context.Context, userId uuid.UUID) (*auth_types.Session, error) {
|
||||
func (service ServiceImpl) createSession(ctx context.Context, userId uuid.UUID) (*auth_types.Session, error) {
|
||||
sessionId, err := service.random.String(ctx, 32)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
createAt := service.clock.Now()
|
||||
expiresAt := createAt.Add(24 * time.Hour)
|
||||
|
||||
session := types.NewSession(sessionId, userId, createAt, expiresAt)
|
||||
session := auth_types.NewSession(sessionId, userId, createAt, expiresAt)
|
||||
|
||||
err = service.db.InsertSession(ctx, session)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
return session, nil
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package core
|
||||
|
||||
import "time"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"os/signal"
|
||||
"spend-sparrow/internal/account"
|
||||
"spend-sparrow/internal/authentication"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/db"
|
||||
"spend-sparrow/internal/handler"
|
||||
@@ -107,13 +108,13 @@ func shutdownServer(ctx context.Context, s *http.Server, wg *sync.WaitGroup) {
|
||||
func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||
var router = http.NewServeMux()
|
||||
|
||||
authDb := db.NewAuthSqlite(d)
|
||||
authDb := authentication.NewDbSqlite(d)
|
||||
|
||||
randomService := service.NewRandom()
|
||||
clockService := service.NewClock()
|
||||
mailService := service.NewMail(serverSettings)
|
||||
randomService := core.NewRandom()
|
||||
clockService := core.NewClock()
|
||||
mailService := core.NewMail(serverSettings)
|
||||
|
||||
authService := service.NewAuth(authDb, randomService, clockService, mailService, serverSettings)
|
||||
authService := authentication.NewService(authDb, randomService, clockService, mailService, serverSettings)
|
||||
accountService := account.NewServiceImpl(d, randomService, clockService)
|
||||
treasureChestService := service.NewTreasureChest(d, randomService, clockService)
|
||||
transactionService := service.NewTransaction(d, randomService, clockService)
|
||||
@@ -123,7 +124,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *
|
||||
render := core.NewRender()
|
||||
indexHandler := handler.NewIndex(render, clockService)
|
||||
dashboardHandler := handler.NewDashboard(render, dashboardService, treasureChestService)
|
||||
authHandler := handler.NewAuth(authService, render)
|
||||
authHandler := authentication.NewHandler(authService, render)
|
||||
accountHandler := account.NewHandler(accountService, render)
|
||||
treasureChestHandler := handler.NewTreasureChest(treasureChestService, transactionRecurringService, render)
|
||||
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
|
||||
@@ -157,7 +158,7 @@ func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *
|
||||
return wrapper
|
||||
}
|
||||
|
||||
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
|
||||
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) {
|
||||
runDailyTasks(ctx, transactionRecurring, auth)
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
@@ -172,7 +173,7 @@ func dailyTaskTimer(ctx context.Context, transactionRecurring service.Transactio
|
||||
}
|
||||
}
|
||||
|
||||
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
|
||||
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth authentication.Service) {
|
||||
slog.InfoContext(ctx, "Running daily tasks")
|
||||
_ = transactionRecurring.GenerateTransactions(ctx)
|
||||
_ = auth.CleanupSessionsAndTokens(ctx)
|
||||
|
||||
@@ -193,7 +193,7 @@ func (handler DashboardImpl) handleDashboardTreasureChest() http.HandlerFunc {
|
||||
if treasureChestStr != "" {
|
||||
id, err := uuid.Parse(treasureChestStr)
|
||||
if err != nil {
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", service.ErrBadRequest))
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest: %w", core.ErrBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -3,12 +3,12 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"spend-sparrow/internal/authentication"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
||||
func Authenticate(service authentication.Service) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
@@ -31,7 +31,7 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
||||
return
|
||||
}
|
||||
|
||||
cookie := CreateSessionCookie(session.Id)
|
||||
cookie := core.CreateSessionCookie(session.Id)
|
||||
http.SetCookie(w, &cookie)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ package middleware
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"spend-sparrow/internal/authentication"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/service"
|
||||
"spend-sparrow/internal/utils"
|
||||
"strings"
|
||||
)
|
||||
@@ -31,7 +31,7 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
|
||||
return rr.ResponseWriter.Write([]byte(dataStr))
|
||||
}
|
||||
|
||||
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
|
||||
func CrossSiteRequestForgery(auth authentication.Service) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
@@ -3,7 +3,6 @@ package handler
|
||||
import (
|
||||
"net/http"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/service"
|
||||
"spend-sparrow/internal/template"
|
||||
"spend-sparrow/internal/utils"
|
||||
|
||||
@@ -16,10 +15,10 @@ type Index interface {
|
||||
|
||||
type IndexImpl struct {
|
||||
r *core.Render
|
||||
c service.Clock
|
||||
c core.Clock
|
||||
}
|
||||
|
||||
func NewIndex(r *core.Render, c service.Clock) Index {
|
||||
func NewIndex(r *core.Render, c core.Clock) Index {
|
||||
return IndexImpl{
|
||||
r: r,
|
||||
c: c,
|
||||
|
||||
@@ -157,7 +157,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
||||
if idStr != "new" {
|
||||
id, err = uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest))
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -167,7 +167,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
||||
if accountIdStr != "" {
|
||||
i, err := uuid.Parse(accountIdStr)
|
||||
if err != nil {
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", service.ErrBadRequest))
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse account id: %w", core.ErrBadRequest))
|
||||
return
|
||||
}
|
||||
accountId = &i
|
||||
@@ -178,7 +178,7 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
||||
if treasureChestIdStr != "" {
|
||||
i, err := uuid.Parse(treasureChestIdStr)
|
||||
if err != nil {
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", service.ErrBadRequest))
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse treasure chest id: %w", core.ErrBadRequest))
|
||||
return
|
||||
}
|
||||
treasureChestId = &i
|
||||
@@ -186,14 +186,14 @@ func (h TransactionImpl) handleUpdateTransaction() http.HandlerFunc {
|
||||
|
||||
valueF, err := strconv.ParseFloat(r.FormValue("value"), 64)
|
||||
if err != nil {
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse value: %w", service.ErrBadRequest))
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse value: %w", core.ErrBadRequest))
|
||||
return
|
||||
}
|
||||
value := int64(math.Round(valueF * service.DECIMALS_MULTIPLIER))
|
||||
|
||||
timestamp, err := time.Parse("2006-01-02", r.FormValue("timestamp"))
|
||||
if err != nil {
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", service.ErrBadRequest))
|
||||
core.HandleError(w, r, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"spend-sparrow/internal/auth_types"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/service"
|
||||
t "spend-sparrow/internal/template/transaction_recurring"
|
||||
@@ -111,7 +112,7 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
|
||||
}
|
||||
}
|
||||
|
||||
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) {
|
||||
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *auth_types.User, id, accountId, treasureChestId string) {
|
||||
var transactionsRecurring []*types.TransactionRecurring
|
||||
var err error
|
||||
if accountId == "" && treasureChestId == "" {
|
||||
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"spend-sparrow/internal/auth_types"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/db"
|
||||
"spend-sparrow/internal/types"
|
||||
"time"
|
||||
@@ -22,10 +24,10 @@ func NewDashboard(db *sqlx.DB) *Dashboard {
|
||||
|
||||
func (s Dashboard) MainChart(
|
||||
ctx context.Context,
|
||||
user *types.User,
|
||||
user *auth_types.User,
|
||||
) ([]types.DashboardMainChartEntry, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
transactions := make([]types.Transaction, 0)
|
||||
@@ -82,10 +84,10 @@ func (s Dashboard) MainChart(
|
||||
|
||||
func (s Dashboard) TreasureChests(
|
||||
ctx context.Context,
|
||||
user *types.User,
|
||||
user *auth_types.User,
|
||||
) ([]*types.DashboardTreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
treasureChests := make([]*types.TreasureChest, 0)
|
||||
@@ -120,11 +122,11 @@ func (s Dashboard) TreasureChests(
|
||||
|
||||
func (s Dashboard) TreasureChest(
|
||||
ctx context.Context,
|
||||
user *types.User,
|
||||
user *auth_types.User,
|
||||
treausureChestId *uuid.UUID,
|
||||
) ([]types.DashboardMainChartEntry, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
transactions := make([]types.Transaction, 0)
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"spend-sparrow/internal/core"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -16,9 +17,9 @@ var (
|
||||
func ValidateString(value string, fieldName string) error {
|
||||
switch {
|
||||
case value == "":
|
||||
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest)
|
||||
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, core.ErrBadRequest)
|
||||
case !safeInputRegex.MatchString(value):
|
||||
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest)
|
||||
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, core.ErrBadRequest)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"spend-sparrow/internal/authentication"
|
||||
"spend-sparrow/internal/auth_types"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/db"
|
||||
"spend-sparrow/internal/types"
|
||||
"strconv"
|
||||
@@ -18,22 +19,22 @@ import (
|
||||
const page_size = 25
|
||||
|
||||
type Transaction interface {
|
||||
Add(ctx context.Context, tx *sqlx.Tx, user *authentication.User, transaction types.Transaction) (*types.Transaction, error)
|
||||
Update(ctx context.Context, user *authentication.User, transaction types.Transaction) (*types.Transaction, error)
|
||||
Get(ctx context.Context, user *authentication.User, id string) (*types.Transaction, error)
|
||||
GetAll(ctx context.Context, user *authentication.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
|
||||
Delete(ctx context.Context, user *authentication.User, id string) error
|
||||
Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||
Update(ctx context.Context, user *auth_types.User, transaction types.Transaction) (*types.Transaction, error)
|
||||
Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error)
|
||||
GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error)
|
||||
Delete(ctx context.Context, user *auth_types.User, id string) error
|
||||
|
||||
RecalculateBalances(ctx context.Context, user *authentication.User) error
|
||||
RecalculateBalances(ctx context.Context, user *auth_types.User) error
|
||||
}
|
||||
|
||||
type TransactionImpl struct {
|
||||
db *sqlx.DB
|
||||
clock Clock
|
||||
random Random
|
||||
clock core.Clock
|
||||
random core.Random
|
||||
}
|
||||
|
||||
func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
|
||||
func NewTransaction(db *sqlx.DB, random core.Random, clock core.Clock) Transaction {
|
||||
return TransactionImpl{
|
||||
db: db,
|
||||
clock: clock,
|
||||
@@ -41,9 +42,9 @@ func NewTransaction(db *sqlx.DB, random Random, clock Clock) Transaction {
|
||||
}
|
||||
}
|
||||
|
||||
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User, transactionInput types.Transaction) (*types.Transaction, error) {
|
||||
func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *auth_types.User, transactionInput types.Transaction) (*types.Transaction, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
var err error
|
||||
@@ -108,9 +109,9 @@ func (s TransactionImpl) Add(ctx context.Context, tx *sqlx.Tx, user *types.User,
|
||||
return transaction, nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) Update(ctx context.Context, user *types.User, input types.Transaction) (*types.Transaction, error) {
|
||||
func (s TransactionImpl) Update(ctx context.Context, user *auth_types.User, input types.Transaction) (*types.Transaction, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -126,10 +127,10 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ
|
||||
err = tx.GetContext(ctx, transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, input.Id)
|
||||
err = db.TransformAndLogDbError(ctx, "transaction Update", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest)
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, core.ErrBadRequest)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
if transaction.Error == nil && transaction.AccountId != nil {
|
||||
@@ -207,32 +208,32 @@ func (s TransactionImpl) Update(ctx context.Context, user *types.User, input typ
|
||||
return transaction, nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) Get(ctx context.Context, user *types.User, id string) (*types.Transaction, error) {
|
||||
func (s TransactionImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.Transaction, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
uuid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transaction get", "err", err)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
var transaction types.Transaction
|
||||
err = s.db.GetContext(ctx, &transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError(ctx, "transaction Get", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest)
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("transaction %v not found: %w", id, core.ErrBadRequest)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
return &transaction, nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
|
||||
func (s TransactionImpl) GetAll(ctx context.Context, user *auth_types.User, filter types.TransactionItemsFilter) ([]*types.Transaction, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -278,14 +279,14 @@ func (s TransactionImpl) GetAll(ctx context.Context, user *types.User, filter ty
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||
func (s TransactionImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
return core.ErrUnauthorized
|
||||
}
|
||||
uuid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transaction delete", "err", err)
|
||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -311,7 +312,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
|
||||
WHERE id = ?
|
||||
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
|
||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -323,7 +324,7 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
|
||||
WHERE id = ?
|
||||
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||
err = db.TransformAndLogDbError(ctx, "transaction Delete", r, err)
|
||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -343,9 +344,9 @@ func (s TransactionImpl) Delete(ctx context.Context, user *types.User, id string
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.User) error {
|
||||
func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *auth_types.User) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
return core.ErrUnauthorized
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -362,7 +363,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
|
||||
SET current_balance = 0
|
||||
WHERE user_id = ?`, user.Id)
|
||||
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
|
||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -371,7 +372,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
|
||||
SET current_balance = 0
|
||||
WHERE user_id = ?`, user.Id)
|
||||
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", r, err)
|
||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -380,7 +381,7 @@ func (s TransactionImpl) RecalculateBalances(ctx context.Context, user *types.Us
|
||||
FROM "transaction"
|
||||
WHERE user_id = ?`, user.Id)
|
||||
err = db.TransformAndLogDbError(ctx, "transaction RecalculateBalances", nil, err)
|
||||
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||
if err != nil && !errors.Is(err, core.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
@@ -459,7 +460,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
|
||||
if oldTransaction == nil {
|
||||
id, err = s.random.UUID(ctx)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
createdAt = s.clock.Now()
|
||||
createdBy = userId
|
||||
@@ -480,7 +481,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
|
||||
}
|
||||
if rowCount == 0 {
|
||||
slog.ErrorContext(ctx, "transaction validate", "err", err)
|
||||
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,13 +490,13 @@ func (s TransactionImpl) validateAndEnrichTransaction(ctx context.Context, tx *s
|
||||
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, input.TreasureChestId, userId)
|
||||
err = db.TransformAndLogDbError(ctx, "transaction validate", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if treasureChest.ParentId == nil {
|
||||
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"spend-sparrow/internal/auth_types"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/db"
|
||||
"spend-sparrow/internal/types"
|
||||
"strconv"
|
||||
@@ -16,24 +18,24 @@ import (
|
||||
)
|
||||
|
||||
type TransactionRecurring interface {
|
||||
Add(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||
Update(ctx context.Context, user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||
GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error)
|
||||
GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error)
|
||||
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||
Delete(ctx context.Context, user *types.User, id string) error
|
||||
Add(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||
Update(ctx context.Context, user *auth_types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||
GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error)
|
||||
GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error)
|
||||
GetAllByTreasureChest(ctx context.Context, user *auth_types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||
Delete(ctx context.Context, user *auth_types.User, id string) error
|
||||
|
||||
GenerateTransactions(ctx context.Context) error
|
||||
}
|
||||
|
||||
type TransactionRecurringImpl struct {
|
||||
db *sqlx.DB
|
||||
clock Clock
|
||||
random Random
|
||||
clock core.Clock
|
||||
random core.Random
|
||||
transaction Transaction
|
||||
}
|
||||
|
||||
func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transaction Transaction) TransactionRecurring {
|
||||
func NewTransactionRecurring(db *sqlx.DB, random core.Random, clock core.Clock, transaction Transaction) TransactionRecurring {
|
||||
return TransactionRecurringImpl{
|
||||
db: db,
|
||||
clock: clock,
|
||||
@@ -43,11 +45,11 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, transactio
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) Add(ctx context.Context,
|
||||
user *types.User,
|
||||
user *auth_types.User,
|
||||
transactionRecurringInput types.TransactionRecurringInput,
|
||||
) (*types.TransactionRecurring, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -85,16 +87,16 @@ func (s TransactionRecurringImpl) Add(ctx context.Context,
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) Update(ctx context.Context,
|
||||
user *types.User,
|
||||
user *auth_types.User,
|
||||
input types.TransactionRecurringInput,
|
||||
) (*types.TransactionRecurring, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
uuid, err := uuid.Parse(input.Id)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transactionRecurring update", "err", err)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -110,10 +112,10 @@ func (s TransactionRecurringImpl) Update(ctx context.Context,
|
||||
err = tx.GetContext(ctx, transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError(ctx, "transactionRecurring Update", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest)
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, core.ErrBadRequest)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
transactionRecurring, err = s.validateAndEnrichTransactionRecurring(ctx, tx, transactionRecurring, user.Id, input)
|
||||
@@ -149,9 +151,9 @@ func (s TransactionRecurringImpl) Update(ctx context.Context,
|
||||
return transactionRecurring, nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TransactionRecurring, error) {
|
||||
func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TransactionRecurring, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||
@@ -169,15 +171,15 @@ func (s TransactionRecurringImpl) GetAll(ctx context.Context, user *types.User)
|
||||
return transactionRecurrings, nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
||||
func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *auth_types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
accountUuid, err := uuid.Parse(accountId)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transactionRecurring GetAllByAccount", "err", err)
|
||||
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -193,10 +195,10 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ
|
||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
||||
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByAccount", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest)
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("account %v not found: %w", accountId, core.ErrBadRequest)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||
@@ -222,17 +224,17 @@ func (s TransactionRecurringImpl) GetAllByAccount(ctx context.Context, user *typ
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
||||
user *types.User,
|
||||
user *auth_types.User,
|
||||
treasureChestId string,
|
||||
) ([]*types.TransactionRecurring, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
treasureChestUuid, err := uuid.Parse(treasureChestId)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transactionRecurring GetAllByTreasureChest", "err", err)
|
||||
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -248,10 +250,10 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
||||
err = db.TransformAndLogDbError(ctx, "transactionRecurring GetAllByTreasureChest", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest)
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, core.ErrBadRequest)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
transactionRecurrings := make([]*types.TransactionRecurring, 0)
|
||||
@@ -276,14 +278,14 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(ctx context.Context,
|
||||
return transactionRecurrings, nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User, id string) error {
|
||||
func (s TransactionRecurringImpl) Delete(ctx context.Context, user *auth_types.User, id string) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
return core.ErrUnauthorized
|
||||
}
|
||||
uuid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transactionRecurring delete", "err", err)
|
||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -339,7 +341,7 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) erro
|
||||
}
|
||||
|
||||
for _, transactionRecurring := range recurringTransactions {
|
||||
user := &types.User{
|
||||
user := &auth_types.User{
|
||||
Id: transactionRecurring.UserId,
|
||||
}
|
||||
transaction := types.Transaction{
|
||||
@@ -397,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||
if oldTransactionRecurring == nil {
|
||||
id, err = s.random.UUID(ctx)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
createdAt = s.clock.Now()
|
||||
createdBy = userId
|
||||
@@ -416,7 +418,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||
temp, err := uuid.Parse(input.AccountId)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||
return nil, fmt.Errorf("could not parse accountId: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse accountId: %w", core.ErrBadRequest)
|
||||
}
|
||||
accountUuid = &temp
|
||||
err = tx.GetContext(ctx, &rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, userId)
|
||||
@@ -426,7 +428,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||
}
|
||||
if rowCount == 0 {
|
||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("account not found: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
hasAccount = true
|
||||
@@ -436,37 +438,37 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||
temp, err := uuid.Parse(input.TreasureChestId)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||
return nil, fmt.Errorf("could not parse treasureChestId: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse treasureChestId: %w", core.ErrBadRequest)
|
||||
}
|
||||
treasureChestUuid = &temp
|
||||
var treasureChest types.TreasureChest
|
||||
err = tx.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
||||
err = db.TransformAndLogDbError(ctx, "transactionRecurring validate", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("treasure chest not found: %w", core.ErrBadRequest)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if treasureChest.ParentId == nil {
|
||||
return nil, fmt.Errorf("treasure chest is a group: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("treasure chest is a group: %w", core.ErrBadRequest)
|
||||
}
|
||||
hasTreasureChest = true
|
||||
}
|
||||
|
||||
if !hasAccount && !hasTreasureChest {
|
||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||
return nil, fmt.Errorf("either account or treasure chest is required: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("either account or treasure chest is required: %w", core.ErrBadRequest)
|
||||
}
|
||||
if hasAccount && hasTreasureChest {
|
||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||
return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("either account or treasure chest is required, not both: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
valueFloat, err := strconv.ParseFloat(input.Value, 64)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse value: %w", core.ErrBadRequest)
|
||||
}
|
||||
value := int64(math.Round(valueFloat * DECIMALS_MULTIPLIER))
|
||||
|
||||
@@ -485,18 +487,18 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
||||
intervalMonths, err = strconv.ParseInt(input.IntervalMonths, 10, 0)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||
return nil, fmt.Errorf("could not parse intervalMonths: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse intervalMonths: %w", core.ErrBadRequest)
|
||||
}
|
||||
if intervalMonths < 1 {
|
||||
slog.ErrorContext(ctx, "transactionRecurring validate", "err", err)
|
||||
return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("intervalMonths needs to be greater than 0: %w", core.ErrBadRequest)
|
||||
}
|
||||
var nextExecution *time.Time = nil
|
||||
if input.NextExecution != "" {
|
||||
t, err := time.Parse("2006-01-02", input.NextExecution)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "transaction validate", "err", err)
|
||||
return nil, fmt.Errorf("could not parse timestamp: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse timestamp: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
t = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"spend-sparrow/internal/auth_types"
|
||||
"spend-sparrow/internal/core"
|
||||
"spend-sparrow/internal/db"
|
||||
"spend-sparrow/internal/types"
|
||||
|
||||
@@ -14,20 +16,20 @@ import (
|
||||
)
|
||||
|
||||
type TreasureChest interface {
|
||||
Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error)
|
||||
Update(ctx context.Context, user *types.User, id, parentId, name string) (*types.TreasureChest, error)
|
||||
Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error)
|
||||
GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error)
|
||||
Delete(ctx context.Context, user *types.User, id string) error
|
||||
Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error)
|
||||
Update(ctx context.Context, user *auth_types.User, id, parentId, name string) (*types.TreasureChest, error)
|
||||
Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error)
|
||||
GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error)
|
||||
Delete(ctx context.Context, user *auth_types.User, id string) error
|
||||
}
|
||||
|
||||
type TreasureChestImpl struct {
|
||||
db *sqlx.DB
|
||||
clock Clock
|
||||
random Random
|
||||
clock core.Clock
|
||||
random core.Random
|
||||
}
|
||||
|
||||
func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
|
||||
func NewTreasureChest(db *sqlx.DB, random core.Random, clock core.Clock) TreasureChest {
|
||||
return TreasureChestImpl{
|
||||
db: db,
|
||||
clock: clock,
|
||||
@@ -35,14 +37,14 @@ func NewTreasureChest(db *sqlx.DB, random Random, clock Clock) TreasureChest {
|
||||
}
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId, name string) (*types.TreasureChest, error) {
|
||||
func (s TreasureChestImpl) Add(ctx context.Context, user *auth_types.User, parentId, name string) (*types.TreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
newId, err := s.random.UUID(ctx)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
err = ValidateString(name, "name")
|
||||
@@ -57,7 +59,7 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId,
|
||||
return nil, err
|
||||
}
|
||||
if parent.ParentId != nil {
|
||||
return nil, fmt.Errorf("only a depth of 1 allowed: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("only a depth of 1 allowed: %w", core.ErrBadRequest)
|
||||
}
|
||||
parentUuid = &parent.Id
|
||||
}
|
||||
@@ -88,9 +90,9 @@ func (s TreasureChestImpl) Add(ctx context.Context, user *types.User, parentId,
|
||||
return treasureChest, nil
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
|
||||
func (s TreasureChestImpl) Update(ctx context.Context, user *auth_types.User, idStr, parentId, name string) (*types.TreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
err := ValidateString(name, "name")
|
||||
if err != nil {
|
||||
@@ -99,7 +101,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "treasureChest update", "err", err)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -115,10 +117,10 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
|
||||
err = tx.GetContext(ctx, treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
||||
err = db.TransformAndLogDbError(ctx, "treasureChest Update", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
var parentUuid *uuid.UUID
|
||||
@@ -134,7 +136,7 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
|
||||
return nil, err
|
||||
}
|
||||
if parent.ParentId != nil || childCount > 0 {
|
||||
return nil, fmt.Errorf("only one level allowed: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("only one level allowed: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
parentUuid = &parent.Id
|
||||
@@ -170,32 +172,32 @@ func (s TreasureChestImpl) Update(ctx context.Context, user *types.User, idStr,
|
||||
return treasureChest, nil
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) Get(ctx context.Context, user *types.User, id string) (*types.TreasureChest, error) {
|
||||
func (s TreasureChestImpl) Get(ctx context.Context, user *auth_types.User, id string) (*types.TreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
uuid, err := uuid.Parse(id)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "treasureChest get", "err", err)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
return nil, fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
var treasureChest types.TreasureChest
|
||||
err = s.db.GetContext(ctx, &treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||
err = db.TransformAndLogDbError(ctx, "treasureChest Get", nil, err)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
if errors.Is(err, core.ErrNotFound) {
|
||||
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
|
||||
}
|
||||
return nil, types.ErrInternal
|
||||
return nil, core.ErrInternal
|
||||
}
|
||||
|
||||
return &treasureChest, nil
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*types.TreasureChest, error) {
|
||||
func (s TreasureChestImpl) GetAll(ctx context.Context, user *auth_types.User) ([]*types.TreasureChest, error) {
|
||||
if user == nil {
|
||||
return nil, ErrUnauthorized
|
||||
return nil, core.ErrUnauthorized
|
||||
}
|
||||
|
||||
treasureChests := make([]*types.TreasureChest, 0)
|
||||
@@ -208,14 +210,14 @@ func (s TreasureChestImpl) GetAll(ctx context.Context, user *types.User) ([]*typ
|
||||
return sortTreasureChests(treasureChests), nil
|
||||
}
|
||||
|
||||
func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr string) error {
|
||||
func (s TreasureChestImpl) Delete(ctx context.Context, user *auth_types.User, idStr string) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
return core.ErrUnauthorized
|
||||
}
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "treasureChest delete", "err", err)
|
||||
return fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||
return fmt.Errorf("could not parse Id: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -235,7 +237,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
|
||||
}
|
||||
|
||||
if childCount > 0 {
|
||||
return fmt.Errorf("treasure chest has children: %w", ErrBadRequest)
|
||||
return fmt.Errorf("treasure chest has children: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
transactionsCount := 0
|
||||
@@ -247,7 +249,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
|
||||
return err
|
||||
}
|
||||
if transactionsCount > 0 {
|
||||
return fmt.Errorf("treasure chest has transactions: %w", ErrBadRequest)
|
||||
return fmt.Errorf("treasure chest has transactions: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
recurringCount := 0
|
||||
@@ -259,7 +261,7 @@ func (s TreasureChestImpl) Delete(ctx context.Context, user *types.User, idStr s
|
||||
return err
|
||||
}
|
||||
if recurringCount > 0 {
|
||||
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest)
|
||||
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", core.ErrBadRequest)
|
||||
}
|
||||
|
||||
r, err := tx.ExecContext(ctx, `DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
||||
|
||||
Reference in New Issue
Block a user