feat(auth): #182 cleanup expired tokens
This commit was merged in pull request #184.
This commit is contained in:
@@ -30,7 +30,8 @@ type Auth interface {
|
||||
GetSession(ctx context.Context, sessionId string) (*types.Session, error)
|
||||
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
|
||||
DeleteSession(ctx context.Context, sessionId string) error
|
||||
DeleteOldSessions(ctx context.Context, userId uuid.UUID) error
|
||||
DeleteOldSessions(ctx context.Context) error
|
||||
DeleteOldTokens(ctx context.Context) error
|
||||
}
|
||||
|
||||
type AuthSqlite struct {
|
||||
@@ -369,18 +370,6 @@ func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*type
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error {
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
DELETE FROM session
|
||||
WHERE expires_at < datetime('now')
|
||||
AND user_id = ?`, userId)
|
||||
if err != nil {
|
||||
slog.Error("Could not delete old sessions", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
|
||||
if sessionId != "" {
|
||||
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
|
||||
@@ -392,3 +381,25 @@ func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) DeleteOldSessions(ctx context.Context) error {
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
DELETE FROM session
|
||||
WHERE expires_at < datetime('now')`)
|
||||
if err != nil {
|
||||
slog.Error("Could not delete old sessions", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db AuthSqlite) DeleteOldTokens(ctx context.Context) error {
|
||||
_, err := db.db.ExecContext(ctx, `
|
||||
DELETE FROM token
|
||||
WHERE expires_at < datetime('now')`)
|
||||
if err != nil {
|
||||
slog.Error("Could not delete old tokens", "err", err)
|
||||
return types.ErrInternal
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
|
||||
// init server
|
||||
httpServer := &http.Server{
|
||||
Addr: ":" + serverSettings.Port,
|
||||
Handler: createHandler(database, serverSettings),
|
||||
Handler: createHandlerWithServices(ctx, database, serverSettings),
|
||||
ReadHeaderTimeout: 2 * time.Second,
|
||||
}
|
||||
go startServer(httpServer)
|
||||
@@ -73,6 +73,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go shutdownServer(httpServer, ctx, &wg)
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
@@ -102,7 +103,7 @@ func shutdownServer(s *http.Server, ctx context.Context, wg *sync.WaitGroup) {
|
||||
}
|
||||
}
|
||||
|
||||
func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||
func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||
var router = http.NewServeMux()
|
||||
|
||||
authDb := db.NewAuthSqlite(d)
|
||||
@@ -126,6 +127,8 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
|
||||
transactionRecurringHandler := handler.NewTransactionRecurring(transactionRecurringService, render)
|
||||
|
||||
go dailyTaskTimer(ctx, transactionRecurringService, authService)
|
||||
|
||||
indexHandler.Handle(router)
|
||||
accountHandler.Handle(router)
|
||||
treasureChestHandler.Handle(router)
|
||||
@@ -138,7 +141,6 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||
|
||||
wrapper := middleware.Wrapper(
|
||||
router,
|
||||
middleware.GenerateRecurringTransactions(transactionRecurringService),
|
||||
middleware.SecurityHeaders(serverSettings),
|
||||
middleware.CacheControl,
|
||||
middleware.CrossSiteRequestForgery(authService),
|
||||
@@ -151,3 +153,24 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||
|
||||
return wrapper
|
||||
}
|
||||
|
||||
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
|
||||
runDailyTasks(ctx, transactionRecurring, auth)
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
runDailyTasks(ctx, transactionRecurring, auth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
|
||||
slog.Info("Running daily tasks")
|
||||
_ = transactionRecurring.GenerateTransactions(ctx)
|
||||
_ = auth.CleanupSessionsAndTokens(ctx)
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"spend-sparrow/internal/service"
|
||||
)
|
||||
|
||||
func GenerateRecurringTransactions(transactionRecurring service.TransactionRecurring) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user := GetUser(r)
|
||||
if user == nil || r.Method != http.MethodGet {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
_ = transactionRecurring.GenerateTransactions(r.Context(), user)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -44,6 +44,8 @@ type Auth interface {
|
||||
|
||||
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
|
||||
GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
|
||||
|
||||
CleanupSessionsAndTokens(ctx context.Context) error
|
||||
}
|
||||
|
||||
type AuthImpl struct {
|
||||
@@ -80,17 +82,28 @@ func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, emai
|
||||
return nil, nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
err = service.cleanUpSessionWithTokens(ctx, session)
|
||||
newSession, err := service.createSession(ctx, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
|
||||
session, err = service.createSession(ctx, user.Id)
|
||||
err = service.db.DeleteSession(ctx, session.Id)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
|
||||
return session, user, nil
|
||||
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
for _, token := range tokens {
|
||||
err = service.db.DeleteToken(ctx, token.Token)
|
||||
if err != nil {
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
}
|
||||
|
||||
return newSession, user, nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
|
||||
@@ -445,26 +458,16 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session
|
||||
return tokenStr, nil
|
||||
}
|
||||
|
||||
func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error {
|
||||
if session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := service.db.DeleteSession(ctx, session.Id)
|
||||
func (service AuthImpl) CleanupSessionsAndTokens(ctx context.Context) error {
|
||||
err := service.db.DeleteOldSessions(ctx)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
|
||||
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||
err = service.db.DeleteOldTokens(ctx)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
for _, token := range tokens {
|
||||
err = service.db.DeleteToken(ctx, token.Token)
|
||||
if err != nil {
|
||||
return types.ErrInternal
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -475,11 +478,6 @@ func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*t
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
err = service.db.DeleteOldSessions(ctx, userId)
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
createAt := service.clock.Now()
|
||||
expiresAt := createAt.Add(24 * time.Hour)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ type TransactionRecurring interface {
|
||||
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||
Delete(ctx context.Context, user *types.User, id string) error
|
||||
|
||||
GenerateTransactions(ctx context.Context, user *types.User) error
|
||||
GenerateTransactions(ctx context.Context) error
|
||||
}
|
||||
|
||||
type TransactionRecurringImpl struct {
|
||||
@@ -317,10 +317,7 @@ func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error {
|
||||
if user == nil {
|
||||
return ErrUnauthorized
|
||||
}
|
||||
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) error {
|
||||
now := s.clock.Now()
|
||||
|
||||
tx, err := s.db.BeginTxx(ctx, nil)
|
||||
@@ -334,14 +331,17 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user
|
||||
|
||||
recurringTransactions := make([]*types.TransactionRecurring, 0)
|
||||
err = tx.SelectContext(ctx, &recurringTransactions, `
|
||||
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
|
||||
user.Id, now)
|
||||
SELECT * FROM transaction_recurring WHERE next_execution <= ?`,
|
||||
now)
|
||||
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, transactionRecurring := range recurringTransactions {
|
||||
user := &types.User{
|
||||
Id: transactionRecurring.UserId,
|
||||
}
|
||||
transaction := types.Transaction{
|
||||
Timestamp: *transactionRecurring.NextExecution,
|
||||
Party: transactionRecurring.Party,
|
||||
|
||||
Reference in New Issue
Block a user