feat(auth): #182 cleanup expired tokens
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 5m19s
Build and Push Docker Image / Build-And-Push-Docker-Image (push) Successful in 5m7s

This commit was merged in pull request #184.
This commit is contained in:
2025-06-16 22:42:23 +02:00
parent 3df9fab25b
commit 596cc602d0
5 changed files with 76 additions and 66 deletions

View File

@@ -30,7 +30,8 @@ type Auth interface {
GetSession(ctx context.Context, sessionId string) (*types.Session, error)
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
DeleteSession(ctx context.Context, sessionId string) error
DeleteOldSessions(ctx context.Context, userId uuid.UUID) error
DeleteOldSessions(ctx context.Context) error
DeleteOldTokens(ctx context.Context) error
}
type AuthSqlite struct {
@@ -369,18 +370,6 @@ func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*type
return sessions, nil
}
func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM session
WHERE expires_at < datetime('now')
AND user_id = ?`, userId)
if err != nil {
slog.Error("Could not delete old sessions", "err", err)
return types.ErrInternal
}
return nil
}
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
if sessionId != "" {
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
@@ -392,3 +381,25 @@ func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error
return nil
}
func (db AuthSqlite) DeleteOldSessions(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM session
WHERE expires_at < datetime('now')`)
if err != nil {
slog.Error("Could not delete old sessions", "err", err)
return types.ErrInternal
}
return nil
}
func (db AuthSqlite) DeleteOldTokens(ctx context.Context) error {
_, err := db.db.ExecContext(ctx, `
DELETE FROM token
WHERE expires_at < datetime('now')`)
if err != nil {
slog.Error("Could not delete old tokens", "err", err)
return types.ErrInternal
}
return nil
}

View File

@@ -64,7 +64,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
// init server
httpServer := &http.Server{
Addr: ":" + serverSettings.Port,
Handler: createHandler(database, serverSettings),
Handler: createHandlerWithServices(ctx, database, serverSettings),
ReadHeaderTimeout: 2 * time.Second,
}
go startServer(httpServer)
@@ -73,6 +73,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
var wg sync.WaitGroup
wg.Add(1)
go shutdownServer(httpServer, ctx, &wg)
wg.Wait()
return nil
@@ -102,7 +103,7 @@ func shutdownServer(s *http.Server, ctx context.Context, wg *sync.WaitGroup) {
}
}
func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
var router = http.NewServeMux()
authDb := db.NewAuthSqlite(d)
@@ -126,6 +127,8 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
transactionRecurringHandler := handler.NewTransactionRecurring(transactionRecurringService, render)
go dailyTaskTimer(ctx, transactionRecurringService, authService)
indexHandler.Handle(router)
accountHandler.Handle(router)
treasureChestHandler.Handle(router)
@@ -138,7 +141,6 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
wrapper := middleware.Wrapper(
router,
middleware.GenerateRecurringTransactions(transactionRecurringService),
middleware.SecurityHeaders(serverSettings),
middleware.CacheControl,
middleware.CrossSiteRequestForgery(authService),
@@ -151,3 +153,24 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
return wrapper
}
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
runDailyTasks(ctx, transactionRecurring, auth)
ticker := time.NewTicker(24 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
runDailyTasks(ctx, transactionRecurring, auth)
}
}
}
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
slog.Info("Running daily tasks")
_ = transactionRecurring.GenerateTransactions(ctx)
_ = auth.CleanupSessionsAndTokens(ctx)
}

View File

@@ -1,22 +0,0 @@
package middleware
import (
"net/http"
"spend-sparrow/internal/service"
)
func GenerateRecurringTransactions(transactionRecurring service.TransactionRecurring) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := GetUser(r)
if user == nil || r.Method != http.MethodGet {
next.ServeHTTP(w, r)
return
}
_ = transactionRecurring.GenerateTransactions(r.Context(), user)
next.ServeHTTP(w, r)
})
}
}

View File

@@ -44,6 +44,8 @@ type Auth interface {
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
CleanupSessionsAndTokens(ctx context.Context) error
}
type AuthImpl struct {
@@ -80,17 +82,28 @@ func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, emai
return nil, nil, ErrInvalidCredentials
}
err = service.cleanUpSessionWithTokens(ctx, session)
newSession, err := service.createSession(ctx, user.Id)
if err != nil {
return nil, nil, types.ErrInternal
}
session, err = service.createSession(ctx, user.Id)
err = service.db.DeleteSession(ctx, session.Id)
if err != nil {
return nil, nil, types.ErrInternal
}
return session, user, nil
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
if err != nil {
return nil, nil, types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(ctx, token.Token)
if err != nil {
return nil, nil, types.ErrInternal
}
}
return newSession, user, nil
}
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
@@ -445,26 +458,16 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session
return tokenStr, nil
}
func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error {
if session == nil {
return nil
}
err := service.db.DeleteSession(ctx, session.Id)
func (service AuthImpl) CleanupSessionsAndTokens(ctx context.Context) error {
err := service.db.DeleteOldSessions(ctx)
if err != nil {
return types.ErrInternal
}
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
err = service.db.DeleteOldTokens(ctx)
if err != nil {
return types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(ctx, token.Token)
if err != nil {
return types.ErrInternal
}
}
return nil
}
@@ -475,11 +478,6 @@ func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*t
return nil, types.ErrInternal
}
err = service.db.DeleteOldSessions(ctx, userId)
if err != nil {
return nil, types.ErrInternal
}
createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour)

View File

@@ -23,7 +23,7 @@ type TransactionRecurring interface {
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(ctx context.Context, user *types.User, id string) error
GenerateTransactions(ctx context.Context, user *types.User) error
GenerateTransactions(ctx context.Context) error
}
type TransactionRecurringImpl struct {
@@ -317,10 +317,7 @@ func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User,
return nil
}
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error {
if user == nil {
return ErrUnauthorized
}
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) error {
now := s.clock.Now()
tx, err := s.db.BeginTxx(ctx, nil)
@@ -334,14 +331,17 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user
recurringTransactions := make([]*types.TransactionRecurring, 0)
err = tx.SelectContext(ctx, &recurringTransactions, `
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
user.Id, now)
SELECT * FROM transaction_recurring WHERE next_execution <= ?`,
now)
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
if err != nil {
return err
}
for _, transactionRecurring := range recurringTransactions {
user := &types.User{
Id: transactionRecurring.UserId,
}
transaction := types.Transaction{
Timestamp: *transactionRecurring.NextExecution,
Party: transactionRecurring.Party,