fix: lint errors
All checks were successful
Build Docker Image / Build-Docker-Image (push) Successful in 5m22s
Build and Push Docker Image / Build-And-Push-Docker-Image (push) Successful in 5m26s

This commit was merged in pull request #130.
This commit is contained in:
2025-05-25 16:36:30 +02:00
parent 2ba5ddd9f2
commit 128a2fc4d7
36 changed files with 1024 additions and 968 deletions

28
.golangci.yaml Normal file
View File

@@ -0,0 +1,28 @@
version: '2'
linters:
default: all
disable:
- wsl
- wrapcheck
- varnamelen
- revive # should probably be enabled
- nlreturn
- mnd # should probably be enabled
- lll # should probably be enabled
- ireturn # should probably be enabled
- interfacebloat
- iface
- goconst # should probably be enabled
- gocognit # should probably be enabled
- gochecknoglobals # should probably be enabled
- funlen
- maintidx
- exhaustruct
- dupword # should probably be enabled
- dupl # should probably be enabled
- depguard
- cyclop
- contextcheck
settings:
nestif:
min-complexity: 6

View File

@@ -1,6 +1,7 @@
package db
import (
"errors"
"spend-sparrow/log"
"spend-sparrow/types"
@@ -89,7 +90,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
FROM user
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
} else {
log.Error("SQL error GetUser: %v", err)
@@ -116,7 +117,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
FROM user
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
} else {
log.Error("SQL error GetUser %v", err)
@@ -128,7 +129,6 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
}
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
tx, err := db.db.Begin()
if err != nil {
log.Error("Could not start transaction: %v", err)
@@ -216,7 +216,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
log.Info("Token '%v' not found", token)
return nil, ErrNotFound
} else {
@@ -241,7 +241,6 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
}
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(`
SELECT token, created_at, expires_at
FROM token
@@ -257,7 +256,6 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.
}
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(`
SELECT token, created_at, expires_at
FROM token
@@ -325,7 +323,6 @@ func (db AuthSqlite) DeleteToken(token string) error {
}
func (db AuthSqlite) InsertSession(session *types.Session) error {
_, err := db.db.Exec(`
INSERT INTO session (session_id, user_id, created_at, expires_at)
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
@@ -339,7 +336,6 @@ func (db AuthSqlite) InsertSession(session *types.Session) error {
}
func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
var (
userId uuid.UUID
createdAt time.Time
@@ -360,9 +356,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
}
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
sessions, err := db.db.Query(`
SELECT session_id, created_at, expires_at
var sessions []*types.Session
err := db.db.Select(&sessions, `
SELECT *
FROM session
WHERE user_id = ?`, userId)
if err != nil {
@@ -370,26 +366,7 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
return nil, types.ErrInternal
}
var result []*types.Session
for sessions.Next() {
var (
sessionId string
createdAt time.Time
expiresAt time.Time
)
err := sessions.Scan(&sessionId, &createdAt, &expiresAt)
if err != nil {
log.Error("Could not scan session: %v", err)
return nil, types.ErrInternal
}
session := types.NewSession(sessionId, userId, createdAt, expiresAt)
result = append(result, session)
}
return result, nil
return sessions, nil
}
func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {

View File

@@ -1,6 +1,7 @@
package db
package db_test
import (
"spend-sparrow/db"
"spend-sparrow/types"
"testing"
"time"
@@ -8,26 +9,29 @@ import (
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setupDb(t *testing.T) *sqlx.DB {
db, err := sqlx.Open("sqlite3", ":memory:")
t.Helper()
d, err := sqlx.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Error opening database: %v", err)
}
t.Cleanup(func() {
err := db.Close()
err := d.Close()
if err != nil {
panic(err)
}
})
err = RunMigrations(db, "../")
err = db.RunMigrations(d, "../")
if err != nil {
t.Fatalf("Error running migrations: %v", err)
}
return db
return d
}
func TestUser(t *testing.T) {
@@ -35,55 +39,55 @@ func TestUser(t *testing.T) {
t.Run("should insert and get the same", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(expected)
assert.Nil(t, err)
require.NoError(t, err)
actual, err := underTest.GetUser(expected.Id)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, expected, actual)
actual, err = underTest.GetUserByEmail(expected.Email)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, expected, actual)
})
t.Run("should return ErrNotFound", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
_, err := underTest.GetUserByEmail("nonExistentEmail")
assert.Equal(t, ErrNotFound, err)
assert.Equal(t, db.ErrNotFound, err)
})
t.Run("should return ErrUserExist", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(user)
assert.Nil(t, err)
require.NoError(t, err)
err = underTest.InsertUser(user)
assert.Equal(t, ErrAlreadyExists, err)
assert.Equal(t, db.ErrAlreadyExists, err)
})
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
@@ -98,37 +102,37 @@ func TestToken(t *testing.T) {
t.Run("should insert and get the same", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour)
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(expected)
assert.Nil(t, err)
require.NoError(t, err)
actual, err := underTest.GetToken(expected.Token)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, expected, actual)
expected.SessionId = ""
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals)
expected.SessionId = "sessionId"
expected.UserId = uuid.Nil
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, []*types.Token{expected}, actuals)
})
t.Run("should insert and return multiple tokens", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour)
@@ -137,14 +141,14 @@ func TestToken(t *testing.T) {
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(expected1)
assert.Nil(t, err)
require.NoError(t, err)
err = underTest.InsertToken(expected2)
assert.Nil(t, err)
require.NoError(t, err)
expected1.UserId = uuid.Nil
expected2.UserId = uuid.Nil
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
expected1.SessionId = ""
@@ -152,46 +156,45 @@ func TestToken(t *testing.T) {
expected1.UserId = userId
expected2.UserId = userId
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
})
t.Run("should return ErrNotFound", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
_, err := underTest.GetToken("nonExistent")
assert.Equal(t, ErrNotFound, err)
assert.Equal(t, db.ErrNotFound, err)
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify)
assert.Equal(t, ErrNotFound, err)
assert.Equal(t, db.ErrNotFound, err)
_, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify)
assert.Equal(t, ErrNotFound, err)
assert.Equal(t, db.ErrNotFound, err)
})
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(user)
assert.Nil(t, err)
require.NoError(t, err)
err = underTest.InsertUser(user)
assert.Equal(t, ErrAlreadyExists, err)
assert.Equal(t, db.ErrAlreadyExists, err)
})
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
t.Parallel()
db := setupDb(t)
d := setupDb(t)
underTest := AuthSqlite{db: db}
underTest := db.NewAuthSqlite(d)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)

View File

@@ -14,7 +14,7 @@ var (
func TransformAndLogDbError(module string, r sql.Result, err error) error {
if err != nil {
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return ErrNotFound
}
log.Error("%v: %v", module, err)

View File

@@ -77,7 +77,6 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
func (handler AuthImpl) handleSignIn() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) {
session := middleware.GetSession(r)
email := r.FormValue("email")
@@ -95,7 +94,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
})
if err != nil {
if err == service.ErrInvalidCredentials {
if errors.Is(err, service.ErrInvalidCredentials) {
utils.TriggerToastWithStatus(w, r, "error", "Invalid email or password", http.StatusUnauthorized)
} else {
utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError)
@@ -166,7 +165,6 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
err := handler.service.VerifyUserEmail(token)
@@ -203,13 +201,14 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
})
if err != nil {
if errors.Is(err, types.ErrInternal) {
switch {
case errors.Is(err, types.ErrInternal):
utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError)
return
} else if errors.Is(err, service.ErrInvalidEmail) {
case errors.Is(err, service.ErrInvalidEmail):
utils.TriggerToastWithStatus(w, r, "error", "The email provided is invalid", http.StatusBadRequest)
return
} else if errors.Is(err, service.ErrInvalidPassword) {
case errors.Is(err, service.ErrInvalidPassword):
utils.TriggerToastWithStatus(w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest)
return
}
@@ -272,7 +271,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
err := handler.service.DeleteAccount(user, password)
if err != nil {
if err == service.ErrInvalidCredentials {
if errors.Is(err, service.ErrInvalidCredentials) {
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
} else {
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
@@ -286,7 +285,6 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
isPasswordReset := r.URL.Query().Has("token")
user := middleware.GetUser(r)
@@ -303,7 +301,6 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := middleware.GetSession(r)
user := middleware.GetUser(r)
if session == nil || user == nil {
@@ -326,7 +323,6 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUser(r)
if user != nil {
utils.DoRedirect(w, r, "/")
@@ -340,7 +336,6 @@ func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
email := r.FormValue("email")
if email == "" {
utils.TriggerToastWithStatus(w, r, "error", "Please enter an email", http.StatusBadRequest)
@@ -362,7 +357,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
pageUrl, err := url.Parse(r.Header.Get("HX-Current-URL"))
pageUrl, err := url.Parse(r.Header.Get("Hx-Current-Url"))
if err != nil {
log.Error("Could not get current URL: %v", err)
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)

View File

@@ -10,13 +10,14 @@ import (
)
func handleError(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, service.ErrUnauthorized) {
switch {
case errors.Is(err, service.ErrUnauthorized):
utils.TriggerToastWithStatus(w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized)
return
} else if errors.Is(err, service.ErrBadRequest) {
case errors.Is(err, service.ErrBadRequest):
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
return
} else if errors.Is(err, db.ErrNotFound) {
case errors.Is(err, db.ErrNotFound):
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound)
return
}

View File

@@ -16,7 +16,6 @@ var UserKey ContextKey = "user"
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sessionId := getSessionID(r)
session, user, _ := service.SignInSession(sessionId)
@@ -49,7 +48,12 @@ func GetUser(r *http.Request) *types.User {
return nil
}
return obj.(*types.User)
user, ok := obj.(*types.User)
if !ok {
return nil
}
return user
}
func GetSession(r *http.Request) *types.Session {
@@ -58,7 +62,12 @@ func GetSession(r *http.Request) *types.Session {
return nil
}
return obj.(*types.Session)
session, ok := obj.(*types.Session)
if !ok {
return nil
}
return session
}
func getSessionID(r *http.Request) string {

View File

@@ -7,7 +7,6 @@ import (
func CacheControl(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
shouldCache := strings.HasPrefix(r.URL.Path, "/static")
if !shouldCache {

View File

@@ -37,19 +37,17 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := GetSession(r)
if r.Method == http.MethodPost ||
r.Method == http.MethodPut ||
r.Method == http.MethodDelete ||
r.Method == http.MethodPatch {
csrfToken := r.Header.Get("csrf-token")
csrfToken := r.Header.Get("Csrf-Token")
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
log.Info("CSRF-Token \"%s\" not correct", csrfToken)
if r.Header.Get("HX-Request") == "true" {
if r.Header.Get("Hx-Request") == "true" {
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
} else {
http.Error(w, "CSRF-Token not correct", http.StatusBadRequest)

View File

@@ -2,6 +2,7 @@ package middleware
import (
"compress/gzip"
"errors"
"io"
"net/http"
"strings"
@@ -32,8 +33,7 @@ func Gzip(next http.Handler) http.Handler {
next.ServeHTTP(wrapper, r)
err := gz.Close()
if err != nil && err != http.ErrBodyNotAllowed {
// if err != nil {
if err != nil && !errors.Is(err, http.ErrBodyNotAllowed) {
log.Error("Gzip: could not close Writer: %v", err)
}
})

View File

@@ -7,7 +7,6 @@ import (
)
func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
@@ -30,7 +29,7 @@ func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Han
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
if r.Method == "OPTIONS" {
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}

View File

@@ -2,12 +2,12 @@ package middleware
import "net/http"
// Chain list of handlers together
// Chain list of handlers together.
func Wrapper(next http.Handler, handlers ...func(http.Handler) http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
lastHandler := next
for i := 0; i < len(handlers); i++ {
lastHandler = handlers[i](lastHandler)
for _, handler := range handlers {
lastHandler = handler(lastHandler)
}
lastHandler.ServeHTTP(w, r)
})

View File

@@ -44,7 +44,6 @@ func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWri
}
func (render *Render) getUserComp(user *types.User) templ.Component {
if user != nil {
return auth.UserComp(user.Email)
} else {

View File

@@ -106,7 +106,6 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
}
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) {
var transactionsRecurring []*types.TransactionRecurring
var err error
if accountId == "" && treasureChestId == "" {

26
main.go
View File

@@ -1,6 +1,8 @@
package main
import (
"errors"
"fmt"
"spend-sparrow/db"
"spend-sparrow/handler"
"spend-sparrow/handler/middleware"
@@ -37,10 +39,14 @@ func main() {
log.Fatal("Could not close Database data.db: %v", err)
}()
run(context.Background(), db, os.Getenv)
err = run(context.Background(), db, os.Getenv)
if err != nil {
log.Error("Error running server: %v", err)
return
}
}
func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
func run(ctx context.Context, database *sqlx.DB, env func(string) string) error {
ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer cancel()
@@ -52,22 +58,24 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
// init db
err := db.RunMigrations(database, "")
if err != nil {
log.Fatal("Could not run migrations: %v", err)
return fmt.Errorf("could not run migrations: %w", err)
}
// init servers
var prometheusServer *http.Server
if serverSettings.PrometheusEnabled {
prometheusServer := &http.Server{
Addr: ":8081",
Handler: promhttp.Handler(),
Addr: ":8081",
Handler: promhttp.Handler(),
ReadHeaderTimeout: 10 * time.Second,
}
go startServer(prometheusServer)
}
httpServer := &http.Server{
Addr: ":" + serverSettings.Port,
Handler: createHandler(database, serverSettings),
Addr: ":" + serverSettings.Port,
Handler: createHandler(database, serverSettings),
ReadHeaderTimeout: 10 * time.Second,
}
go startServer(httpServer)
@@ -77,11 +85,13 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
go shutdownServer(httpServer, ctx, &wg)
go shutdownServer(prometheusServer, ctx, &wg)
wg.Wait()
return nil
}
func startServer(s *http.Server) {
log.Info("Starting server on %q", s.Addr)
if err := s.ListenAndServe(); err != nil && err != http.ErrServerClosed {
if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Error("error listening and serving: %v", err)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
package service
import (
"errors"
"fmt"
"spend-sparrow/db"
@@ -119,7 +120,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Update", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
}
return nil, types.ErrInternal
@@ -164,8 +165,8 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
account := &types.Account{}
err = s.db.Get(account, `
var account types.Account
err = s.db.Get(&account, `
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("account Get", nil, err)
if err != nil {
@@ -173,7 +174,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
return nil, err
}
return account, nil
return &account, nil
}
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {

View File

@@ -94,30 +94,6 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
return session, user, nil
}
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
if session == nil {
return nil
}
err := service.db.DeleteSession(session.Id)
if err != nil {
return types.ErrInternal
}
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
if err != nil {
return types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(token.Token)
if err != nil {
return types.ErrInternal
}
}
return nil
}
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
if sessionId == "" {
return nil, nil, ErrSessionIdInvalid
@@ -155,30 +131,6 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
return session, nil
}
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
sessionId, err := service.random.String(32)
if err != nil {
return nil, types.ErrInternal
}
err = service.db.DeleteOldSessions(userId)
if err != nil {
return nil, types.ErrInternal
}
createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour)
session := types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session)
if err != nil {
return nil, types.ErrInternal
}
return session, nil
}
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
_, err := mail.ParseAddress(email)
if err != nil {
@@ -205,7 +157,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
err = service.db.InsertUser(user)
if err != nil {
if err == db.ErrAlreadyExists {
if errors.Is(err, db.ErrAlreadyExists) {
return nil, ErrAccountExists
} else {
return nil, types.ErrInternal
@@ -216,9 +168,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
}
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
if err != nil && err != db.ErrNotFound {
if err != nil && !errors.Is(err, db.ErrNotFound) {
return
}
@@ -234,7 +185,13 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
return
}
token = types.NewToken(userId, "", newTokenStr, types.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour))
token = types.NewToken(
userId,
"",
newTokenStr,
types.TokenTypeEmailVerify,
service.clock.Now(),
service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token)
if err != nil {
@@ -253,7 +210,6 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
}
func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
if tokenStr == "" {
return types.ErrInternal
}
@@ -291,12 +247,10 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
}
func (service AuthImpl) SignOut(sessionId string) error {
return service.db.DeleteSession(sessionId)
}
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
userDb, err := service.db.GetUser(user.Id)
if err != nil {
return types.ErrInternal
@@ -318,7 +272,6 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
}
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
if !isPasswordValid(newPass) {
return ErrInvalidPassword
}
@@ -365,14 +318,20 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
user, err := service.db.GetUserByEmail(email)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil
} else {
return types.ErrInternal
}
}
token := types.NewToken(user.Id, "", tokenStr, types.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute))
token := types.NewToken(
user.Id,
"",
tokenStr,
types.TokenTypePasswordReset,
service.clock.Now(),
service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(token)
if err != nil {
@@ -391,7 +350,6 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
}
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
if !isPasswordValid(newPass) {
return ErrInvalidPassword
}
@@ -449,7 +407,6 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
if token.Type != types.TokenTypeCsrf ||
token.SessionId != sessionId ||
token.ExpiresAt.Before(service.clock.Now()) {
return false
}
@@ -472,7 +429,13 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
return "", types.ErrInternal
}
token := types.NewToken(session.UserId, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(8*time.Hour))
token := types.NewToken(
session.UserId,
session.Id,
tokenStr,
types.TokenTypeCsrf,
service.clock.Now(),
service.clock.Now().Add(8*time.Hour))
err = service.db.InsertToken(token)
if err != nil {
return "", types.ErrInternal
@@ -483,12 +446,59 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
return tokenStr, nil
}
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
if session == nil {
return nil
}
err := service.db.DeleteSession(session.Id)
if err != nil {
return types.ErrInternal
}
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
if err != nil {
return types.ErrInternal
}
for _, token := range tokens {
err = service.db.DeleteToken(token.Token)
if err != nil {
return types.ErrInternal
}
}
return nil
}
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
sessionId, err := service.random.String(32)
if err != nil {
return nil, types.ErrInternal
}
err = service.db.DeleteOldSessions(userId)
if err != nil {
return nil, types.ErrInternal
}
createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour)
session := types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session)
if err != nil {
return nil, types.ErrInternal
}
return session, nil
}
func GetHashPassword(password string, salt []byte) []byte {
return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16)
}
func isPasswordValid(password string) bool {
if len(password) < 8 ||
!strings.ContainsAny(password, "0123456789") ||
!strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") ||

View File

@@ -1,8 +1,9 @@
package service
package service_test
import (
"spend-sparrow/db"
"spend-sparrow/mocks"
"spend-sparrow/service"
"spend-sparrow/types"
"strings"
@@ -12,6 +13,17 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
var (
settings = types.Settings{
Port: "",
PrometheusEnabled: false,
BaseUrl: "",
Environment: "test",
Smtp: nil,
}
)
func TestSignUp(t *testing.T) {
@@ -24,11 +36,11 @@ func TestSignUp(t *testing.T) {
mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t)
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!")
assert.Equal(t, ErrInvalidEmail, err)
assert.Equal(t, service.ErrInvalidEmail, err)
})
t.Run("should check for password complexity", func(t *testing.T) {
t.Parallel()
@@ -38,7 +50,7 @@ func TestSignUp(t *testing.T) {
mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t)
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
weakPasswords := []string{
"123!ab", // too short
@@ -49,7 +61,7 @@ func TestSignUp(t *testing.T) {
for _, password := range weakPasswords {
_, err := underTest.SignUp("some@valid.email", password)
assert.Equal(t, ErrInvalidPassword, err)
assert.Equal(t, service.ErrInvalidPassword, err)
}
})
t.Run("should signup correctly", func(t *testing.T) {
@@ -66,17 +78,17 @@ func TestSignUp(t *testing.T) {
salt := []byte("salt")
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
expected := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime)
expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
mockRandom.EXPECT().UUID().Return(userId, nil)
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(expected).Return(nil)
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
actual, err := underTest.SignUp(email, password)
assert.Nil(t, err)
require.NoError(t, err)
assert.Equal(t, expected, actual)
})
@@ -93,7 +105,7 @@ func TestSignUp(t *testing.T) {
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
password := "SomeStrongPassword123!"
salt := []byte("salt")
user := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime)
user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
mockRandom.EXPECT().UUID().Return(user.Id, nil)
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
@@ -101,20 +113,25 @@ func TestSignUp(t *testing.T) {
mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists)
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
_, err := underTest.SignUp(user.Email, password)
assert.Equal(t, ErrAccountExists, err)
assert.Equal(t, service.ErrAccountExists, err)
})
}
func TestSendVerificationMail(t *testing.T) {
t.Parallel()
t.Run("should use stored token and send mail", func(t *testing.T) {
t.Parallel()
token := types.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", types.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
token := types.NewToken(
uuid.New(),
"sessionId",
"someRandomTokenToUse",
types.TokenTypeEmailVerify,
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
tokens := []*types.Token{token}
email := "some@email.de"
@@ -131,7 +148,7 @@ func TestSendVerificationMail(t *testing.T) {
return strings.Contains(message, token.Token)
})).Return()
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
underTest.SendVerificationMail(userId, email)
})

View File

@@ -5,16 +5,21 @@ import (
"regexp"
)
const (
DECIMALS_MULTIPLIER = 100
)
var (
safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9ÄÖÜäöüß,&'" -]+$`)
)
func validateString(value string, fieldName string) error {
if value == "" {
switch {
case value == "":
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest)
} else if !safeInputRegex.MatchString(value) {
case !safeInputRegex.MatchString(value):
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest)
} else {
default:
return nil
}
}

View File

@@ -34,7 +34,19 @@ func (m MailImpl) internalSendMail(to string, subject string, message string) {
auth := smtp.PlainAuth("", s.User, s.Pass, s.Host)
msg := fmt.Sprintf("From: %v <%v>\nTo: %v\nSubject: %v\nMIME-version: 1.0;\nContent-Type: text/html; charset=\"UTF-8\";\n\n%v", s.FromName, s.FromMail, to, subject, message)
msg := fmt.Sprintf(
`From: %v <%v>
To: %v
Subject: %v
MIME-version: 1.0;
Content-Type: text/html; charset="UTF-8";
%v`,
s.FromName,
s.FromMail,
to,
subject,
message)
log.Info("Sending mail to %v", to)
err := smtp.SendMail(s.Host+":"+s.Port, auth, s.FromMail, []string{to}, []byte(msg))

View File

@@ -1,8 +0,0 @@
package service
type MoneyImpl struct {
}
func NewMoneyImpl() *MoneyImpl {
return &MoneyImpl{}
}

View File

@@ -1,80 +0,0 @@
package service
import (
"testing"
)
func TestMoneyCalculation(t *testing.T) {
t.Parallel()
t.Run("should calculate correct oink balance", func(t *testing.T) {
// t.Parallel()
//
// underTest := NewMoneyImpl()
//
// // GIVEN
// timestamp := time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC)
//
// userId := uuid.New()
//
// account := types.Account{
// Id: uuid.New(),
// UserId: userId,
//
// Type: "Bank",
// Name: "Bank",
//
// CurrentBalance: 0,
// LastTransaction: time.Time{},
// OinkBalance: 0,
// }
//
// // The PiggyBank is a fictional account. The money it "holds" is actually in the Account
// piggyBank := types.PiggyBank{
// Id: uuid.New(),
// UserId: userId,
//
// AccountId: account.Id,
// Name: "Car",
//
// CurrentBalance: 0,
// }
//
// savingsPlan := types.SavingsPlan{
// Id: uuid.New(),
// UserId: userId,
// PiggyBankId: piggyBank.Id,
//
// MonthlySaving: 10,
//
// ValidFrom: timestamp,
// }
//
// transaction1 := types.Transaction{
// Id: uuid.New(),
// UserId: userId,
//
// AccountId: account.Id,
//
// Value: 20,
// Timestamp: timestamp,
// }
//
// transaction2 := types.Transaction{
// Id: uuid.New(),
// UserId: userId,
//
// AccountId: account.Id,
// PiggyBankId: &piggyBank.Id,
//
// Value: -1,
// Timestamp: timestamp.Add(1 * time.Hour),
// }
//
// // WHEN
// actual, err := underTest.CalculateAllBalancesInTime(account, piggyBank, savingsPlan, []types.Transaction{transaction1, transaction2})
//
// // THEN
// assert.Nil(t, err)
// assert.ElementsMatch(t, expected, actual)
})
}

View File

@@ -1,6 +1,7 @@
package service
import (
"errors"
"fmt"
"strconv"
"time"
@@ -73,8 +74,10 @@ func (s TransactionImpl) Add(user *types.User, transactionInput types.Transactio
}
r, err := tx.NamedExec(`
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp, party, description, error, created_at, created_by)
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp, :party, :description, :error, :created_at, :created_by)`, transaction)
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
party, description, error, created_at, created_by)
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
:party, :description, :error, :created_at, :created_by)`, transaction)
err = db.TransformAndLogDbError("transaction Insert", r, err)
if err != nil {
return nil, err
@@ -135,7 +138,7 @@ func (s TransactionImpl) Update(user *types.User, input types.TransactionInput)
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Update", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest)
}
return nil, types.ErrInternal
@@ -232,7 +235,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transaction Get", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest)
}
return nil, types.ErrInternal
@@ -259,7 +262,10 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF
OR (? = "false" AND error IS NULL)
)
ORDER BY timestamp DESC`,
user.Id, filter.AccountId, filter.AccountId, filter.TreasureChestId, filter.TreasureChestId, filter.Error, filter.Error, filter.Error)
user.Id,
filter.AccountId, filter.AccountId,
filter.TreasureChestId, filter.TreasureChestId,
filter.Error, filter.Error, filter.Error)
err = db.TransformAndLogDbError("transaction GetAll", nil, err)
if err != nil {
return nil, err
@@ -302,7 +308,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil && err != db.ErrNotFound {
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
}
@@ -314,7 +320,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
WHERE id = ?
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
err = db.TransformAndLogDbError("transaction Delete", r, err)
if err != nil && err != db.ErrNotFound {
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
}
@@ -354,7 +360,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
SET current_balance = 0
WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil && err != db.ErrNotFound {
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
@@ -363,7 +369,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
SET current_balance = 0
WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
if err != nil && err != db.ErrNotFound {
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
@@ -372,7 +378,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
FROM "transaction"
WHERE user_id = ?`, user.Id)
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil && err != db.ErrNotFound {
if err != nil && !errors.Is(err, db.ErrNotFound) {
return err
}
defer func() {
@@ -382,15 +388,15 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
}
}()
transaction := &types.Transaction{}
var transaction types.Transaction
for rows.Next() {
err = rows.StructScan(transaction)
err = rows.StructScan(&transaction)
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
if err != nil {
return err
}
s.updateErrors(transaction)
s.updateErrors(&transaction)
r, err = tx.Exec(`
UPDATE "transaction"
SET error = ?
@@ -424,7 +430,6 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
if err != nil {
return err
}
}
}
@@ -438,7 +443,6 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
}
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.TransactionInput) (*types.Transaction, error) {
var (
id uuid.UUID
accountUuid *uuid.UUID
@@ -484,7 +488,6 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
log.Error("transaction validate: %v", err)
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
}
}
if input.TreasureChestId != "" {
@@ -498,7 +501,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError("transaction validate", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
}
return nil, err
@@ -513,7 +516,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
log.Error("transaction validate: %v", err)
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
}
valueInt := int64(valueFloat * 100)
valueInt := int64(valueFloat * DECIMALS_MULTIPLIER)
timestamp, err := time.Parse("2006-01-02", input.Timestamp)
if err != nil {
@@ -544,6 +547,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
Timestamp: timestamp,
Party: input.Party,
Description: input.Description,
Error: nil,
CreatedAt: createdAt,
CreatedBy: createdBy,
@@ -557,25 +561,26 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
}
func (s TransactionImpl) updateErrors(transaction *types.Transaction) {
error := ""
errorStr := ""
if transaction.Value < 0 {
switch {
case transaction.Value < 0:
if transaction.TreasureChestId == nil {
error = "no treasure chest specified"
errorStr = "no treasure chest specified"
}
} else if transaction.Value > 0 {
case transaction.Value > 0:
if transaction.AccountId == nil && transaction.TreasureChestId == nil {
error = "either an account or a treasure chest needs to be specified"
errorStr = "either an account or a treasure chest needs to be specified"
} else if transaction.AccountId != nil && transaction.TreasureChestId != nil {
error = "positive amounts can only be applied to either an account or a treasure chest"
errorStr = "positive amounts can only be applied to either an account or a treasure chest"
}
} else {
error = "\"value\" needs to be specified"
default:
errorStr = "\"value\" needs to be specified"
}
if error == "" {
if errorStr == "" {
transaction.Error = nil
} else {
transaction.Error = &error
transaction.Error = &errorStr
}
}

View File

@@ -1,6 +1,7 @@
package service
import (
"errors"
"fmt"
"strconv"
"time"
@@ -18,8 +19,8 @@ import (
var (
transactionRecurringMetric = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "spendsparrow_transactionRecurring_recurring_total",
Help: "The total of transactionRecurring recurring operations",
Name: "spendsparrow_transaction_recurring_total",
Help: "The total of transactionRecurring operations",
},
[]string{"operation"},
)
@@ -28,7 +29,6 @@ var (
type TransactionRecurring interface {
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
Get(user *types.User, id string) (*types.TransactionRecurring, error)
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error)
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
Delete(user *types.User, id string) error
@@ -50,7 +50,9 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, settings *
}
}
func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
func (s TransactionRecurringImpl) Add(
user *types.User,
transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("add").Inc()
if user == nil {
@@ -72,8 +74,11 @@ func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInpu
}
r, err := tx.NamedExec(`
INSERT INTO "transaction_recurring" (id, user_id, interval_months, active, party, description, account_id, treasure_chest_id, value, created_at, created_by)
VALUES (:id, :user_id, :interval_months, :active, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`, transactionRecurring)
INSERT INTO "transaction_recurring" (id, user_id, interval_months,
active, party, description, account_id, treasure_chest_id, value, created_at, created_by)
VALUES (:id, :user_id, :interval_months,
:active, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`,
transactionRecurring)
err = db.TransformAndLogDbError("transactionRecurring Insert", r, err)
if err != nil {
return nil, err
@@ -88,7 +93,9 @@ func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInpu
return transactionRecurring, nil
}
func (s TransactionRecurringImpl) Update(user *types.User, input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
func (s TransactionRecurringImpl) Update(
user *types.User,
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("update").Inc()
if user == nil {
return nil, ErrUnauthorized
@@ -112,7 +119,7 @@ func (s TransactionRecurringImpl) Update(user *types.User, input types.Transacti
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest)
}
return nil, types.ErrInternal
@@ -151,31 +158,6 @@ func (s TransactionRecurringImpl) Update(user *types.User, input types.Transacti
return transactionRecurring, nil
}
func (s TransactionRecurringImpl) Get(user *types.User, id string) (*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("get").Inc()
if user == nil {
return nil, ErrUnauthorized
}
uuid, err := uuid.Parse(id)
if err != nil {
log.Error("transactionRecurring get: %v", err)
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
var transactionRecurring types.TransactionRecurring
err = s.db.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("transactionRecurring Get", nil, err)
if err != nil {
if err == db.ErrNotFound {
return nil, fmt.Errorf("transactionRecurring %v not found: %w", id, ErrBadRequest)
}
return nil, types.ErrInternal
}
return &transactionRecurring, nil
}
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
transactionRecurringMetric.WithLabelValues("get_all_by_account").Inc()
if user == nil {
@@ -201,7 +183,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest)
}
return nil, types.ErrInternal
@@ -254,7 +236,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(user *types.User, treasu
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest)
}
return nil, types.ErrInternal
@@ -329,7 +311,6 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
oldTransactionRecurring *types.TransactionRecurring,
userId uuid.UUID,
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
var (
id uuid.UUID
accountUuid *uuid.UUID
@@ -393,7 +374,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
}
return nil, err
@@ -418,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
log.Error("transactionRecurring validate: %v", err)
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
}
valueInt := int64(valueFloat * 100)
valueInt := int64(valueFloat * DECIMALS_MULTIPLIER)
if input.Party != "" {
err = validateString(input.Party, "party")
@@ -444,12 +425,12 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
active := input.Active == "on"
transactionRecurring := types.TransactionRecurring{
Id: id,
UserId: userId,
IntervalMonths: intervalMonths,
Active: active,
LastExecution: nil,
Party: input.Party,
Description: input.Description,

View File

@@ -1,6 +1,7 @@
package service
import (
"errors"
"fmt"
"slices"
@@ -131,7 +132,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
}
return nil, types.ErrInternal
@@ -198,17 +199,17 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
}
treasureChest := &types.TreasureChest{}
err = s.db.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
var treasureChest types.TreasureChest
err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
err = db.TransformAndLogDbError("treasureChest Get", nil, err)
if err != nil {
if err == db.ErrNotFound {
if errors.Is(err, db.ErrNotFound) {
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
}
return nil, types.ErrInternal
}
return treasureChest, nil
return &treasureChest, nil
}
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) {
@@ -259,7 +260,9 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
}
transactionsCount := 0
err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`, user.Id, id)
err = tx.Get(&transactionsCount,
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
user.Id, id)
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
if err != nil {
return err
@@ -284,12 +287,11 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
}
func sortTree(nodes []*types.TreasureChest) []*types.TreasureChest {
var (
roots []*types.TreasureChest
result []*types.TreasureChest
roots []*types.TreasureChest
)
children := make(map[uuid.UUID][]*types.TreasureChest)
result := make([]*types.TreasureChest, 0)
for _, node := range nodes {
if node.ParentId == nil {

View File

@@ -111,6 +111,7 @@ templ AccountItem(account *types.Account) {
hx-target="closest #account"
hx-swap="outerHTML"
class="button button-neglect px-1 flex items-center gap-2"
hx-confirm="Are you sure you want to delete this account?"
>
@svg.Delete()
<span>

View File

@@ -28,7 +28,7 @@ templ Layout(slot templ.Component, user templ.Component, loggedIn bool, path str
<script src="/static/js/toast.js"></script>
<script src="/static/js/time.js"></script>
</head>
<body class="h-screen flex flex-col" hx-headers='{"csrf-token": "CSRF_TOKEN"}'>
<body class="h-screen flex flex-col" hx-headers='{"Csrf-Token": "CSRF_TOKEN"}'>
// Header
<nav class="flex bg-white items-center gap-2 py-1 px-2 h-12 md:gap-10 md:px-10 md:py-2">
<a href="/" class="flex gap-2 mr-20">

View File

@@ -6,13 +6,13 @@ import (
"github.com/google/uuid"
)
// The Account holds money
// The Account holds money.
type Account struct {
Id uuid.UUID
Id uuid.UUID `db:"id"`
UserId uuid.UUID `db:"user_id"`
// Custom Name of the account, e.g. "Bank", "Cash", "Credit Card"
Name string
Name string `db:"name"`
CurrentBalance int64 `db:"current_balance"`
LastTransaction *time.Time `db:"last_transaction"`

View File

@@ -17,7 +17,15 @@ type User struct {
CreateAt time.Time
}
func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User {
func NewUser(
id uuid.UUID,
email string,
emailVerified bool,
emailVerifiedAt *time.Time,
isAdmin bool,
password []byte,
salt []byte,
createAt time.Time) *User {
return &User{
Id: id,
Email: email,
@@ -31,10 +39,10 @@ func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *ti
}
type Session struct {
Id string
UserId uuid.UUID
CreatedAt time.Time
ExpiresAt time.Time
Id string `db:"session_id"`
UserId uuid.UUID `db:"user_id"`
CreatedAt time.Time `db:"created_at"`
ExpiresAt time.Time `db:"expires_at"`
}
func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session {
@@ -63,7 +71,13 @@ var (
TokenTypeCsrf TokenType = "csrf"
)
func NewToken(userId uuid.UUID, sessionId string, token string, tokenType TokenType, createdAt time.Time, expiresAt time.Time) *Token {
func NewToken(
userId uuid.UUID,
sessionId string,
token string,
tokenType TokenType,
createdAt time.Time,
expiresAt time.Time) *Token {
return &Token{
UserId: userId,
SessionId: sessionId,

View File

@@ -1,26 +0,0 @@
package types
import (
"time"
"github.com/google/uuid"
)
// The SavingsPlan is applied every interval to the TreasureChest/Account as a transaction
type SavingsPlan struct {
Id uuid.UUID
UserId uuid.UUID `db:"user_id"`
TreasureChestId uuid.UUID `db:"treasure_chest_id"`
MonthlySaving int64 `db:"monthly_saving"`
ValidFrom time.Time `db:"valid_from"`
/// nil means it is valid indefinitely
ValidTo *time.Time `db:"valid_to"`
CreatedAt time.Time `db:"created_at"`
CreatedBy uuid.UUID `db:"created_by"`
UpdatedAt *time.Time `db:"updated_at"`
UpdatedBy *uuid.UUID `db:"updated_by"`
}

View File

@@ -23,36 +23,9 @@ type SmtpSettings struct {
}
func NewSettingsFromEnv(env func(string) string) *Settings {
var smtp *SmtpSettings
if env("SMTP_ENABLED") == "true" {
smtp = &SmtpSettings{
Host: env("SMTP_HOST"),
Port: env("SMTP_PORT"),
User: env("SMTP_USER"),
Pass: env("SMTP_PASS"),
FromMail: env("SMTP_FROM_MAIL"),
FromName: env("SMTP_FROM_NAME"),
}
if smtp.Host == "" {
log.Fatal("SMTP_HOST must be set")
}
if smtp.Port == "" {
log.Fatal("SMTP_PORT must be set")
}
if smtp.User == "" {
log.Fatal("SMTP_USER must be set")
}
if smtp.Pass == "" {
log.Fatal("SMTP_PASS must be set")
}
if smtp.FromMail == "" {
log.Fatal("SMTP_FROM_MAIL must be set")
}
if smtp.FromName == "" {
log.Fatal("SMTP_FROM_NAME must be set")
}
smtp = getSmtpSettings(env)
}
settings := &Settings{
@@ -78,3 +51,35 @@ func NewSettingsFromEnv(env func(string) string) *Settings {
return settings
}
func getSmtpSettings(env func(string) string) *SmtpSettings {
smtp := SmtpSettings{
Host: env("SMTP_HOST"),
Port: env("SMTP_PORT"),
User: env("SMTP_USER"),
Pass: env("SMTP_PASS"),
FromMail: env("SMTP_FROM_MAIL"),
FromName: env("SMTP_FROM_NAME"),
}
if smtp.Host == "" {
log.Fatal("SMTP_HOST must be set")
}
if smtp.Port == "" {
log.Fatal("SMTP_PORT must be set")
}
if smtp.User == "" {
log.Fatal("SMTP_USER must be set")
}
if smtp.Pass == "" {
log.Fatal("SMTP_PASS must be set")
}
if smtp.FromMail == "" {
log.Fatal("SMTP_FROM_MAIL must be set")
}
if smtp.FromName == "" {
log.Fatal("SMTP_FROM_NAME must be set")
}
return &smtp
}

View File

@@ -14,19 +14,19 @@ import (
// If it becomes necessary to precalculate snapshots for performance reasons, this can be done in the future.
// But the transaction should always be the source of truth.
type Transaction struct {
Id uuid.UUID
Id uuid.UUID `db:"id"`
UserId uuid.UUID `db:"user_id"`
Timestamp time.Time
Party string
Description string
Timestamp time.Time `db:"timestamp"`
Party string `db:"party"`
Description string `db:"description"`
AccountId *uuid.UUID `db:"account_id"`
TreasureChestId *uuid.UUID `db:"treasure_chest_id"`
Value int64
Value int64 `db:"value"`
// If an error is present, then the transaction is not valid and should not be used for calculations.
Error *string
Error *string `db:"error"`
CreatedAt time.Time `db:"created_at"`
// Either a user_id or a transaction_recurring_id
CreatedBy uuid.UUID `db:"created_by"`

View File

@@ -7,19 +7,19 @@ import (
)
type TransactionRecurring struct {
Id uuid.UUID
Id uuid.UUID `db:"id"`
UserId uuid.UUID `db:"user_id"`
IntervalMonths int64 `db:"interval_months"`
LastExecution *time.Time `db:"last_execution"`
Active bool
Active bool `db:"active"`
Party string
Description string
Party string `db:"party"`
Description string `db:"description"`
AccountId *uuid.UUID `db:"account_id"`
TreasureChestId *uuid.UUID `db:"treasure_chest_id"`
Value int64
Value int64 `db:"value"`
CreatedAt time.Time `db:"created_at"`
CreatedBy uuid.UUID `db:"created_by"`

View File

@@ -10,13 +10,13 @@ import (
// The money it "holds" distributed across all accounts
//
// At the time of writing this, linking it to a specific account doesn't really make sense
// Imagne a TreasureChest for free time activities, where some money is spend in cash and some other with credit card
// Imagine a TreasureChest for free time activities, where some money is spend in cash and some other with credit card.
type TreasureChest struct {
Id uuid.UUID
Id uuid.UUID `db:"id"`
ParentId *uuid.UUID `db:"parent_id"`
UserId uuid.UUID `db:"user_id"`
Name string
Name string `db:"name"`
CurrentBalance int64 `db:"current_balance"`

View File

@@ -11,7 +11,7 @@ import (
func TriggerToast(w http.ResponseWriter, r *http.Request, class string, message string) {
if IsHtmx(r) {
w.Header().Set("HX-Trigger", fmt.Sprintf(`{"toast": "%v|%v"}`, class, strings.ReplaceAll(message, `"`, `\"`)))
w.Header().Set("Hx-Trigger", fmt.Sprintf(`{"toast": "%v|%v"}`, class, strings.ReplaceAll(message, `"`, `\"`)))
} else {
log.Error("Trying to trigger toast in non-HTMX request")
}
@@ -24,19 +24,19 @@ func TriggerToastWithStatus(w http.ResponseWriter, r *http.Request, class string
func DoRedirect(w http.ResponseWriter, r *http.Request, url string) {
if IsHtmx(r) {
w.Header().Add("HX-Redirect", url)
w.Header().Add("Hx-Redirect", url)
} else {
http.Redirect(w, r, url, http.StatusSeeOther)
}
}
func WaitMinimumTime[T interface{}](waitTime time.Duration, function func() (T, error)) (T, error) {
func WaitMinimumTime[T interface{}](waitTime time.Duration, f func() (T, error)) (T, error) {
start := time.Now()
result, err := function()
result, err := f()
time.Sleep(waitTime - time.Since(start))
return result, err
}
func IsHtmx(r *http.Request) bool {
return r.Header.Get("HX-Request") == "true"
return r.Header.Get("Hx-Request") == "true"
}