feat(auth): #182 cleanup expired tokens
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 5m19s
Build and Push Docker Image / Build-And-Push-Docker-Image (push) Successful in 5m7s

This commit was merged in pull request #184.
This commit is contained in:
2025-06-16 22:42:23 +02:00
parent 3df9fab25b
commit 596cc602d0
5 changed files with 76 additions and 66 deletions

View File

@@ -30,7 +30,8 @@ type Auth interface {
GetSession(ctx context.Context, sessionId string) (*types.Session, error) GetSession(ctx context.Context, sessionId string) (*types.Session, error)
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error) GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
DeleteSession(ctx context.Context, sessionId string) error DeleteSession(ctx context.Context, sessionId string) error
DeleteOldSessions(ctx context.Context, userId uuid.UUID) error DeleteOldSessions(ctx context.Context) error
DeleteOldTokens(ctx context.Context) error
} }
type AuthSqlite struct { type AuthSqlite struct {
@@ -369,18 +370,6 @@ func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*type
return sessions, nil return sessions, nil
} }
func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM session
WHERE expires_at < datetime('now')
AND user_id = ?`, userId)
if err != nil {
slog.Error("Could not delete old sessions", "err", err)
return types.ErrInternal
}
return nil
}
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error { func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
if sessionId != "" { if sessionId != "" {
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId) _, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
@@ -392,3 +381,25 @@ func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error
return nil return nil
} }
func (db AuthSqlite) DeleteOldSessions(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM session
WHERE expires_at < datetime('now')`)
if err != nil {
slog.Error("Could not delete old sessions", "err", err)
return types.ErrInternal
}
return nil
}
func (db AuthSqlite) DeleteOldTokens(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM token
WHERE expires_at < datetime('now')`)
if err != nil {
slog.Error("Could not delete old tokens", "err", err)
return types.ErrInternal
}
return nil
}

View File

@@ -64,7 +64,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
// init server // init server
httpServer := &http.Server{ httpServer := &http.Server{
Addr: ":" + serverSettings.Port, Addr: ":" + serverSettings.Port,
Handler: createHandler(database, serverSettings), Handler: createHandlerWithServices(ctx, database, serverSettings),
ReadHeaderTimeout: 2 * time.Second, ReadHeaderTimeout: 2 * time.Second,
} }
go startServer(httpServer) go startServer(httpServer)
@@ -73,6 +73,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go shutdownServer(httpServer, ctx, &wg) go shutdownServer(httpServer, ctx, &wg)
wg.Wait() wg.Wait()
return nil return nil
@@ -102,7 +103,7 @@ func shutdownServer(s *http.Server, ctx context.Context, wg *sync.WaitGroup) {
} }
} }
func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler { func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
var router = http.NewServeMux() var router = http.NewServeMux()
authDb := db.NewAuthSqlite(d) authDb := db.NewAuthSqlite(d)
@@ -126,6 +127,8 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render) transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
transactionRecurringHandler := handler.NewTransactionRecurring(transactionRecurringService, render) transactionRecurringHandler := handler.NewTransactionRecurring(transactionRecurringService, render)
go dailyTaskTimer(ctx, transactionRecurringService, authService)
indexHandler.Handle(router) indexHandler.Handle(router)
accountHandler.Handle(router) accountHandler.Handle(router)
treasureChestHandler.Handle(router) treasureChestHandler.Handle(router)
@@ -138,7 +141,6 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
wrapper := middleware.Wrapper( wrapper := middleware.Wrapper(
router, router,
middleware.GenerateRecurringTransactions(transactionRecurringService),
middleware.SecurityHeaders(serverSettings), middleware.SecurityHeaders(serverSettings),
middleware.CacheControl, middleware.CacheControl,
middleware.CrossSiteRequestForgery(authService), middleware.CrossSiteRequestForgery(authService),
@@ -151,3 +153,24 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
return wrapper return wrapper
} }
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
runDailyTasks(ctx, transactionRecurring, auth)
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
runDailyTasks(ctx, transactionRecurring, auth)
}
}
}
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
slog.Info("Running daily tasks")
_ = transactionRecurring.GenerateTransactions(ctx)
_ = auth.CleanupSessionsAndTokens(ctx)
}

View File

@@ -1,22 +0,0 @@
package middleware
import (
"net/http"
"spend-sparrow/internal/service"
)
func GenerateRecurringTransactions(transactionRecurring service.TransactionRecurring) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUser(r)
if user == nil || r.Method != http.MethodGet {
next.ServeHTTP(w, r)
return
}
_ = transactionRecurring.GenerateTransactions(r.Context(), user)
next.ServeHTTP(w, r)
})
}
}

View File

@@ -44,6 +44,8 @@ type Auth interface {
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
GetCsrfToken(ctx context.Context, session *types.Session) (string, error) GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
CleanupSessionsAndTokens(ctx context.Context) error
} }
type AuthImpl struct { type AuthImpl struct {
@@ -80,17 +82,28 @@ func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, emai
return nil, nil, ErrInvalidCredentials return nil, nil, ErrInvalidCredentials
} }
err = service.cleanUpSessionWithTokens(ctx, session) newSession, err := service.createSession(ctx, user.Id)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, types.ErrInternal
} }
session, err = service.createSession(ctx, user.Id) err = service.db.DeleteSession(ctx, session.Id)
if err != nil { if err != nil {
return nil, nil, types.ErrInternal return nil, nil, types.ErrInternal
} }
return session, user, nil tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
if err != nil {
return nil, nil, types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(ctx, token.Token)
if err != nil {
return nil, nil, types.ErrInternal
}
}
return newSession, user, nil
} }
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) { func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
@@ -445,26 +458,16 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session
return tokenStr, nil return tokenStr, nil
} }
func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error { func (service AuthImpl) CleanupSessionsAndTokens(ctx context.Context) error {
if session == nil { err := service.db.DeleteOldSessions(ctx)
return nil
}
err := service.db.DeleteSession(ctx, session.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf) err = service.db.DeleteOldTokens(ctx)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
for _, token := range tokens {
err = service.db.DeleteToken(ctx, token.Token)
if err != nil {
return types.ErrInternal
}
}
return nil return nil
} }
@@ -475,11 +478,6 @@ func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*t
return nil, types.ErrInternal return nil, types.ErrInternal
} }
err = service.db.DeleteOldSessions(ctx, userId)
if err != nil {
return nil, types.ErrInternal
}
createAt := service.clock.Now() createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)

View File

@@ -23,7 +23,7 @@ type TransactionRecurring interface {
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(ctx context.Context, user *types.User, id string) error Delete(ctx context.Context, user *types.User, id string) error
GenerateTransactions(ctx context.Context, user *types.User) error GenerateTransactions(ctx context.Context) error
} }
type TransactionRecurringImpl struct { type TransactionRecurringImpl struct {
@@ -317,10 +317,7 @@ func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User,
return nil return nil
} }
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error { func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) error {
if user == nil {
return ErrUnauthorized
}
now := s.clock.Now() now := s.clock.Now()
tx, err := s.db.BeginTxx(ctx, nil) tx, err := s.db.BeginTxx(ctx, nil)
@@ -334,14 +331,17 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user
recurringTransactions := make([]*types.TransactionRecurring, 0) recurringTransactions := make([]*types.TransactionRecurring, 0)
err = tx.SelectContext(ctx, &recurringTransactions, ` err = tx.SelectContext(ctx, &recurringTransactions, `
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`, SELECT * FROM transaction_recurring WHERE next_execution <= ?`,
user.Id, now) now)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err) err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
if err != nil { if err != nil {
return err return err
} }
for _, transactionRecurring := range recurringTransactions { for _, transactionRecurring := range recurringTransactions {
user := &types.User{
Id: transactionRecurring.UserId,
}
transaction := types.Transaction{ transaction := types.Transaction{
Timestamp: *transactionRecurring.NextExecution, Timestamp: *transactionRecurring.NextExecution,
Party: transactionRecurring.Party, Party: transactionRecurring.Party,