fix: lint errors
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 5m22s
Build and Push Docker Image / Build-And-Push-Docker-Image (push) Successful in 5m26s

This commit was merged in pull request #130.
This commit is contained in:
2025-05-25 16:36:30 +02:00
parent 2ba5ddd9f2
commit 128a2fc4d7
36 changed files with 1024 additions and 968 deletions

28
.golangci.yaml Normal file
View File

@@ -0,0 +1,28 @@
version: '2'
linters:
default: all
disable:
- wsl
- wrapcheck
- varnamelen
- revive # should probably be enabled
- nlreturn
- mnd # should probably be enabled
- lll # should probably be enabled
- ireturn # should probably be enabled
- interfacebloat
- iface
- goconst # should probably be enabled
- gocognit # should probably be enabled
- gochecknoglobals # should probably be enabled
- funlen
- maintidx
- exhaustruct
- dupword # should probably be enabled
- dupl # should probably be enabled
- depguard
- cyclop
- contextcheck
settings:
nestif:
min-complexity: 6

View File

@@ -1,6 +1,7 @@
package db package db
import ( import (
"errors"
"spend-sparrow/log" "spend-sparrow/log"
"spend-sparrow/types" "spend-sparrow/types"
@@ -89,7 +90,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
FROM user FROM user
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, ErrNotFound
} else { } else {
log.Error("SQL error GetUser: %v", err) log.Error("SQL error GetUser: %v", err)
@@ -116,7 +117,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
FROM user FROM user
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound return nil, ErrNotFound
} else { } else {
log.Error("SQL error GetUser %v", err) log.Error("SQL error GetUser %v", err)
@@ -128,7 +129,6 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
} }
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
tx, err := db.db.Begin() tx, err := db.db.Begin()
if err != nil { if err != nil {
log.Error("Could not start transaction: %v", err) log.Error("Could not start transaction: %v", err)
@@ -216,7 +216,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr) WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
log.Info("Token '%v' not found", token) log.Info("Token '%v' not found", token)
return nil, ErrNotFound return nil, ErrNotFound
} else { } else {
@@ -241,7 +241,6 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
} }
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) { func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(` query, err := db.db.Query(`
SELECT token, created_at, expires_at SELECT token, created_at, expires_at
FROM token FROM token
@@ -257,7 +256,6 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.
} }
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) { func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(` query, err := db.db.Query(`
SELECT token, created_at, expires_at SELECT token, created_at, expires_at
FROM token FROM token
@@ -325,7 +323,6 @@ func (db AuthSqlite) DeleteToken(token string) error {
} }
func (db AuthSqlite) InsertSession(session *types.Session) error { func (db AuthSqlite) InsertSession(session *types.Session) error {
_, err := db.db.Exec(` _, err := db.db.Exec(`
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt) VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
@@ -339,7 +336,6 @@ func (db AuthSqlite) InsertSession(session *types.Session) error {
} }
func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
createdAt time.Time createdAt time.Time
@@ -360,9 +356,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
} }
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) { func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
var sessions []*types.Session
sessions, err := db.db.Query(` err := db.db.Select(&sessions, `
SELECT session_id, created_at, expires_at SELECT *
FROM session FROM session
WHERE user_id = ?`, userId) WHERE user_id = ?`, userId)
if err != nil { if err != nil {
@@ -370,26 +366,7 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
return nil, types.ErrInternal return nil, types.ErrInternal
} }
var result []*types.Session return sessions, nil
for sessions.Next() {
var (
sessionId string
createdAt time.Time
expiresAt time.Time
)
err := sessions.Scan(&sessionId, &createdAt, &expiresAt)
if err != nil {
log.Error("Could not scan session: %v", err)
return nil, types.ErrInternal
}
session := types.NewSession(sessionId, userId, createdAt, expiresAt)
result = append(result, session)
}
return result, nil
} }
func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error { func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {

View File

@@ -1,6 +1,7 @@
package db package db_test
import ( import (
"spend-sparrow/db"
"spend-sparrow/types" "spend-sparrow/types"
"testing" "testing"
"time" "time"
@@ -8,26 +9,29 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func setupDb(t *testing.T) *sqlx.DB { func setupDb(t *testing.T) *sqlx.DB {
db, err := sqlx.Open("sqlite3", ":memory:") t.Helper()
d, err := sqlx.Open("sqlite3", ":memory:")
if err != nil { if err != nil {
t.Fatalf("Error opening database: %v", err) t.Fatalf("Error opening database: %v", err)
} }
t.Cleanup(func() { t.Cleanup(func() {
err := db.Close() err := d.Close()
if err != nil { if err != nil {
panic(err) panic(err)
} }
}) })
err = RunMigrations(db, "../") err = db.RunMigrations(d, "../")
if err != nil { if err != nil {
t.Fatalf("Error running migrations: %v", err) t.Fatalf("Error running migrations: %v", err)
} }
return db return d
} }
func TestUser(t *testing.T) { func TestUser(t *testing.T) {
@@ -35,55 +39,55 @@ func TestUser(t *testing.T) {
t.Run("should insert and get the same", func(t *testing.T) { t.Run("should insert and get the same", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(expected) err := underTest.InsertUser(expected)
assert.Nil(t, err) require.NoError(t, err)
actual, err := underTest.GetUser(expected.Id) actual, err := underTest.GetUser(expected.Id)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
actual, err = underTest.GetUserByEmail(expected.Email) actual, err = underTest.GetUserByEmail(expected.Email)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
}) })
t.Run("should return ErrNotFound", func(t *testing.T) { t.Run("should return ErrNotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
_, err := underTest.GetUserByEmail("nonExistentEmail") _, err := underTest.GetUserByEmail("nonExistentEmail")
assert.Equal(t, ErrNotFound, err) assert.Equal(t, db.ErrNotFound, err)
}) })
t.Run("should return ErrUserExist", func(t *testing.T) { t.Run("should return ErrUserExist", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(user)
assert.Nil(t, err) require.NoError(t, err)
err = underTest.InsertUser(user) err = underTest.InsertUser(user)
assert.Equal(t, ErrAlreadyExists, err) assert.Equal(t, db.ErrAlreadyExists, err)
}) })
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
@@ -98,37 +102,37 @@ func TestToken(t *testing.T) {
t.Run("should insert and get the same", func(t *testing.T) { t.Run("should insert and get the same", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt) expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(expected) err := underTest.InsertToken(expected)
assert.Nil(t, err) require.NoError(t, err)
actual, err := underTest.GetToken(expected.Token) actual, err := underTest.GetToken(expected.Token)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
expected.SessionId = "" expected.SessionId = ""
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type) actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals) assert.Equal(t, []*types.Token{expected}, actuals)
expected.SessionId = "sessionId" expected.SessionId = "sessionId"
expected.UserId = uuid.Nil expected.UserId = uuid.Nil
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type) actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals) assert.Equal(t, []*types.Token{expected}, actuals)
}) })
t.Run("should insert and return multiple tokens", func(t *testing.T) { t.Run("should insert and return multiple tokens", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
@@ -137,14 +141,14 @@ func TestToken(t *testing.T) {
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt) expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(expected1) err := underTest.InsertToken(expected1)
assert.Nil(t, err) require.NoError(t, err)
err = underTest.InsertToken(expected2) err = underTest.InsertToken(expected2)
assert.Nil(t, err) require.NoError(t, err)
expected1.UserId = uuid.Nil expected1.UserId = uuid.Nil
expected2.UserId = uuid.Nil expected2.UserId = uuid.Nil
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type) actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals) assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
expected1.SessionId = "" expected1.SessionId = ""
@@ -152,46 +156,45 @@ func TestToken(t *testing.T) {
expected1.UserId = userId expected1.UserId = userId
expected2.UserId = userId expected2.UserId = userId
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type) actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals) assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
}) })
t.Run("should return ErrNotFound", func(t *testing.T) { t.Run("should return ErrNotFound", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
_, err := underTest.GetToken("nonExistent") _, err := underTest.GetToken("nonExistent")
assert.Equal(t, ErrNotFound, err) assert.Equal(t, db.ErrNotFound, err)
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify) _, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify)
assert.Equal(t, ErrNotFound, err) assert.Equal(t, db.ErrNotFound, err)
_, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify) _, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify)
assert.Equal(t, ErrNotFound, err) assert.Equal(t, db.ErrNotFound, err)
}) })
t.Run("should return ErrAlreadyExists", func(t *testing.T) { t.Run("should return ErrAlreadyExists", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(user)
assert.Nil(t, err) require.NoError(t, err)
err = underTest.InsertUser(user) err = underTest.InsertUser(user)
assert.Equal(t, ErrAlreadyExists, err) assert.Equal(t, db.ErrAlreadyExists, err)
}) })
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
t.Parallel() t.Parallel()
db := setupDb(t) d := setupDb(t)
underTest := AuthSqlite{db: db} underTest := db.NewAuthSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)

View File

@@ -14,7 +14,7 @@ var (
func TransformAndLogDbError(module string, r sql.Result, err error) error { func TransformAndLogDbError(module string, r sql.Result, err error) error {
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return ErrNotFound return ErrNotFound
} }
log.Error("%v: %v", module, err) log.Error("%v: %v", module, err)

View File

@@ -77,7 +77,6 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
func (handler AuthImpl) handleSignIn() http.HandlerFunc { func (handler AuthImpl) handleSignIn() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) { user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) {
session := middleware.GetSession(r) session := middleware.GetSession(r)
email := r.FormValue("email") email := r.FormValue("email")
@@ -95,7 +94,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
}) })
if err != nil { if err != nil {
if err == service.ErrInvalidCredentials { if errors.Is(err, service.ErrInvalidCredentials) {
utils.TriggerToastWithStatus(w, r, "error", "Invalid email or password", http.StatusUnauthorized) utils.TriggerToastWithStatus(w, r, "error", "Invalid email or password", http.StatusUnauthorized)
} else { } else {
utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError) utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError)
@@ -166,7 +165,6 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc { func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token") token := r.URL.Query().Get("token")
err := handler.service.VerifyUserEmail(token) err := handler.service.VerifyUserEmail(token)
@@ -203,13 +201,14 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
}) })
if err != nil { if err != nil {
if errors.Is(err, types.ErrInternal) { switch {
case errors.Is(err, types.ErrInternal):
utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError) utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError)
return return
} else if errors.Is(err, service.ErrInvalidEmail) { case errors.Is(err, service.ErrInvalidEmail):
utils.TriggerToastWithStatus(w, r, "error", "The email provided is invalid", http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", "The email provided is invalid", http.StatusBadRequest)
return return
} else if errors.Is(err, service.ErrInvalidPassword) { case errors.Is(err, service.ErrInvalidPassword):
utils.TriggerToastWithStatus(w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest)
return return
} }
@@ -272,7 +271,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
err := handler.service.DeleteAccount(user, password) err := handler.service.DeleteAccount(user, password)
if err != nil { if err != nil {
if err == service.ErrInvalidCredentials { if errors.Is(err, service.ErrInvalidCredentials) {
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
} else { } else {
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
@@ -286,7 +285,6 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc { func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
isPasswordReset := r.URL.Query().Has("token") isPasswordReset := r.URL.Query().Has("token")
user := middleware.GetUser(r) user := middleware.GetUser(r)
@@ -303,7 +301,6 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := middleware.GetSession(r) session := middleware.GetSession(r)
user := middleware.GetUser(r) user := middleware.GetUser(r)
if session == nil || user == nil { if session == nil || user == nil {
@@ -326,7 +323,6 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc { func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUser(r) user := middleware.GetUser(r)
if user != nil { if user != nil {
utils.DoRedirect(w, r, "/") utils.DoRedirect(w, r, "/")
@@ -340,7 +336,6 @@ func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
email := r.FormValue("email") email := r.FormValue("email")
if email == "" { if email == "" {
utils.TriggerToastWithStatus(w, r, "error", "Please enter an email", http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", "Please enter an email", http.StatusBadRequest)
@@ -362,7 +357,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
pageUrl, err := url.Parse(r.Header.Get("HX-Current-URL")) pageUrl, err := url.Parse(r.Header.Get("Hx-Current-Url"))
if err != nil { if err != nil {
log.Error("Could not get current URL: %v", err) log.Error("Could not get current URL: %v", err)
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)

View File

@@ -10,13 +10,14 @@ import (
) )
func handleError(w http.ResponseWriter, r *http.Request, err error) { func handleError(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, service.ErrUnauthorized) { switch {
case errors.Is(err, service.ErrUnauthorized):
utils.TriggerToastWithStatus(w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized) utils.TriggerToastWithStatus(w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized)
return return
} else if errors.Is(err, service.ErrBadRequest) { case errors.Is(err, service.ErrBadRequest):
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
return return
} else if errors.Is(err, db.ErrNotFound) { case errors.Is(err, db.ErrNotFound):
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound) utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound)
return return
} }

View File

@@ -16,7 +16,6 @@ var UserKey ContextKey = "user"
func Authenticate(service service.Auth) func(http.Handler) http.Handler { func Authenticate(service service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionId := getSessionID(r) sessionId := getSessionID(r)
session, user, _ := service.SignInSession(sessionId) session, user, _ := service.SignInSession(sessionId)
@@ -49,7 +48,12 @@ func GetUser(r *http.Request) *types.User {
return nil return nil
} }
return obj.(*types.User) user, ok := obj.(*types.User)
if !ok {
return nil
}
return user
} }
func GetSession(r *http.Request) *types.Session { func GetSession(r *http.Request) *types.Session {
@@ -58,7 +62,12 @@ func GetSession(r *http.Request) *types.Session {
return nil return nil
} }
return obj.(*types.Session) session, ok := obj.(*types.Session)
if !ok {
return nil
}
return session
} }
func getSessionID(r *http.Request) string { func getSessionID(r *http.Request) string {

View File

@@ -7,7 +7,6 @@ import (
func CacheControl(next http.Handler) http.Handler { func CacheControl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
shouldCache := strings.HasPrefix(r.URL.Path, "/static") shouldCache := strings.HasPrefix(r.URL.Path, "/static")
if !shouldCache { if !shouldCache {

View File

@@ -37,19 +37,17 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := GetSession(r) session := GetSession(r)
if r.Method == http.MethodPost || if r.Method == http.MethodPost ||
r.Method == http.MethodPut || r.Method == http.MethodPut ||
r.Method == http.MethodDelete || r.Method == http.MethodDelete ||
r.Method == http.MethodPatch { r.Method == http.MethodPatch {
csrfToken := r.Header.Get("Csrf-Token")
csrfToken := r.Header.Get("csrf-token")
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
log.Info("CSRF-Token \"%s\" not correct", csrfToken) log.Info("CSRF-Token \"%s\" not correct", csrfToken)
if r.Header.Get("HX-Request") == "true" { if r.Header.Get("Hx-Request") == "true" {
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
} else { } else {
http.Error(w, "CSRF-Token not correct", http.StatusBadRequest) http.Error(w, "CSRF-Token not correct", http.StatusBadRequest)

View File

@@ -2,6 +2,7 @@ package middleware
import ( import (
"compress/gzip" "compress/gzip"
"errors"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -32,8 +33,7 @@ func Gzip(next http.Handler) http.Handler {
next.ServeHTTP(wrapper, r) next.ServeHTTP(wrapper, r)
err := gz.Close() err := gz.Close()
if err != nil && err != http.ErrBodyNotAllowed { if err != nil && !errors.Is(err, http.ErrBodyNotAllowed) {
// if err != nil {
log.Error("Gzip: could not close Writer: %v", err) log.Error("Gzip: could not close Writer: %v", err)
} }
}) })

View File

@@ -7,7 +7,6 @@ import (
) )
func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Handler { func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Content-Type-Options", "nosniff")
@@ -30,7 +29,7 @@ func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Han
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload") w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
if r.Method == "OPTIONS" { if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return return
} }

View File

@@ -2,12 +2,12 @@ package middleware
import "net/http" import "net/http"
// Chain list of handlers together // Chain list of handlers together.
func Wrapper(next http.Handler, handlers ...func(http.Handler) http.Handler) http.Handler { func Wrapper(next http.Handler, handlers ...func(http.Handler) http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
lastHandler := next lastHandler := next
for i := 0; i < len(handlers); i++ { for _, handler := range handlers {
lastHandler = handlers[i](lastHandler) lastHandler = handler(lastHandler)
} }
lastHandler.ServeHTTP(w, r) lastHandler.ServeHTTP(w, r)
}) })

View File

@@ -44,7 +44,6 @@ func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWri
} }
func (render *Render) getUserComp(user *types.User) templ.Component { func (render *Render) getUserComp(user *types.User) templ.Component {
if user != nil { if user != nil {
return auth.UserComp(user.Email) return auth.UserComp(user.Email)
} else { } else {

View File

@@ -106,7 +106,6 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
} }
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) { func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) {
var transactionsRecurring []*types.TransactionRecurring var transactionsRecurring []*types.TransactionRecurring
var err error var err error
if accountId == "" && treasureChestId == "" { if accountId == "" && treasureChestId == "" {

18
main.go
View File

@@ -1,6 +1,8 @@
package main package main
import ( import (
"errors"
"fmt"
"spend-sparrow/db" "spend-sparrow/db"
"spend-sparrow/handler" "spend-sparrow/handler"
"spend-sparrow/handler/middleware" "spend-sparrow/handler/middleware"
@@ -37,10 +39,14 @@ func main() {
log.Fatal("Could not close Database data.db: %v", err) log.Fatal("Could not close Database data.db: %v", err)
}() }()
run(context.Background(), db, os.Getenv) err = run(context.Background(), db, os.Getenv)
if err != nil {
log.Error("Error running server: %v", err)
return
}
} }
func run(ctx context.Context, database *sqlx.DB, env func(string) string) { func run(ctx context.Context, database *sqlx.DB, env func(string) string) error {
ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer cancel() defer cancel()
@@ -52,7 +58,7 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
// init db // init db
err := db.RunMigrations(database, "") err := db.RunMigrations(database, "")
if err != nil { if err != nil {
log.Fatal("Could not run migrations: %v", err) return fmt.Errorf("could not run migrations: %w", err)
} }
// init servers // init servers
@@ -61,6 +67,7 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
prometheusServer := &http.Server{ prometheusServer := &http.Server{
Addr: ":8081", Addr: ":8081",
Handler: promhttp.Handler(), Handler: promhttp.Handler(),
ReadHeaderTimeout: 10 * time.Second,
} }
go startServer(prometheusServer) go startServer(prometheusServer)
} }
@@ -68,6 +75,7 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
httpServer := &http.Server{ httpServer := &http.Server{
Addr: ":" + serverSettings.Port, Addr: ":" + serverSettings.Port,
Handler: createHandler(database, serverSettings), Handler: createHandler(database, serverSettings),
ReadHeaderTimeout: 10 * time.Second,
} }
go startServer(httpServer) go startServer(httpServer)
@@ -77,11 +85,13 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
go shutdownServer(httpServer, ctx, &wg) go shutdownServer(httpServer, ctx, &wg)
go shutdownServer(prometheusServer, ctx, &wg) go shutdownServer(prometheusServer, ctx, &wg)
wg.Wait() wg.Wait()
return nil
} }
func startServer(s *http.Server) { func startServer(s *http.Server) {
log.Info("Starting server on %q", s.Addr) log.Info("Starting server on %q", s.Addr)
if err := s.ListenAndServe(); err != nil && err != http.ErrServerClosed { if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Error("error listening and serving: %v", err) log.Error("error listening and serving: %v", err)
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"errors"
"fmt" "fmt"
"spend-sparrow/db" "spend-sparrow/db"
@@ -119,7 +120,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Update", nil, err) err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest) return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
} }
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -164,8 +165,8 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
account := &types.Account{} var account types.Account
err = s.db.Get(account, ` err = s.db.Get(&account, `
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Get", nil, err) err = db.TransformAndLogDbError("account Get", nil, err)
if err != nil { if err != nil {
@@ -173,7 +174,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
return nil, err return nil, err
} }
return account, nil return &account, nil
} }
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {

View File

@@ -94,30 +94,6 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
return session, user, nil return session, user, nil
} }
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
if session == nil {
return nil
}
err := service.db.DeleteSession(session.Id)
if err != nil {
return types.ErrInternal
}
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
if err != nil {
return types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(token.Token)
if err != nil {
return types.ErrInternal
}
}
return nil
}
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) { func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
if sessionId == "" { if sessionId == "" {
return nil, nil, ErrSessionIdInvalid return nil, nil, ErrSessionIdInvalid
@@ -155,30 +131,6 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
return session, nil return session, nil
} }
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
sessionId, err := service.random.String(32)
if err != nil {
return nil, types.ErrInternal
}
err = service.db.DeleteOldSessions(userId)
if err != nil {
return nil, types.ErrInternal
}
createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour)
session := types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session)
if err != nil {
return nil, types.ErrInternal
}
return session, nil
}
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) { func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
_, err := mail.ParseAddress(email) _, err := mail.ParseAddress(email)
if err != nil { if err != nil {
@@ -205,7 +157,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
err = service.db.InsertUser(user) err = service.db.InsertUser(user)
if err != nil { if err != nil {
if err == db.ErrAlreadyExists { if errors.Is(err, db.ErrAlreadyExists) {
return nil, ErrAccountExists return nil, ErrAccountExists
} else { } else {
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -216,9 +168,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
} }
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify) tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
if err != nil && err != db.ErrNotFound { if err != nil && !errors.Is(err, db.ErrNotFound) {
return return
} }
@@ -234,7 +185,13 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
return return
} }
token = types.NewToken(userId, "", newTokenStr, types.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) token = types.NewToken(
userId,
"",
newTokenStr,
types.TokenTypeEmailVerify,
service.clock.Now(),
service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token) err = service.db.InsertToken(token)
if err != nil { if err != nil {
@@ -253,7 +210,6 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
} }
func (service AuthImpl) VerifyUserEmail(tokenStr string) error { func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
if tokenStr == "" { if tokenStr == "" {
return types.ErrInternal return types.ErrInternal
} }
@@ -291,12 +247,10 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
} }
func (service AuthImpl) SignOut(sessionId string) error { func (service AuthImpl) SignOut(sessionId string) error {
return service.db.DeleteSession(sessionId) return service.db.DeleteSession(sessionId)
} }
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error { func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
userDb, err := service.db.GetUser(user.Id) userDb, err := service.db.GetUser(user.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
@@ -318,7 +272,6 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
} }
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error { func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
} }
@@ -365,14 +318,20 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
user, err := service.db.GetUserByEmail(email) user, err := service.db.GetUserByEmail(email)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil return nil
} else { } else {
return types.ErrInternal return types.ErrInternal
} }
} }
token := types.NewToken(user.Id, "", tokenStr, types.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute)) token := types.NewToken(
user.Id,
"",
tokenStr,
types.TokenTypePasswordReset,
service.clock.Now(),
service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(token) err = service.db.InsertToken(token)
if err != nil { if err != nil {
@@ -391,7 +350,6 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
} }
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
} }
@@ -449,7 +407,6 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
if token.Type != types.TokenTypeCsrf || if token.Type != types.TokenTypeCsrf ||
token.SessionId != sessionId || token.SessionId != sessionId ||
token.ExpiresAt.Before(service.clock.Now()) { token.ExpiresAt.Before(service.clock.Now()) {
return false return false
} }
@@ -472,7 +429,13 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
return "", types.ErrInternal return "", types.ErrInternal
} }
token := types.NewToken(session.UserId, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(8*time.Hour)) token := types.NewToken(
session.UserId,
session.Id,
tokenStr,
types.TokenTypeCsrf,
service.clock.Now(),
service.clock.Now().Add(8*time.Hour))
err = service.db.InsertToken(token) err = service.db.InsertToken(token)
if err != nil { if err != nil {
return "", types.ErrInternal return "", types.ErrInternal
@@ -483,12 +446,59 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
return tokenStr, nil return tokenStr, nil
} }
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
if session == nil {
return nil
}
err := service.db.DeleteSession(session.Id)
if err != nil {
return types.ErrInternal
}
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
if err != nil {
return types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(token.Token)
if err != nil {
return types.ErrInternal
}
}
return nil
}
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
sessionId, err := service.random.String(32)
if err != nil {
return nil, types.ErrInternal
}
err = service.db.DeleteOldSessions(userId)
if err != nil {
return nil, types.ErrInternal
}
createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour)
session := types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session)
if err != nil {
return nil, types.ErrInternal
}
return session, nil
}
func GetHashPassword(password string, salt []byte) []byte { func GetHashPassword(password string, salt []byte) []byte {
return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16)
} }
func isPasswordValid(password string) bool { func isPasswordValid(password string) bool {
if len(password) < 8 || if len(password) < 8 ||
!strings.ContainsAny(password, "0123456789") || !strings.ContainsAny(password, "0123456789") ||
!strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") ||

View File

@@ -1,8 +1,9 @@
package service package service_test
import ( import (
"spend-sparrow/db" "spend-sparrow/db"
"spend-sparrow/mocks" "spend-sparrow/mocks"
"spend-sparrow/service"
"spend-sparrow/types" "spend-sparrow/types"
"strings" "strings"
@@ -12,6 +13,17 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
var (
settings = types.Settings{
Port: "",
PrometheusEnabled: false,
BaseUrl: "",
Environment: "test",
Smtp: nil,
}
) )
func TestSignUp(t *testing.T) { func TestSignUp(t *testing.T) {
@@ -24,11 +36,11 @@ func TestSignUp(t *testing.T) {
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!") _, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!")
assert.Equal(t, ErrInvalidEmail, err) assert.Equal(t, service.ErrInvalidEmail, err)
}) })
t.Run("should check for password complexity", func(t *testing.T) { t.Run("should check for password complexity", func(t *testing.T) {
t.Parallel() t.Parallel()
@@ -38,7 +50,7 @@ func TestSignUp(t *testing.T) {
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
weakPasswords := []string{ weakPasswords := []string{
"123!ab", // too short "123!ab", // too short
@@ -49,7 +61,7 @@ func TestSignUp(t *testing.T) {
for _, password := range weakPasswords { for _, password := range weakPasswords {
_, err := underTest.SignUp("some@valid.email", password) _, err := underTest.SignUp("some@valid.email", password)
assert.Equal(t, ErrInvalidPassword, err) assert.Equal(t, service.ErrInvalidPassword, err)
} }
}) })
t.Run("should signup correctly", func(t *testing.T) { t.Run("should signup correctly", func(t *testing.T) {
@@ -66,17 +78,17 @@ func TestSignUp(t *testing.T) {
salt := []byte("salt") salt := []byte("salt")
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
expected := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime) expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
mockRandom.EXPECT().UUID().Return(userId, nil) mockRandom.EXPECT().UUID().Return(userId, nil)
mockRandom.EXPECT().Bytes(16).Return(salt, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil)
mockClock.EXPECT().Now().Return(createTime) mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(expected).Return(nil) mockAuthDb.EXPECT().InsertUser(expected).Return(nil)
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
actual, err := underTest.SignUp(email, password) actual, err := underTest.SignUp(email, password)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, expected, actual) assert.Equal(t, expected, actual)
}) })
@@ -93,7 +105,7 @@ func TestSignUp(t *testing.T) {
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
password := "SomeStrongPassword123!" password := "SomeStrongPassword123!"
salt := []byte("salt") salt := []byte("salt")
user := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime) user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
mockRandom.EXPECT().UUID().Return(user.Id, nil) mockRandom.EXPECT().UUID().Return(user.Id, nil)
mockRandom.EXPECT().Bytes(16).Return(salt, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil)
@@ -101,20 +113,25 @@ func TestSignUp(t *testing.T) {
mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists) mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists)
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp(user.Email, password) _, err := underTest.SignUp(user.Email, password)
assert.Equal(t, ErrAccountExists, err) assert.Equal(t, service.ErrAccountExists, err)
}) })
} }
func TestSendVerificationMail(t *testing.T) { func TestSendVerificationMail(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("should use stored token and send mail", func(t *testing.T) { t.Run("should use stored token and send mail", func(t *testing.T) {
t.Parallel() t.Parallel()
token := types.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", types.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) token := types.NewToken(
uuid.New(),
"sessionId",
"someRandomTokenToUse",
types.TokenTypeEmailVerify,
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
tokens := []*types.Token{token} tokens := []*types.Token{token}
email := "some@email.de" email := "some@email.de"
@@ -131,7 +148,7 @@ func TestSendVerificationMail(t *testing.T) {
return strings.Contains(message, token.Token) return strings.Contains(message, token.Token)
})).Return() })).Return()
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
underTest.SendVerificationMail(userId, email) underTest.SendVerificationMail(userId, email)
}) })

View File

@@ -5,16 +5,21 @@ import (
"regexp" "regexp"
) )
const (
DECIMALS_MULTIPLIER = 100
)
var ( var (
safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9ÄÖÜäöüß,&'" -]+$`) safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9ÄÖÜäöüß,&'" -]+$`)
) )
func validateString(value string, fieldName string) error { func validateString(value string, fieldName string) error {
if value == "" { switch {
case value == "":
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest) return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest)
} else if !safeInputRegex.MatchString(value) { case !safeInputRegex.MatchString(value):
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest) return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest)
} else { default:
return nil return nil
} }
} }

View File

@@ -34,7 +34,19 @@ func (m MailImpl) internalSendMail(to string, subject string, message string) {
auth := smtp.PlainAuth("", s.User, s.Pass, s.Host) auth := smtp.PlainAuth("", s.User, s.Pass, s.Host)
msg := fmt.Sprintf("From: %v <%v>\nTo: %v\nSubject: %v\nMIME-version: 1.0;\nContent-Type: text/html; charset=\"UTF-8\";\n\n%v", s.FromName, s.FromMail, to, subject, message) msg := fmt.Sprintf(
`From: %v <%v>
To: %v
Subject: %v
MIME-version: 1.0;
Content-Type: text/html; charset="UTF-8";
%v`,
s.FromName,
s.FromMail,
to,
subject,
message)
log.Info("Sending mail to %v", to) log.Info("Sending mail to %v", to)
err := smtp.SendMail(s.Host+":"+s.Port, auth, s.FromMail, []string{to}, []byte(msg)) err := smtp.SendMail(s.Host+":"+s.Port, auth, s.FromMail, []string{to}, []byte(msg))

View File

@@ -1,8 +0,0 @@
package service
type MoneyImpl struct {
}
func NewMoneyImpl() *MoneyImpl {
return &MoneyImpl{}
}

View File

@@ -1,80 +0,0 @@
package service
import (
"testing"
)
func TestMoneyCalculation(t *testing.T) {
t.Parallel()
t.Run("should calculate correct oink balance", func(t *testing.T) {
// t.Parallel()
//
// underTest := NewMoneyImpl()
//
// // GIVEN
// timestamp := time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC)
//
// userId := uuid.New()
//
// account := types.Account{
// Id: uuid.New(),
// UserId: userId,
//
// Type: "Bank",
// Name: "Bank",
//
// CurrentBalance: 0,
// LastTransaction: time.Time{},
// OinkBalance: 0,
// }
//
// // The PiggyBank is a fictional account. The money it "holds" is actually in the Account
// piggyBank := types.PiggyBank{
// Id: uuid.New(),
// UserId: userId,
//
// AccountId: account.Id,
// Name: "Car",
//
// CurrentBalance: 0,
// }
//
// savingsPlan := types.SavingsPlan{
// Id: uuid.New(),
// UserId: userId,
// PiggyBankId: piggyBank.Id,
//
// MonthlySaving: 10,
//
// ValidFrom: timestamp,
// }
//
// transaction1 := types.Transaction{
// Id: uuid.New(),
// UserId: userId,
//
// AccountId: account.Id,
//
// Value: 20,
// Timestamp: timestamp,
// }
//
// transaction2 := types.Transaction{
// Id: uuid.New(),
// UserId: userId,
//
// AccountId: account.Id,
// PiggyBankId: &piggyBank.Id,
//
// Value: -1,
// Timestamp: timestamp.Add(1 * time.Hour),
// }
//
// // WHEN
// actual, err := underTest.CalculateAllBalancesInTime(account, piggyBank, savingsPlan, []types.Transaction{transaction1, transaction2})
//
// // THEN
// assert.Nil(t, err)
// assert.ElementsMatch(t, expected, actual)
})
}

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
@@ -73,8 +74,10 @@ func (s TransactionImpl) Add(user *types.User, transactionInput types.Transactio
} }
r, err := tx.NamedExec(` r, err := tx.NamedExec(`
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp, party, description, error, created_at, created_by) INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp, :party, :description, :error, :created_at, :created_by)`, transaction) party, description, error, created_at, created_by)
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
:party, :description, :error, :created_at, :created_by)`, transaction)
err = db.TransformAndLogDbError("transaction Insert", r, err) err = db.TransformAndLogDbError("transaction Insert", r, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -135,7 +138,7 @@ func (s TransactionImpl) Update(user *types.User, input types.TransactionInput)
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Update", nil, err) err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest) return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest)
} }
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -232,7 +235,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Get", nil, err) err = db.TransformAndLogDbError("transaction Get", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest) return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest)
} }
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -259,7 +262,10 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF
OR (? = "false" AND error IS NULL) OR (? = "false" AND error IS NULL)
) )
ORDER BY timestamp DESC`, ORDER BY timestamp DESC`,
user.Id, filter.AccountId, filter.AccountId, filter.TreasureChestId, filter.TreasureChestId, filter.Error, filter.Error, filter.Error) user.Id,
filter.AccountId, filter.AccountId,
filter.TreasureChestId, filter.TreasureChestId,
filter.Error, filter.Error, filter.Error)
err = db.TransformAndLogDbError("transaction GetAll", nil, err) err = db.TransformAndLogDbError("transaction GetAll", nil, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -302,7 +308,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
WHERE id = ? WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err) err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil && err != db.ErrNotFound { if err != nil && !errors.Is(err, db.ErrNotFound) {
return err return err
} }
} }
@@ -314,7 +320,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
WHERE id = ? WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err) err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil && err != db.ErrNotFound { if err != nil && !errors.Is(err, db.ErrNotFound) {
return err return err
} }
} }
@@ -354,7 +360,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
SET current_balance = 0 SET current_balance = 0
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err) err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil && err != db.ErrNotFound { if err != nil && !errors.Is(err, db.ErrNotFound) {
return err return err
} }
@@ -363,7 +369,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
SET current_balance = 0 SET current_balance = 0
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err) err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil && err != db.ErrNotFound { if err != nil && !errors.Is(err, db.ErrNotFound) {
return err return err
} }
@@ -372,7 +378,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
FROM "transaction" FROM "transaction"
WHERE user_id = ?`, user.Id) WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err) err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil && err != db.ErrNotFound { if err != nil && !errors.Is(err, db.ErrNotFound) {
return err return err
} }
defer func() { defer func() {
@@ -382,15 +388,15 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
} }
}() }()
transaction := &types.Transaction{} var transaction types.Transaction
for rows.Next() { for rows.Next() {
err = rows.StructScan(transaction) err = rows.StructScan(&transaction)
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err) err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil { if err != nil {
return err return err
} }
s.updateErrors(transaction) s.updateErrors(&transaction)
r, err = tx.Exec(` r, err = tx.Exec(`
UPDATE "transaction" UPDATE "transaction"
SET error = ? SET error = ?
@@ -424,7 +430,6 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
if err != nil { if err != nil {
return err return err
} }
} }
} }
@@ -438,7 +443,6 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
} }
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.TransactionInput) (*types.Transaction, error) { func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.TransactionInput) (*types.Transaction, error) {
var ( var (
id uuid.UUID id uuid.UUID
accountUuid *uuid.UUID accountUuid *uuid.UUID
@@ -484,7 +488,6 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
log.Error("transaction validate: %v", err) log.Error("transaction validate: %v", err)
return nil, fmt.Errorf("account not found: %w", ErrBadRequest) return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
} }
} }
if input.TreasureChestId != "" { if input.TreasureChestId != "" {
@@ -498,7 +501,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError("transaction validate", nil, err) err = db.TransformAndLogDbError("transaction validate", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
} }
return nil, err return nil, err
@@ -513,7 +516,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
log.Error("transaction validate: %v", err) log.Error("transaction validate: %v", err)
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
} }
valueInt := int64(valueFloat * 100) valueInt := int64(valueFloat * DECIMALS_MULTIPLIER)
timestamp, err := time.Parse("2006-01-02", input.Timestamp) timestamp, err := time.Parse("2006-01-02", input.Timestamp)
if err != nil { if err != nil {
@@ -544,6 +547,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
Timestamp: timestamp, Timestamp: timestamp,
Party: input.Party, Party: input.Party,
Description: input.Description, Description: input.Description,
Error: nil,
CreatedAt: createdAt, CreatedAt: createdAt,
CreatedBy: createdBy, CreatedBy: createdBy,
@@ -557,25 +561,26 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
} }
func (s TransactionImpl) updateErrors(transaction *types.Transaction) { func (s TransactionImpl) updateErrors(transaction *types.Transaction) {
error := "" errorStr := ""
if transaction.Value < 0 { switch {
case transaction.Value < 0:
if transaction.TreasureChestId == nil { if transaction.TreasureChestId == nil {
error = "no treasure chest specified" errorStr = "no treasure chest specified"
} }
} else if transaction.Value > 0 { case transaction.Value > 0:
if transaction.AccountId == nil && transaction.TreasureChestId == nil { if transaction.AccountId == nil && transaction.TreasureChestId == nil {
error = "either an account or a treasure chest needs to be specified" errorStr = "either an account or a treasure chest needs to be specified"
} else if transaction.AccountId != nil && transaction.TreasureChestId != nil { } else if transaction.AccountId != nil && transaction.TreasureChestId != nil {
error = "positive amounts can only be applied to either an account or a treasure chest" errorStr = "positive amounts can only be applied to either an account or a treasure chest"
} }
} else { default:
error = "\"value\" needs to be specified" errorStr = "\"value\" needs to be specified"
} }
if error == "" { if errorStr == "" {
transaction.Error = nil transaction.Error = nil
} else { } else {
transaction.Error = &error transaction.Error = &errorStr
} }
} }

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
@@ -18,8 +19,8 @@ import (
var ( var (
transactionRecurringMetric = promauto.NewCounterVec( transactionRecurringMetric = promauto.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "spendsparrow_transactionRecurring_recurring_total", Name: "spendsparrow_transaction_recurring_total",
Help: "The total of transactionRecurring recurring operations", Help: "The total of transactionRecurring operations",
}, },
[]string{"operation"}, []string{"operation"},
) )
@@ -28,7 +29,6 @@ var (
type TransactionRecurring interface { type TransactionRecurring interface {
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Get(user *types.User, id string) (*types.TransactionRecurring, error)
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error)
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(user *types.User, id string) error Delete(user *types.User, id string) error
@@ -50,7 +50,9 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, settings *
} }
} }
func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) Add(
user *types.User,
transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("add").Inc() transactionRecurringMetric.WithLabelValues("add").Inc()
if user == nil { if user == nil {
@@ -72,8 +74,11 @@ func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInpu
} }
r, err := tx.NamedExec(` r, err := tx.NamedExec(`
INSERT INTO "transaction_recurring" (id, user_id, interval_months, active, party, description, account_id, treasure_chest_id, value, created_at, created_by) INSERT INTO "transaction_recurring" (id, user_id, interval_months,
VALUES (:id, :user_id, :interval_months, :active, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`, transactionRecurring) active, party, description, account_id, treasure_chest_id, value, created_at, created_by)
VALUES (:id, :user_id, :interval_months,
:active, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`,
transactionRecurring)
err = db.TransformAndLogDbError("transactionRecurring Insert", r, err) err = db.TransformAndLogDbError("transactionRecurring Insert", r, err)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -88,7 +93,9 @@ func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInpu
return transactionRecurring, nil return transactionRecurring, nil
} }
func (s TransactionRecurringImpl) Update(user *types.User, input types.TransactionRecurringInput) (*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) Update(
user *types.User,
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("update").Inc() transactionRecurringMetric.WithLabelValues("update").Inc()
if user == nil { if user == nil {
return nil, ErrUnauthorized return nil, ErrUnauthorized
@@ -112,7 +119,7 @@ func (s TransactionRecurringImpl) Update(user *types.User, input types.Transacti
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err) err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest) return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest)
} }
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -151,31 +158,6 @@ func (s TransactionRecurringImpl) Update(user *types.User, input types.Transacti
return transactionRecurring, nil return transactionRecurring, nil
} }
func (s TransactionRecurringImpl) Get(user *types.User, id string) (*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("get").Inc()
if user == nil {
return nil, ErrUnauthorized
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("transactionRecurring get: %v", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
var transactionRecurring types.TransactionRecurring
err = s.db.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Get", nil, err)
if err != nil {
if err == db.ErrNotFound {
return nil, fmt.Errorf("transactionRecurring %v not found: %w", id, ErrBadRequest)
}
return nil, types.ErrInternal
}
return &transactionRecurring, nil
}
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) { func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("get_all_by_account").Inc() transactionRecurringMetric.WithLabelValues("get_all_by_account").Inc()
if user == nil { if user == nil {
@@ -201,7 +183,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id) err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err) err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest) return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest)
} }
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -254,7 +236,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(user *types.User, treasu
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id) err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err) err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest) return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest)
} }
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -329,7 +311,6 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
oldTransactionRecurring *types.TransactionRecurring, oldTransactionRecurring *types.TransactionRecurring,
userId uuid.UUID, userId uuid.UUID,
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) { input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
var ( var (
id uuid.UUID id uuid.UUID
accountUuid *uuid.UUID accountUuid *uuid.UUID
@@ -393,7 +374,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err) err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
} }
return nil, err return nil, err
@@ -418,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
log.Error("transactionRecurring validate: %v", err) log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
} }
valueInt := int64(valueFloat * 100) valueInt := int64(valueFloat * DECIMALS_MULTIPLIER)
if input.Party != "" { if input.Party != "" {
err = validateString(input.Party, "party") err = validateString(input.Party, "party")
@@ -444,12 +425,12 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
active := input.Active == "on" active := input.Active == "on"
transactionRecurring := types.TransactionRecurring{ transactionRecurring := types.TransactionRecurring{
Id: id, Id: id,
UserId: userId, UserId: userId,
IntervalMonths: intervalMonths, IntervalMonths: intervalMonths,
Active: active, Active: active,
LastExecution: nil,
Party: input.Party, Party: input.Party,
Description: input.Description, Description: input.Description,

View File

@@ -1,6 +1,7 @@
package service package service
import ( import (
"errors"
"fmt" "fmt"
"slices" "slices"
@@ -131,7 +132,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id) err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Update", nil, err) err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err) return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
} }
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -198,17 +199,17 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
} }
treasureChest := &types.TreasureChest{} var treasureChest types.TreasureChest
err = s.db.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("treasureChest Get", nil, err) err = db.TransformAndLogDbError("treasureChest Get", nil, err)
if err != nil { if err != nil {
if err == db.ErrNotFound { if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err) return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
} }
return nil, types.ErrInternal return nil, types.ErrInternal
} }
return treasureChest, nil return &treasureChest, nil
} }
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) { func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) {
@@ -259,7 +260,9 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
} }
transactionsCount := 0 transactionsCount := 0
err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`, user.Id, id) err = tx.Get(&transactionsCount,
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err) err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil { if err != nil {
return err return err
@@ -284,12 +287,11 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
} }
func sortTree(nodes []*types.TreasureChest) []*types.TreasureChest { func sortTree(nodes []*types.TreasureChest) []*types.TreasureChest {
var ( var (
roots []*types.TreasureChest roots []*types.TreasureChest
result []*types.TreasureChest
) )
children := make(map[uuid.UUID][]*types.TreasureChest) children := make(map[uuid.UUID][]*types.TreasureChest)
result := make([]*types.TreasureChest, 0)
for _, node := range nodes { for _, node := range nodes {
if node.ParentId == nil { if node.ParentId == nil {

View File

@@ -111,6 +111,7 @@ templ AccountItem(account *types.Account) {
hx-target="closest #account" hx-target="closest #account"
hx-swap="outerHTML" hx-swap="outerHTML"
class="button button-neglect px-1 flex items-center gap-2" class="button button-neglect px-1 flex items-center gap-2"
hx-confirm="Are you sure you want to delete this account?"
> >
@svg.Delete() @svg.Delete()
<span> <span>

View File

@@ -28,7 +28,7 @@ templ Layout(slot templ.Component, user templ.Component, loggedIn bool, path str
<script src="/static/js/toast.js"></script> <script src="/static/js/toast.js"></script>
<script src="/static/js/time.js"></script> <script src="/static/js/time.js"></script>
</head> </head>
<body class="h-screen flex flex-col" hx-headers='{"csrf-token": "CSRF_TOKEN"}'> <body class="h-screen flex flex-col" hx-headers='{"Csrf-Token": "CSRF_TOKEN"}'>
// Header // Header
<nav class="flex bg-white items-center gap-2 py-1 px-2 h-12 md:gap-10 md:px-10 md:py-2"> <nav class="flex bg-white items-center gap-2 py-1 px-2 h-12 md:gap-10 md:px-10 md:py-2">
<a href="/" class="flex gap-2 mr-20"> <a href="/" class="flex gap-2 mr-20">

View File

@@ -6,13 +6,13 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// The Account holds money // The Account holds money.
type Account struct { type Account struct {
Id uuid.UUID Id uuid.UUID `db:"id"`
UserId uuid.UUID `db:"user_id"` UserId uuid.UUID `db:"user_id"`
// Custom Name of the account, e.g. "Bank", "Cash", "Credit Card" // Custom Name of the account, e.g. "Bank", "Cash", "Credit Card"
Name string Name string `db:"name"`
CurrentBalance int64 `db:"current_balance"` CurrentBalance int64 `db:"current_balance"`
LastTransaction *time.Time `db:"last_transaction"` LastTransaction *time.Time `db:"last_transaction"`

View File

@@ -17,7 +17,15 @@ type User struct {
CreateAt time.Time CreateAt time.Time
} }
func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User { func NewUser(
id uuid.UUID,
email string,
emailVerified bool,
emailVerifiedAt *time.Time,
isAdmin bool,
password []byte,
salt []byte,
createAt time.Time) *User {
return &User{ return &User{
Id: id, Id: id,
Email: email, Email: email,
@@ -31,10 +39,10 @@ func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *ti
} }
type Session struct { type Session struct {
Id string Id string `db:"session_id"`
UserId uuid.UUID UserId uuid.UUID `db:"user_id"`
CreatedAt time.Time CreatedAt time.Time `db:"created_at"`
ExpiresAt time.Time ExpiresAt time.Time `db:"expires_at"`
} }
func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session { func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session {
@@ -63,7 +71,13 @@ var (
TokenTypeCsrf TokenType = "csrf" TokenTypeCsrf TokenType = "csrf"
) )
func NewToken(userId uuid.UUID, sessionId string, token string, tokenType TokenType, createdAt time.Time, expiresAt time.Time) *Token { func NewToken(
userId uuid.UUID,
sessionId string,
token string,
tokenType TokenType,
createdAt time.Time,
expiresAt time.Time) *Token {
return &Token{ return &Token{
UserId: userId, UserId: userId,
SessionId: sessionId, SessionId: sessionId,

View File

@@ -1,26 +0,0 @@
package types
import (
"time"
"github.com/google/uuid"
)
// The SavingsPlan is applied every interval to the TreasureChest/Account as a transaction
type SavingsPlan struct {
Id uuid.UUID
UserId uuid.UUID `db:"user_id"`
TreasureChestId uuid.UUID `db:"treasure_chest_id"`
MonthlySaving int64 `db:"monthly_saving"`
ValidFrom time.Time `db:"valid_from"`
/// nil means it is valid indefinitely
ValidTo *time.Time `db:"valid_to"`
CreatedAt time.Time `db:"created_at"`
CreatedBy uuid.UUID `db:"created_by"`
UpdatedAt *time.Time `db:"updated_at"`
UpdatedBy *uuid.UUID `db:"updated_by"`
}

View File

@@ -23,10 +23,37 @@ type SmtpSettings struct {
} }
func NewSettingsFromEnv(env func(string) string) *Settings { func NewSettingsFromEnv(env func(string) string) *Settings {
var smtp *SmtpSettings var smtp *SmtpSettings
if env("SMTP_ENABLED") == "true" { if env("SMTP_ENABLED") == "true" {
smtp = &SmtpSettings{ smtp = getSmtpSettings(env)
}
settings := &Settings{
Port: env("PORT"),
PrometheusEnabled: env("PROMETHEUS_ENABLED") == "true",
BaseUrl: env("BASE_URL"),
Environment: env("ENVIRONMENT"),
Smtp: smtp,
}
if settings.BaseUrl == "" {
log.Fatal("BASE_URL must be set")
}
if settings.Port == "" {
log.Fatal("PORT must be set")
}
if settings.Environment == "" {
log.Fatal("ENVIRONMENT must be set")
}
log.Info("BASE_URL is %q", settings.BaseUrl)
log.Info("ENVIRONMENT is %q", settings.Environment)
return settings
}
func getSmtpSettings(env func(string) string) *SmtpSettings {
smtp := SmtpSettings{
Host: env("SMTP_HOST"), Host: env("SMTP_HOST"),
Port: env("SMTP_PORT"), Port: env("SMTP_PORT"),
User: env("SMTP_USER"), User: env("SMTP_USER"),
@@ -53,28 +80,6 @@ func NewSettingsFromEnv(env func(string) string) *Settings {
if smtp.FromName == "" { if smtp.FromName == "" {
log.Fatal("SMTP_FROM_NAME must be set") log.Fatal("SMTP_FROM_NAME must be set")
} }
}
settings := &Settings{ return &smtp
Port: env("PORT"),
PrometheusEnabled: env("PROMETHEUS_ENABLED") == "true",
BaseUrl: env("BASE_URL"),
Environment: env("ENVIRONMENT"),
Smtp: smtp,
}
if settings.BaseUrl == "" {
log.Fatal("BASE_URL must be set")
}
if settings.Port == "" {
log.Fatal("PORT must be set")
}
if settings.Environment == "" {
log.Fatal("ENVIRONMENT must be set")
}
log.Info("BASE_URL is %q", settings.BaseUrl)
log.Info("ENVIRONMENT is %q", settings.Environment)
return settings
} }

View File

@@ -14,19 +14,19 @@ import (
// If it becomes necessary to precalculate snapshots for performance reasons, this can be done in the future. // If it becomes necessary to precalculate snapshots for performance reasons, this can be done in the future.
// But the transaction should always be the source of truth. // But the transaction should always be the source of truth.
type Transaction struct { type Transaction struct {
Id uuid.UUID Id uuid.UUID `db:"id"`
UserId uuid.UUID `db:"user_id"` UserId uuid.UUID `db:"user_id"`
Timestamp time.Time Timestamp time.Time `db:"timestamp"`
Party string Party string `db:"party"`
Description string Description string `db:"description"`
AccountId *uuid.UUID `db:"account_id"` AccountId *uuid.UUID `db:"account_id"`
TreasureChestId *uuid.UUID `db:"treasure_chest_id"` TreasureChestId *uuid.UUID `db:"treasure_chest_id"`
Value int64 Value int64 `db:"value"`
// If an error is present, then the transaction is not valid and should not be used for calculations. // If an error is present, then the transaction is not valid and should not be used for calculations.
Error *string Error *string `db:"error"`
CreatedAt time.Time `db:"created_at"` CreatedAt time.Time `db:"created_at"`
// Either a user_id or a transaction_recurring_id // Either a user_id or a transaction_recurring_id
CreatedBy uuid.UUID `db:"created_by"` CreatedBy uuid.UUID `db:"created_by"`

View File

@@ -7,19 +7,19 @@ import (
) )
type TransactionRecurring struct { type TransactionRecurring struct {
Id uuid.UUID Id uuid.UUID `db:"id"`
UserId uuid.UUID `db:"user_id"` UserId uuid.UUID `db:"user_id"`
IntervalMonths int64 `db:"interval_months"` IntervalMonths int64 `db:"interval_months"`
LastExecution *time.Time `db:"last_execution"` LastExecution *time.Time `db:"last_execution"`
Active bool Active bool `db:"active"`
Party string Party string `db:"party"`
Description string Description string `db:"description"`
AccountId *uuid.UUID `db:"account_id"` AccountId *uuid.UUID `db:"account_id"`
TreasureChestId *uuid.UUID `db:"treasure_chest_id"` TreasureChestId *uuid.UUID `db:"treasure_chest_id"`
Value int64 Value int64 `db:"value"`
CreatedAt time.Time `db:"created_at"` CreatedAt time.Time `db:"created_at"`
CreatedBy uuid.UUID `db:"created_by"` CreatedBy uuid.UUID `db:"created_by"`

View File

@@ -10,13 +10,13 @@ import (
// The money it "holds" distributed across all accounts // The money it "holds" distributed across all accounts
// //
// At the time of writing this, linking it to a specific account doesn't really make sense // At the time of writing this, linking it to a specific account doesn't really make sense
// Imagne a TreasureChest for free time activities, where some money is spend in cash and some other with credit card // Imagine a TreasureChest for free time activities, where some money is spend in cash and some other with credit card.
type TreasureChest struct { type TreasureChest struct {
Id uuid.UUID Id uuid.UUID `db:"id"`
ParentId *uuid.UUID `db:"parent_id"` ParentId *uuid.UUID `db:"parent_id"`
UserId uuid.UUID `db:"user_id"` UserId uuid.UUID `db:"user_id"`
Name string Name string `db:"name"`
CurrentBalance int64 `db:"current_balance"` CurrentBalance int64 `db:"current_balance"`

View File

@@ -11,7 +11,7 @@ import (
func TriggerToast(w http.ResponseWriter, r *http.Request, class string, message string) { func TriggerToast(w http.ResponseWriter, r *http.Request, class string, message string) {
if IsHtmx(r) { if IsHtmx(r) {
w.Header().Set("HX-Trigger", fmt.Sprintf(`{"toast": "%v|%v"}`, class, strings.ReplaceAll(message, `"`, `\"`))) w.Header().Set("Hx-Trigger", fmt.Sprintf(`{"toast": "%v|%v"}`, class, strings.ReplaceAll(message, `"`, `\"`)))
} else { } else {
log.Error("Trying to trigger toast in non-HTMX request") log.Error("Trying to trigger toast in non-HTMX request")
} }
@@ -24,19 +24,19 @@ func TriggerToastWithStatus(w http.ResponseWriter, r *http.Request, class string
func DoRedirect(w http.ResponseWriter, r *http.Request, url string) { func DoRedirect(w http.ResponseWriter, r *http.Request, url string) {
if IsHtmx(r) { if IsHtmx(r) {
w.Header().Add("HX-Redirect", url) w.Header().Add("Hx-Redirect", url)
} else { } else {
http.Redirect(w, r, url, http.StatusSeeOther) http.Redirect(w, r, url, http.StatusSeeOther)
} }
} }
func WaitMinimumTime[T interface{}](waitTime time.Duration, function func() (T, error)) (T, error) { func WaitMinimumTime[T interface{}](waitTime time.Duration, f func() (T, error)) (T, error) {
start := time.Now() start := time.Now()
result, err := function() result, err := f()
time.Sleep(waitTime - time.Since(start)) time.Sleep(waitTime - time.Since(start))
return result, err return result, err
} }
func IsHtmx(r *http.Request) bool { func IsHtmx(r *http.Request) bool {
return r.Header.Get("HX-Request") == "true" return r.Header.Get("Hx-Request") == "true"
} }