feat(auth): #182 cleanup expired tokens
This commit was merged in pull request #184.
This commit is contained in:
@@ -30,7 +30,8 @@ type Auth interface {
|
|||||||
GetSession(ctx context.Context, sessionId string) (*types.Session, error)
|
GetSession(ctx context.Context, sessionId string) (*types.Session, error)
|
||||||
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
|
GetSessions(ctx context.Context, userId uuid.UUID) ([]*types.Session, error)
|
||||||
DeleteSession(ctx context.Context, sessionId string) error
|
DeleteSession(ctx context.Context, sessionId string) error
|
||||||
DeleteOldSessions(ctx context.Context, userId uuid.UUID) error
|
DeleteOldSessions(ctx context.Context) error
|
||||||
|
DeleteOldTokens(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthSqlite struct {
|
type AuthSqlite struct {
|
||||||
@@ -369,18 +370,6 @@ func (db AuthSqlite) GetSessions(ctx context.Context, userId uuid.UUID) ([]*type
|
|||||||
return sessions, nil
|
return sessions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteOldSessions(ctx context.Context, userId uuid.UUID) error {
|
|
||||||
_, err := db.db.ExecContext(ctx, `
|
|
||||||
DELETE FROM session
|
|
||||||
WHERE expires_at < datetime('now')
|
|
||||||
AND user_id = ?`, userId)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("Could not delete old sessions", "err", err)
|
|
||||||
return types.ErrInternal
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
|
func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error {
|
||||||
if sessionId != "" {
|
if sessionId != "" {
|
||||||
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
|
_, err := db.db.ExecContext(ctx, "DELETE FROM session WHERE session_id = ?", sessionId)
|
||||||
@@ -392,3 +381,25 @@ func (db AuthSqlite) DeleteSession(ctx context.Context, sessionId string) error
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db AuthSqlite) DeleteOldSessions(ctx context.Context) error {
|
||||||
|
_, err := db.db.ExecContext(ctx, `
|
||||||
|
DELETE FROM session
|
||||||
|
WHERE expires_at < datetime('now')`)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Could not delete old sessions", "err", err)
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db AuthSqlite) DeleteOldTokens(ctx context.Context) error {
|
||||||
|
_, err := db.db.ExecContext(ctx, `
|
||||||
|
DELETE FROM token
|
||||||
|
WHERE expires_at < datetime('now')`)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("Could not delete old tokens", "err", err)
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
|
|||||||
// init server
|
// init server
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: ":" + serverSettings.Port,
|
Addr: ":" + serverSettings.Port,
|
||||||
Handler: createHandler(database, serverSettings),
|
Handler: createHandlerWithServices(ctx, database, serverSettings),
|
||||||
ReadHeaderTimeout: 2 * time.Second,
|
ReadHeaderTimeout: 2 * time.Second,
|
||||||
}
|
}
|
||||||
go startServer(httpServer)
|
go startServer(httpServer)
|
||||||
@@ -73,6 +73,7 @@ func Run(ctx context.Context, database *sqlx.DB, migrationsPrefix string, env fu
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go shutdownServer(httpServer, ctx, &wg)
|
go shutdownServer(httpServer, ctx, &wg)
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -102,7 +103,7 @@ func shutdownServer(s *http.Server, ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
func createHandlerWithServices(ctx context.Context, d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
||||||
var router = http.NewServeMux()
|
var router = http.NewServeMux()
|
||||||
|
|
||||||
authDb := db.NewAuthSqlite(d)
|
authDb := db.NewAuthSqlite(d)
|
||||||
@@ -126,6 +127,8 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
|||||||
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
|
transactionHandler := handler.NewTransaction(transactionService, accountService, treasureChestService, render)
|
||||||
transactionRecurringHandler := handler.NewTransactionRecurring(transactionRecurringService, render)
|
transactionRecurringHandler := handler.NewTransactionRecurring(transactionRecurringService, render)
|
||||||
|
|
||||||
|
go dailyTaskTimer(ctx, transactionRecurringService, authService)
|
||||||
|
|
||||||
indexHandler.Handle(router)
|
indexHandler.Handle(router)
|
||||||
accountHandler.Handle(router)
|
accountHandler.Handle(router)
|
||||||
treasureChestHandler.Handle(router)
|
treasureChestHandler.Handle(router)
|
||||||
@@ -138,7 +141,6 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
|||||||
|
|
||||||
wrapper := middleware.Wrapper(
|
wrapper := middleware.Wrapper(
|
||||||
router,
|
router,
|
||||||
middleware.GenerateRecurringTransactions(transactionRecurringService),
|
|
||||||
middleware.SecurityHeaders(serverSettings),
|
middleware.SecurityHeaders(serverSettings),
|
||||||
middleware.CacheControl,
|
middleware.CacheControl,
|
||||||
middleware.CrossSiteRequestForgery(authService),
|
middleware.CrossSiteRequestForgery(authService),
|
||||||
@@ -151,3 +153,24 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dailyTaskTimer(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
|
||||||
|
runDailyTasks(ctx, transactionRecurring, auth)
|
||||||
|
ticker := time.NewTicker(24 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
runDailyTasks(ctx, transactionRecurring, auth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDailyTasks(ctx context.Context, transactionRecurring service.TransactionRecurring, auth service.Auth) {
|
||||||
|
slog.Info("Running daily tasks")
|
||||||
|
_ = transactionRecurring.GenerateTransactions(ctx)
|
||||||
|
_ = auth.CleanupSessionsAndTokens(ctx)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"spend-sparrow/internal/service"
|
|
||||||
)
|
|
||||||
|
|
||||||
func GenerateRecurringTransactions(transactionRecurring service.TransactionRecurring) func(http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := GetUser(r)
|
|
||||||
if user == nil || r.Method != http.MethodGet {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = transactionRecurring.GenerateTransactions(r.Context(), user)
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -44,6 +44,8 @@ type Auth interface {
|
|||||||
|
|
||||||
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
|
IsCsrfTokenValid(ctx context.Context, tokenStr string, sessionId string) bool
|
||||||
GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
|
GetCsrfToken(ctx context.Context, session *types.Session) (string, error)
|
||||||
|
|
||||||
|
CleanupSessionsAndTokens(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthImpl struct {
|
type AuthImpl struct {
|
||||||
@@ -80,17 +82,28 @@ func (service AuthImpl) SignIn(ctx context.Context, session *types.Session, emai
|
|||||||
return nil, nil, ErrInvalidCredentials
|
return nil, nil, ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.cleanUpSessionWithTokens(ctx, session)
|
newSession, err := service.createSession(ctx, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err = service.createSession(ctx, user.Id)
|
err = service.db.DeleteSession(ctx, session.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return session, user, nil
|
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, types.ErrInternal
|
||||||
|
}
|
||||||
|
for _, token := range tokens {
|
||||||
|
err = service.db.DeleteToken(ctx, token.Token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, types.ErrInternal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return newSession, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
|
func (service AuthImpl) SignInSession(ctx context.Context, sessionId string) (*types.Session, *types.User, error) {
|
||||||
@@ -445,26 +458,16 @@ func (service AuthImpl) GetCsrfToken(ctx context.Context, session *types.Session
|
|||||||
return tokenStr, nil
|
return tokenStr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) cleanUpSessionWithTokens(ctx context.Context, session *types.Session) error {
|
func (service AuthImpl) CleanupSessionsAndTokens(ctx context.Context) error {
|
||||||
if session == nil {
|
err := service.db.DeleteOldSessions(ctx)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := service.db.DeleteSession(ctx, session.Id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens, err := service.db.GetTokensBySessionIdAndType(ctx, session.Id, types.TokenTypeCsrf)
|
err = service.db.DeleteOldTokens(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
for _, token := range tokens {
|
|
||||||
err = service.db.DeleteToken(ctx, token.Token)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrInternal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -475,11 +478,6 @@ func (service AuthImpl) createSession(ctx context.Context, userId uuid.UUID) (*t
|
|||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.DeleteOldSessions(ctx, userId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.ErrInternal
|
|
||||||
}
|
|
||||||
|
|
||||||
createAt := service.clock.Now()
|
createAt := service.clock.Now()
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ type TransactionRecurring interface {
|
|||||||
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
GetAllByTreasureChest(ctx context.Context, user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||||
Delete(ctx context.Context, user *types.User, id string) error
|
Delete(ctx context.Context, user *types.User, id string) error
|
||||||
|
|
||||||
GenerateTransactions(ctx context.Context, user *types.User) error
|
GenerateTransactions(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TransactionRecurringImpl struct {
|
type TransactionRecurringImpl struct {
|
||||||
@@ -317,10 +317,7 @@ func (s TransactionRecurringImpl) Delete(ctx context.Context, user *types.User,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user *types.User) error {
|
func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context) error {
|
||||||
if user == nil {
|
|
||||||
return ErrUnauthorized
|
|
||||||
}
|
|
||||||
now := s.clock.Now()
|
now := s.clock.Now()
|
||||||
|
|
||||||
tx, err := s.db.BeginTxx(ctx, nil)
|
tx, err := s.db.BeginTxx(ctx, nil)
|
||||||
@@ -334,14 +331,17 @@ func (s TransactionRecurringImpl) GenerateTransactions(ctx context.Context, user
|
|||||||
|
|
||||||
recurringTransactions := make([]*types.TransactionRecurring, 0)
|
recurringTransactions := make([]*types.TransactionRecurring, 0)
|
||||||
err = tx.SelectContext(ctx, &recurringTransactions, `
|
err = tx.SelectContext(ctx, &recurringTransactions, `
|
||||||
SELECT * FROM transaction_recurring WHERE user_id = ? AND next_execution <= ?`,
|
SELECT * FROM transaction_recurring WHERE next_execution <= ?`,
|
||||||
user.Id, now)
|
now)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GenerateTransactions", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, transactionRecurring := range recurringTransactions {
|
for _, transactionRecurring := range recurringTransactions {
|
||||||
|
user := &types.User{
|
||||||
|
Id: transactionRecurring.UserId,
|
||||||
|
}
|
||||||
transaction := types.Transaction{
|
transaction := types.Transaction{
|
||||||
Timestamp: *transactionRecurring.NextExecution,
|
Timestamp: *transactionRecurring.NextExecution,
|
||||||
Party: transactionRecurring.Party,
|
Party: transactionRecurring.Party,
|
||||||
|
|||||||
Reference in New Issue
Block a user