fix: lint errors
This commit was merged in pull request #130.
This commit is contained in:
28
.golangci.yaml
Normal file
28
.golangci.yaml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
version: '2'
|
||||||
|
linters:
|
||||||
|
default: all
|
||||||
|
disable:
|
||||||
|
- wsl
|
||||||
|
- wrapcheck
|
||||||
|
- varnamelen
|
||||||
|
- revive # should probably be enabled
|
||||||
|
- nlreturn
|
||||||
|
- mnd # should probably be enabled
|
||||||
|
- lll # should probably be enabled
|
||||||
|
- ireturn # should probably be enabled
|
||||||
|
- interfacebloat
|
||||||
|
- iface
|
||||||
|
- goconst # should probably be enabled
|
||||||
|
- gocognit # should probably be enabled
|
||||||
|
- gochecknoglobals # should probably be enabled
|
||||||
|
- funlen
|
||||||
|
- maintidx
|
||||||
|
- exhaustruct
|
||||||
|
- dupword # should probably be enabled
|
||||||
|
- dupl # should probably be enabled
|
||||||
|
- depguard
|
||||||
|
- cyclop
|
||||||
|
- contextcheck
|
||||||
|
settings:
|
||||||
|
nestif:
|
||||||
|
min-complexity: 6
|
||||||
39
db/auth.go
39
db/auth.go
@@ -1,6 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"spend-sparrow/log"
|
"spend-sparrow/log"
|
||||||
"spend-sparrow/types"
|
"spend-sparrow/types"
|
||||||
|
|
||||||
@@ -89,7 +90,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
|
|||||||
FROM user
|
FROM user
|
||||||
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
} else {
|
} else {
|
||||||
log.Error("SQL error GetUser: %v", err)
|
log.Error("SQL error GetUser: %v", err)
|
||||||
@@ -116,7 +117,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
|
|||||||
FROM user
|
FROM user
|
||||||
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
} else {
|
} else {
|
||||||
log.Error("SQL error GetUser %v", err)
|
log.Error("SQL error GetUser %v", err)
|
||||||
@@ -128,7 +129,6 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
|
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
|
||||||
|
|
||||||
tx, err := db.db.Begin()
|
tx, err := db.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Could not start transaction: %v", err)
|
log.Error("Could not start transaction: %v", err)
|
||||||
@@ -216,7 +216,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
|
|||||||
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
|
WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
log.Info("Token '%v' not found", token)
|
log.Info("Token '%v' not found", token)
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
} else {
|
} else {
|
||||||
@@ -241,7 +241,6 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
|
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
|
||||||
|
|
||||||
query, err := db.db.Query(`
|
query, err := db.db.Query(`
|
||||||
SELECT token, created_at, expires_at
|
SELECT token, created_at, expires_at
|
||||||
FROM token
|
FROM token
|
||||||
@@ -257,7 +256,6 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
|
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
|
||||||
|
|
||||||
query, err := db.db.Query(`
|
query, err := db.db.Query(`
|
||||||
SELECT token, created_at, expires_at
|
SELECT token, created_at, expires_at
|
||||||
FROM token
|
FROM token
|
||||||
@@ -325,7 +323,6 @@ func (db AuthSqlite) DeleteToken(token string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) InsertSession(session *types.Session) error {
|
func (db AuthSqlite) InsertSession(session *types.Session) error {
|
||||||
|
|
||||||
_, err := db.db.Exec(`
|
_, err := db.db.Exec(`
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
|
VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt)
|
||||||
@@ -339,7 +336,6 @@ func (db AuthSqlite) InsertSession(session *types.Session) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
userId uuid.UUID
|
userId uuid.UUID
|
||||||
createdAt time.Time
|
createdAt time.Time
|
||||||
@@ -360,9 +356,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
|
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
|
||||||
|
var sessions []*types.Session
|
||||||
sessions, err := db.db.Query(`
|
err := db.db.Select(&sessions, `
|
||||||
SELECT session_id, created_at, expires_at
|
SELECT *
|
||||||
FROM session
|
FROM session
|
||||||
WHERE user_id = ?`, userId)
|
WHERE user_id = ?`, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -370,26 +366,7 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
|
|||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
var result []*types.Session
|
return sessions, nil
|
||||||
|
|
||||||
for sessions.Next() {
|
|
||||||
var (
|
|
||||||
sessionId string
|
|
||||||
createdAt time.Time
|
|
||||||
expiresAt time.Time
|
|
||||||
)
|
|
||||||
|
|
||||||
err := sessions.Scan(&sessionId, &createdAt, &expiresAt)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Could not scan session: %v", err)
|
|
||||||
return nil, types.ErrInternal
|
|
||||||
}
|
|
||||||
|
|
||||||
session := types.NewSession(sessionId, userId, createdAt, expiresAt)
|
|
||||||
result = append(result, session)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
|
func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package db
|
package db_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"spend-sparrow/db"
|
||||||
"spend-sparrow/types"
|
"spend-sparrow/types"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -8,26 +9,29 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupDb(t *testing.T) *sqlx.DB {
|
func setupDb(t *testing.T) *sqlx.DB {
|
||||||
db, err := sqlx.Open("sqlite3", ":memory:")
|
t.Helper()
|
||||||
|
|
||||||
|
d, err := sqlx.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error opening database: %v", err)
|
t.Fatalf("Error opening database: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := db.Close()
|
err := d.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
err = RunMigrations(db, "../")
|
err = db.RunMigrations(d, "../")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error running migrations: %v", err)
|
t.Fatalf("Error running migrations: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return db
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser(t *testing.T) {
|
func TestUser(t *testing.T) {
|
||||||
@@ -35,55 +39,55 @@ func TestUser(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("should insert and get the same", func(t *testing.T) {
|
t.Run("should insert and get the same", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(expected)
|
err := underTest.InsertUser(expected)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := underTest.GetUser(expected.Id)
|
actual, err := underTest.GetUser(expected.Id)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
|
|
||||||
actual, err = underTest.GetUserByEmail(expected.Email)
|
actual, err = underTest.GetUserByEmail(expected.Email)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrNotFound", func(t *testing.T) {
|
t.Run("should return ErrNotFound", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
_, err := underTest.GetUserByEmail("nonExistentEmail")
|
_, err := underTest.GetUserByEmail("nonExistentEmail")
|
||||||
assert.Equal(t, ErrNotFound, err)
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrUserExist", func(t *testing.T) {
|
t.Run("should return ErrUserExist", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
err := underTest.InsertUser(user)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = underTest.InsertUser(user)
|
err = underTest.InsertUser(user)
|
||||||
assert.Equal(t, ErrAlreadyExists, err)
|
assert.Equal(t, db.ErrAlreadyExists, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
||||||
@@ -98,37 +102,37 @@ func TestToken(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("should insert and get the same", func(t *testing.T) {
|
t.Run("should insert and get the same", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
|
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
|
||||||
err := underTest.InsertToken(expected)
|
err := underTest.InsertToken(expected)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual, err := underTest.GetToken(expected.Token)
|
actual, err := underTest.GetToken(expected.Token)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
|
|
||||||
expected.SessionId = ""
|
expected.SessionId = ""
|
||||||
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type)
|
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected}, actuals)
|
assert.Equal(t, []*types.Token{expected}, actuals)
|
||||||
|
|
||||||
expected.SessionId = "sessionId"
|
expected.SessionId = "sessionId"
|
||||||
expected.UserId = uuid.Nil
|
expected.UserId = uuid.Nil
|
||||||
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type)
|
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected}, actuals)
|
assert.Equal(t, []*types.Token{expected}, actuals)
|
||||||
})
|
})
|
||||||
t.Run("should insert and return multiple tokens", func(t *testing.T) {
|
t.Run("should insert and return multiple tokens", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
@@ -137,14 +141,14 @@ func TestToken(t *testing.T) {
|
|||||||
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
|
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
|
||||||
err := underTest.InsertToken(expected1)
|
err := underTest.InsertToken(expected1)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
err = underTest.InsertToken(expected2)
|
err = underTest.InsertToken(expected2)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected1.UserId = uuid.Nil
|
expected1.UserId = uuid.Nil
|
||||||
expected2.UserId = uuid.Nil
|
expected2.UserId = uuid.Nil
|
||||||
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type)
|
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
||||||
|
|
||||||
expected1.SessionId = ""
|
expected1.SessionId = ""
|
||||||
@@ -152,46 +156,45 @@ func TestToken(t *testing.T) {
|
|||||||
expected1.UserId = userId
|
expected1.UserId = userId
|
||||||
expected2.UserId = userId
|
expected2.UserId = userId
|
||||||
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type)
|
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
||||||
|
|
||||||
})
|
})
|
||||||
t.Run("should return ErrNotFound", func(t *testing.T) {
|
t.Run("should return ErrNotFound", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
_, err := underTest.GetToken("nonExistent")
|
_, err := underTest.GetToken("nonExistent")
|
||||||
assert.Equal(t, ErrNotFound, err)
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
|
|
||||||
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify)
|
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify)
|
||||||
assert.Equal(t, ErrNotFound, err)
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
|
|
||||||
_, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify)
|
_, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify)
|
||||||
assert.Equal(t, ErrNotFound, err)
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
|
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
err := underTest.InsertUser(user)
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = underTest.InsertUser(user)
|
err = underTest.InsertUser(user)
|
||||||
assert.Equal(t, ErrAlreadyExists, err)
|
assert.Equal(t, db.ErrAlreadyExists, err)
|
||||||
})
|
})
|
||||||
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
db := setupDb(t)
|
d := setupDb(t)
|
||||||
|
|
||||||
underTest := AuthSqlite{db: db}
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ var (
|
|||||||
|
|
||||||
func TransformAndLogDbError(module string, r sql.Result, err error) error {
|
func TransformAndLogDbError(module string, r sql.Result, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
}
|
}
|
||||||
log.Error("%v: %v", module, err)
|
log.Error("%v: %v", module, err)
|
||||||
|
|||||||
@@ -77,7 +77,6 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
|
|||||||
|
|
||||||
func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) {
|
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) {
|
||||||
session := middleware.GetSession(r)
|
session := middleware.GetSession(r)
|
||||||
email := r.FormValue("email")
|
email := r.FormValue("email")
|
||||||
@@ -95,7 +94,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrInvalidCredentials {
|
if errors.Is(err, service.ErrInvalidCredentials) {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Invalid email or password", http.StatusUnauthorized)
|
utils.TriggerToastWithStatus(w, r, "error", "Invalid email or password", http.StatusUnauthorized)
|
||||||
} else {
|
} else {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError)
|
utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError)
|
||||||
@@ -166,7 +165,6 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc {
|
|||||||
|
|
||||||
func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
token := r.URL.Query().Get("token")
|
token := r.URL.Query().Get("token")
|
||||||
|
|
||||||
err := handler.service.VerifyUserEmail(token)
|
err := handler.service.VerifyUserEmail(token)
|
||||||
@@ -203,13 +201,14 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, types.ErrInternal) {
|
switch {
|
||||||
|
case errors.Is(err, types.ErrInternal):
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError)
|
utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
} else if errors.Is(err, service.ErrInvalidEmail) {
|
case errors.Is(err, service.ErrInvalidEmail):
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "The email provided is invalid", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", "The email provided is invalid", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
} else if errors.Is(err, service.ErrInvalidPassword) {
|
case errors.Is(err, service.ErrInvalidPassword):
|
||||||
utils.TriggerToastWithStatus(w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -272,7 +271,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
|
|||||||
|
|
||||||
err := handler.service.DeleteAccount(user, password)
|
err := handler.service.DeleteAccount(user, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrInvalidCredentials {
|
if errors.Is(err, service.ErrInvalidCredentials) {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest)
|
||||||
} else {
|
} else {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
||||||
@@ -286,7 +285,6 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc {
|
|||||||
|
|
||||||
func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
|
func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
isPasswordReset := r.URL.Query().Has("token")
|
isPasswordReset := r.URL.Query().Has("token")
|
||||||
|
|
||||||
user := middleware.GetUser(r)
|
user := middleware.GetUser(r)
|
||||||
@@ -303,7 +301,6 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc {
|
|||||||
|
|
||||||
func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
|
func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
session := middleware.GetSession(r)
|
session := middleware.GetSession(r)
|
||||||
user := middleware.GetUser(r)
|
user := middleware.GetUser(r)
|
||||||
if session == nil || user == nil {
|
if session == nil || user == nil {
|
||||||
@@ -326,7 +323,6 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
|
|||||||
|
|
||||||
func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
|
func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
user := middleware.GetUser(r)
|
user := middleware.GetUser(r)
|
||||||
if user != nil {
|
if user != nil {
|
||||||
utils.DoRedirect(w, r, "/")
|
utils.DoRedirect(w, r, "/")
|
||||||
@@ -340,7 +336,6 @@ func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc {
|
|||||||
|
|
||||||
func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
|
func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
email := r.FormValue("email")
|
email := r.FormValue("email")
|
||||||
if email == "" {
|
if email == "" {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Please enter an email", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", "Please enter an email", http.StatusBadRequest)
|
||||||
@@ -362,7 +357,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc {
|
|||||||
|
|
||||||
func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
|
func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
pageUrl, err := url.Parse(r.Header.Get("HX-Current-URL"))
|
pageUrl, err := url.Parse(r.Header.Get("Hx-Current-Url"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Could not get current URL: %v", err)
|
log.Error("Could not get current URL: %v", err)
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
|||||||
@@ -10,13 +10,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func handleError(w http.ResponseWriter, r *http.Request, err error) {
|
func handleError(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
if errors.Is(err, service.ErrUnauthorized) {
|
switch {
|
||||||
|
case errors.Is(err, service.ErrUnauthorized):
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized)
|
utils.TriggerToastWithStatus(w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
} else if errors.Is(err, service.ErrBadRequest) {
|
case errors.Is(err, service.ErrBadRequest):
|
||||||
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
} else if errors.Is(err, db.ErrNotFound) {
|
case errors.Is(err, db.ErrNotFound):
|
||||||
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound)
|
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ var UserKey ContextKey = "user"
|
|||||||
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
sessionId := getSessionID(r)
|
sessionId := getSessionID(r)
|
||||||
session, user, _ := service.SignInSession(sessionId)
|
session, user, _ := service.SignInSession(sessionId)
|
||||||
|
|
||||||
@@ -49,7 +48,12 @@ func GetUser(r *http.Request) *types.User {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return obj.(*types.User)
|
user, ok := obj.(*types.User)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSession(r *http.Request) *types.Session {
|
func GetSession(r *http.Request) *types.Session {
|
||||||
@@ -58,7 +62,12 @@ func GetSession(r *http.Request) *types.Session {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return obj.(*types.Session)
|
session, ok := obj.(*types.Session)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSessionID(r *http.Request) string {
|
func getSessionID(r *http.Request) string {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
func CacheControl(next http.Handler) http.Handler {
|
func CacheControl(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
shouldCache := strings.HasPrefix(r.URL.Path, "/static")
|
shouldCache := strings.HasPrefix(r.URL.Path, "/static")
|
||||||
|
|
||||||
if !shouldCache {
|
if !shouldCache {
|
||||||
|
|||||||
@@ -37,19 +37,17 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
|
|||||||
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
|
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
session := GetSession(r)
|
session := GetSession(r)
|
||||||
|
|
||||||
if r.Method == http.MethodPost ||
|
if r.Method == http.MethodPost ||
|
||||||
r.Method == http.MethodPut ||
|
r.Method == http.MethodPut ||
|
||||||
r.Method == http.MethodDelete ||
|
r.Method == http.MethodDelete ||
|
||||||
r.Method == http.MethodPatch {
|
r.Method == http.MethodPatch {
|
||||||
|
csrfToken := r.Header.Get("Csrf-Token")
|
||||||
csrfToken := r.Header.Get("csrf-token")
|
|
||||||
|
|
||||||
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
|
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
|
||||||
log.Info("CSRF-Token \"%s\" not correct", csrfToken)
|
log.Info("CSRF-Token \"%s\" not correct", csrfToken)
|
||||||
if r.Header.Get("HX-Request") == "true" {
|
if r.Header.Get("Hx-Request") == "true" {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
|
||||||
} else {
|
} else {
|
||||||
http.Error(w, "CSRF-Token not correct", http.StatusBadRequest)
|
http.Error(w, "CSRF-Token not correct", http.StatusBadRequest)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -32,8 +33,7 @@ func Gzip(next http.Handler) http.Handler {
|
|||||||
next.ServeHTTP(wrapper, r)
|
next.ServeHTTP(wrapper, r)
|
||||||
|
|
||||||
err := gz.Close()
|
err := gz.Close()
|
||||||
if err != nil && err != http.ErrBodyNotAllowed {
|
if err != nil && !errors.Is(err, http.ErrBodyNotAllowed) {
|
||||||
// if err != nil {
|
|
||||||
log.Error("Gzip: could not close Writer: %v", err)
|
log.Error("Gzip: could not close Writer: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Handler {
|
func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Handler {
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
@@ -30,7 +29,7 @@ func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Han
|
|||||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||||
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
|
w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload")
|
||||||
|
|
||||||
if r.Method == "OPTIONS" {
|
if r.Method == http.MethodOptions {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ package middleware
|
|||||||
|
|
||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
||||||
// Chain list of handlers together
|
// Chain list of handlers together.
|
||||||
func Wrapper(next http.Handler, handlers ...func(http.Handler) http.Handler) http.Handler {
|
func Wrapper(next http.Handler, handlers ...func(http.Handler) http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
lastHandler := next
|
lastHandler := next
|
||||||
for i := 0; i < len(handlers); i++ {
|
for _, handler := range handlers {
|
||||||
lastHandler = handlers[i](lastHandler)
|
lastHandler = handler(lastHandler)
|
||||||
}
|
}
|
||||||
lastHandler.ServeHTTP(w, r)
|
lastHandler.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (render *Render) getUserComp(user *types.User) templ.Component {
|
func (render *Render) getUserComp(user *types.User) templ.Component {
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
return auth.UserComp(user.Email)
|
return auth.UserComp(user.Email)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -106,7 +106,6 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) {
|
func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) {
|
||||||
|
|
||||||
var transactionsRecurring []*types.TransactionRecurring
|
var transactionsRecurring []*types.TransactionRecurring
|
||||||
var err error
|
var err error
|
||||||
if accountId == "" && treasureChestId == "" {
|
if accountId == "" && treasureChestId == "" {
|
||||||
|
|||||||
18
main.go
18
main.go
@@ -1,6 +1,8 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"spend-sparrow/db"
|
"spend-sparrow/db"
|
||||||
"spend-sparrow/handler"
|
"spend-sparrow/handler"
|
||||||
"spend-sparrow/handler/middleware"
|
"spend-sparrow/handler/middleware"
|
||||||
@@ -37,10 +39,14 @@ func main() {
|
|||||||
log.Fatal("Could not close Database data.db: %v", err)
|
log.Fatal("Could not close Database data.db: %v", err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
run(context.Background(), db, os.Getenv)
|
err = run(context.Background(), db, os.Getenv)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error running server: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
|
func run(ctx context.Context, database *sqlx.DB, env func(string) string) error {
|
||||||
ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
|
ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -52,7 +58,7 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
|
|||||||
// init db
|
// init db
|
||||||
err := db.RunMigrations(database, "")
|
err := db.RunMigrations(database, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Could not run migrations: %v", err)
|
return fmt.Errorf("could not run migrations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// init servers
|
// init servers
|
||||||
@@ -61,6 +67,7 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
|
|||||||
prometheusServer := &http.Server{
|
prometheusServer := &http.Server{
|
||||||
Addr: ":8081",
|
Addr: ":8081",
|
||||||
Handler: promhttp.Handler(),
|
Handler: promhttp.Handler(),
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
go startServer(prometheusServer)
|
go startServer(prometheusServer)
|
||||||
}
|
}
|
||||||
@@ -68,6 +75,7 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
|
|||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: ":" + serverSettings.Port,
|
Addr: ":" + serverSettings.Port,
|
||||||
Handler: createHandler(database, serverSettings),
|
Handler: createHandler(database, serverSettings),
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
go startServer(httpServer)
|
go startServer(httpServer)
|
||||||
|
|
||||||
@@ -77,11 +85,13 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) {
|
|||||||
go shutdownServer(httpServer, ctx, &wg)
|
go shutdownServer(httpServer, ctx, &wg)
|
||||||
go shutdownServer(prometheusServer, ctx, &wg)
|
go shutdownServer(prometheusServer, ctx, &wg)
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startServer(s *http.Server) {
|
func startServer(s *http.Server) {
|
||||||
log.Info("Starting server on %q", s.Addr)
|
log.Info("Starting server on %q", s.Addr)
|
||||||
if err := s.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Error("error listening and serving: %v", err)
|
log.Error("error listening and serving: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
1140
main_test.go
1140
main_test.go
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"spend-sparrow/db"
|
"spend-sparrow/db"
|
||||||
@@ -119,7 +120,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type
|
|||||||
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("account Update", nil, err)
|
err = db.TransformAndLogDbError("account Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
|
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
@@ -164,8 +165,8 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
|||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
account := &types.Account{}
|
var account types.Account
|
||||||
err = s.db.Get(account, `
|
err = s.db.Get(&account, `
|
||||||
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("account Get", nil, err)
|
err = db.TransformAndLogDbError("account Get", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -173,7 +174,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
||||||
|
|||||||
134
service/auth.go
134
service/auth.go
@@ -94,30 +94,6 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st
|
|||||||
return session, user, nil
|
return session, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
|
|
||||||
if session == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := service.db.DeleteSession(session.Id)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrInternal
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrInternal
|
|
||||||
}
|
|
||||||
for _, token := range tokens {
|
|
||||||
err = service.db.DeleteToken(token.Token)
|
|
||||||
if err != nil {
|
|
||||||
return types.ErrInternal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
|
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
|
||||||
if sessionId == "" {
|
if sessionId == "" {
|
||||||
return nil, nil, ErrSessionIdInvalid
|
return nil, nil, ErrSessionIdInvalid
|
||||||
@@ -155,30 +131,6 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
|
|||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
|
|
||||||
sessionId, err := service.random.String(32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.ErrInternal
|
|
||||||
}
|
|
||||||
|
|
||||||
err = service.db.DeleteOldSessions(userId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.ErrInternal
|
|
||||||
}
|
|
||||||
|
|
||||||
createAt := service.clock.Now()
|
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
|
||||||
|
|
||||||
session := types.NewSession(sessionId, userId, createAt, expiresAt)
|
|
||||||
|
|
||||||
err = service.db.InsertSession(session)
|
|
||||||
if err != nil {
|
|
||||||
return nil, types.ErrInternal
|
|
||||||
}
|
|
||||||
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
|
func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
|
||||||
_, err := mail.ParseAddress(email)
|
_, err := mail.ParseAddress(email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -205,7 +157,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
|
|||||||
|
|
||||||
err = service.db.InsertUser(user)
|
err = service.db.InsertUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrAlreadyExists {
|
if errors.Is(err, db.ErrAlreadyExists) {
|
||||||
return nil, ErrAccountExists
|
return nil, ErrAccountExists
|
||||||
} else {
|
} else {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
@@ -216,9 +168,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
||||||
|
|
||||||
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
|
tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
|
||||||
if err != nil && err != db.ErrNotFound {
|
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,7 +185,13 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token = types.NewToken(userId, "", newTokenStr, types.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour))
|
token = types.NewToken(
|
||||||
|
userId,
|
||||||
|
"",
|
||||||
|
newTokenStr,
|
||||||
|
types.TokenTypeEmailVerify,
|
||||||
|
service.clock.Now(),
|
||||||
|
service.clock.Now().Add(24*time.Hour))
|
||||||
|
|
||||||
err = service.db.InsertToken(token)
|
err = service.db.InsertToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -253,7 +210,6 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
|
func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
|
||||||
|
|
||||||
if tokenStr == "" {
|
if tokenStr == "" {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -291,12 +247,10 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignOut(sessionId string) error {
|
func (service AuthImpl) SignOut(sessionId string) error {
|
||||||
|
|
||||||
return service.db.DeleteSession(sessionId)
|
return service.db.DeleteSession(sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
||||||
|
|
||||||
userDb, err := service.db.GetUser(user.Id)
|
userDb, err := service.db.GetUser(user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
@@ -318,7 +272,6 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
|
func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
|
||||||
|
|
||||||
if !isPasswordValid(newPass) {
|
if !isPasswordValid(newPass) {
|
||||||
return ErrInvalidPassword
|
return ErrInvalidPassword
|
||||||
}
|
}
|
||||||
@@ -365,14 +318,20 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
|
|||||||
|
|
||||||
user, err := service.db.GetUserByEmail(email)
|
user, err := service.db.GetUserByEmail(email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
token := types.NewToken(user.Id, "", tokenStr, types.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute))
|
token := types.NewToken(
|
||||||
|
user.Id,
|
||||||
|
"",
|
||||||
|
tokenStr,
|
||||||
|
types.TokenTypePasswordReset,
|
||||||
|
service.clock.Now(),
|
||||||
|
service.clock.Now().Add(15*time.Minute))
|
||||||
|
|
||||||
err = service.db.InsertToken(token)
|
err = service.db.InsertToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -391,7 +350,6 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
|
||||||
|
|
||||||
if !isPasswordValid(newPass) {
|
if !isPasswordValid(newPass) {
|
||||||
return ErrInvalidPassword
|
return ErrInvalidPassword
|
||||||
}
|
}
|
||||||
@@ -449,7 +407,6 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
|
|||||||
if token.Type != types.TokenTypeCsrf ||
|
if token.Type != types.TokenTypeCsrf ||
|
||||||
token.SessionId != sessionId ||
|
token.SessionId != sessionId ||
|
||||||
token.ExpiresAt.Before(service.clock.Now()) {
|
token.ExpiresAt.Before(service.clock.Now()) {
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,7 +429,13 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
|
|||||||
return "", types.ErrInternal
|
return "", types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
token := types.NewToken(session.UserId, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(8*time.Hour))
|
token := types.NewToken(
|
||||||
|
session.UserId,
|
||||||
|
session.Id,
|
||||||
|
tokenStr,
|
||||||
|
types.TokenTypeCsrf,
|
||||||
|
service.clock.Now(),
|
||||||
|
service.clock.Now().Add(8*time.Hour))
|
||||||
err = service.db.InsertToken(token)
|
err = service.db.InsertToken(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", types.ErrInternal
|
return "", types.ErrInternal
|
||||||
@@ -483,12 +446,59 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
|
|||||||
return tokenStr, nil
|
return tokenStr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
|
||||||
|
if session == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.db.DeleteSession(session.Id)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
for _, token := range tokens {
|
||||||
|
err = service.db.DeleteToken(token.Token)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
|
||||||
|
sessionId, err := service.random.String(32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrInternal
|
||||||
|
}
|
||||||
|
|
||||||
|
err = service.db.DeleteOldSessions(userId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrInternal
|
||||||
|
}
|
||||||
|
|
||||||
|
createAt := service.clock.Now()
|
||||||
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
session := types.NewSession(sessionId, userId, createAt, expiresAt)
|
||||||
|
|
||||||
|
err = service.db.InsertSession(session)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrInternal
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
func GetHashPassword(password string, salt []byte) []byte {
|
func GetHashPassword(password string, salt []byte) []byte {
|
||||||
return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16)
|
return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPasswordValid(password string) bool {
|
func isPasswordValid(password string) bool {
|
||||||
|
|
||||||
if len(password) < 8 ||
|
if len(password) < 8 ||
|
||||||
!strings.ContainsAny(password, "0123456789") ||
|
!strings.ContainsAny(password, "0123456789") ||
|
||||||
!strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") ||
|
!strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") ||
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package service
|
package service_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"spend-sparrow/db"
|
"spend-sparrow/db"
|
||||||
"spend-sparrow/mocks"
|
"spend-sparrow/mocks"
|
||||||
|
"spend-sparrow/service"
|
||||||
"spend-sparrow/types"
|
"spend-sparrow/types"
|
||||||
|
|
||||||
"strings"
|
"strings"
|
||||||
@@ -12,6 +13,17 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
settings = types.Settings{
|
||||||
|
Port: "",
|
||||||
|
PrometheusEnabled: false,
|
||||||
|
BaseUrl: "",
|
||||||
|
Environment: "test",
|
||||||
|
Smtp: nil,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSignUp(t *testing.T) {
|
func TestSignUp(t *testing.T) {
|
||||||
@@ -24,11 +36,11 @@ func TestSignUp(t *testing.T) {
|
|||||||
mockClock := mocks.NewMockClock(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
mockMail := mocks.NewMockMail(t)
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
_, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!")
|
_, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!")
|
||||||
|
|
||||||
assert.Equal(t, ErrInvalidEmail, err)
|
assert.Equal(t, service.ErrInvalidEmail, err)
|
||||||
})
|
})
|
||||||
t.Run("should check for password complexity", func(t *testing.T) {
|
t.Run("should check for password complexity", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
@@ -38,7 +50,7 @@ func TestSignUp(t *testing.T) {
|
|||||||
mockClock := mocks.NewMockClock(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
mockMail := mocks.NewMockMail(t)
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
weakPasswords := []string{
|
weakPasswords := []string{
|
||||||
"123!ab", // too short
|
"123!ab", // too short
|
||||||
@@ -49,7 +61,7 @@ func TestSignUp(t *testing.T) {
|
|||||||
|
|
||||||
for _, password := range weakPasswords {
|
for _, password := range weakPasswords {
|
||||||
_, err := underTest.SignUp("some@valid.email", password)
|
_, err := underTest.SignUp("some@valid.email", password)
|
||||||
assert.Equal(t, ErrInvalidPassword, err)
|
assert.Equal(t, service.ErrInvalidPassword, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("should signup correctly", func(t *testing.T) {
|
t.Run("should signup correctly", func(t *testing.T) {
|
||||||
@@ -66,17 +78,17 @@ func TestSignUp(t *testing.T) {
|
|||||||
salt := []byte("salt")
|
salt := []byte("salt")
|
||||||
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
expected := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime)
|
expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
|
||||||
|
|
||||||
mockRandom.EXPECT().UUID().Return(userId, nil)
|
mockRandom.EXPECT().UUID().Return(userId, nil)
|
||||||
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
||||||
mockClock.EXPECT().Now().Return(createTime)
|
mockClock.EXPECT().Now().Return(createTime)
|
||||||
mockAuthDb.EXPECT().InsertUser(expected).Return(nil)
|
mockAuthDb.EXPECT().InsertUser(expected).Return(nil)
|
||||||
|
|
||||||
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
actual, err := underTest.SignUp(email, password)
|
actual, err := underTest.SignUp(email, password)
|
||||||
|
|
||||||
assert.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, expected, actual)
|
assert.Equal(t, expected, actual)
|
||||||
})
|
})
|
||||||
@@ -93,7 +105,7 @@ func TestSignUp(t *testing.T) {
|
|||||||
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
password := "SomeStrongPassword123!"
|
password := "SomeStrongPassword123!"
|
||||||
salt := []byte("salt")
|
salt := []byte("salt")
|
||||||
user := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime)
|
user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
|
||||||
|
|
||||||
mockRandom.EXPECT().UUID().Return(user.Id, nil)
|
mockRandom.EXPECT().UUID().Return(user.Id, nil)
|
||||||
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
||||||
@@ -101,20 +113,25 @@ func TestSignUp(t *testing.T) {
|
|||||||
|
|
||||||
mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists)
|
mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists)
|
||||||
|
|
||||||
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
_, err := underTest.SignUp(user.Email, password)
|
_, err := underTest.SignUp(user.Email, password)
|
||||||
assert.Equal(t, ErrAccountExists, err)
|
assert.Equal(t, service.ErrAccountExists, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendVerificationMail(t *testing.T) {
|
func TestSendVerificationMail(t *testing.T) {
|
||||||
|
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("should use stored token and send mail", func(t *testing.T) {
|
t.Run("should use stored token and send mail", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
token := types.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", types.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
|
token := types.NewToken(
|
||||||
|
uuid.New(),
|
||||||
|
"sessionId",
|
||||||
|
"someRandomTokenToUse",
|
||||||
|
types.TokenTypeEmailVerify,
|
||||||
|
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
|
||||||
tokens := []*types.Token{token}
|
tokens := []*types.Token{token}
|
||||||
|
|
||||||
email := "some@email.de"
|
email := "some@email.de"
|
||||||
@@ -131,7 +148,7 @@ func TestSendVerificationMail(t *testing.T) {
|
|||||||
return strings.Contains(message, token.Token)
|
return strings.Contains(message, token.Token)
|
||||||
})).Return()
|
})).Return()
|
||||||
|
|
||||||
underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
underTest.SendVerificationMail(userId, email)
|
underTest.SendVerificationMail(userId, email)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,16 +5,21 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DECIMALS_MULTIPLIER = 100
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9ÄÖÜäöüß,&'" -]+$`)
|
safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9ÄÖÜäöüß,&'" -]+$`)
|
||||||
)
|
)
|
||||||
|
|
||||||
func validateString(value string, fieldName string) error {
|
func validateString(value string, fieldName string) error {
|
||||||
if value == "" {
|
switch {
|
||||||
|
case value == "":
|
||||||
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest)
|
return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest)
|
||||||
} else if !safeInputRegex.MatchString(value) {
|
case !safeInputRegex.MatchString(value):
|
||||||
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest)
|
return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest)
|
||||||
} else {
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,19 @@ func (m MailImpl) internalSendMail(to string, subject string, message string) {
|
|||||||
|
|
||||||
auth := smtp.PlainAuth("", s.User, s.Pass, s.Host)
|
auth := smtp.PlainAuth("", s.User, s.Pass, s.Host)
|
||||||
|
|
||||||
msg := fmt.Sprintf("From: %v <%v>\nTo: %v\nSubject: %v\nMIME-version: 1.0;\nContent-Type: text/html; charset=\"UTF-8\";\n\n%v", s.FromName, s.FromMail, to, subject, message)
|
msg := fmt.Sprintf(
|
||||||
|
`From: %v <%v>
|
||||||
|
To: %v
|
||||||
|
Subject: %v
|
||||||
|
MIME-version: 1.0;
|
||||||
|
Content-Type: text/html; charset="UTF-8";
|
||||||
|
|
||||||
|
%v`,
|
||||||
|
s.FromName,
|
||||||
|
s.FromMail,
|
||||||
|
to,
|
||||||
|
subject,
|
||||||
|
message)
|
||||||
|
|
||||||
log.Info("Sending mail to %v", to)
|
log.Info("Sending mail to %v", to)
|
||||||
err := smtp.SendMail(s.Host+":"+s.Port, auth, s.FromMail, []string{to}, []byte(msg))
|
err := smtp.SendMail(s.Host+":"+s.Port, auth, s.FromMail, []string{to}, []byte(msg))
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
type MoneyImpl struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMoneyImpl() *MoneyImpl {
|
|
||||||
return &MoneyImpl{}
|
|
||||||
}
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMoneyCalculation(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
t.Run("should calculate correct oink balance", func(t *testing.T) {
|
|
||||||
// t.Parallel()
|
|
||||||
//
|
|
||||||
// underTest := NewMoneyImpl()
|
|
||||||
//
|
|
||||||
// // GIVEN
|
|
||||||
// timestamp := time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC)
|
|
||||||
//
|
|
||||||
// userId := uuid.New()
|
|
||||||
//
|
|
||||||
// account := types.Account{
|
|
||||||
// Id: uuid.New(),
|
|
||||||
// UserId: userId,
|
|
||||||
//
|
|
||||||
// Type: "Bank",
|
|
||||||
// Name: "Bank",
|
|
||||||
//
|
|
||||||
// CurrentBalance: 0,
|
|
||||||
// LastTransaction: time.Time{},
|
|
||||||
// OinkBalance: 0,
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // The PiggyBank is a fictional account. The money it "holds" is actually in the Account
|
|
||||||
// piggyBank := types.PiggyBank{
|
|
||||||
// Id: uuid.New(),
|
|
||||||
// UserId: userId,
|
|
||||||
//
|
|
||||||
// AccountId: account.Id,
|
|
||||||
// Name: "Car",
|
|
||||||
//
|
|
||||||
// CurrentBalance: 0,
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// savingsPlan := types.SavingsPlan{
|
|
||||||
// Id: uuid.New(),
|
|
||||||
// UserId: userId,
|
|
||||||
// PiggyBankId: piggyBank.Id,
|
|
||||||
//
|
|
||||||
// MonthlySaving: 10,
|
|
||||||
//
|
|
||||||
// ValidFrom: timestamp,
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// transaction1 := types.Transaction{
|
|
||||||
// Id: uuid.New(),
|
|
||||||
// UserId: userId,
|
|
||||||
//
|
|
||||||
// AccountId: account.Id,
|
|
||||||
//
|
|
||||||
// Value: 20,
|
|
||||||
// Timestamp: timestamp,
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// transaction2 := types.Transaction{
|
|
||||||
// Id: uuid.New(),
|
|
||||||
// UserId: userId,
|
|
||||||
//
|
|
||||||
// AccountId: account.Id,
|
|
||||||
// PiggyBankId: &piggyBank.Id,
|
|
||||||
//
|
|
||||||
// Value: -1,
|
|
||||||
// Timestamp: timestamp.Add(1 * time.Hour),
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // WHEN
|
|
||||||
// actual, err := underTest.CalculateAllBalancesInTime(account, piggyBank, savingsPlan, []types.Transaction{transaction1, transaction2})
|
|
||||||
//
|
|
||||||
// // THEN
|
|
||||||
// assert.Nil(t, err)
|
|
||||||
// assert.ElementsMatch(t, expected, actual)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@@ -73,8 +74,10 @@ func (s TransactionImpl) Add(user *types.User, transactionInput types.Transactio
|
|||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.NamedExec(`
|
r, err := tx.NamedExec(`
|
||||||
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp, party, description, error, created_at, created_by)
|
INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp,
|
||||||
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp, :party, :description, :error, :created_at, :created_by)`, transaction)
|
party, description, error, created_at, created_by)
|
||||||
|
VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp,
|
||||||
|
:party, :description, :error, :created_at, :created_by)`, transaction)
|
||||||
err = db.TransformAndLogDbError("transaction Insert", r, err)
|
err = db.TransformAndLogDbError("transaction Insert", r, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -135,7 +138,7 @@ func (s TransactionImpl) Update(user *types.User, input types.TransactionInput)
|
|||||||
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("transaction Update", nil, err)
|
err = db.TransformAndLogDbError("transaction Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest)
|
return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
@@ -232,7 +235,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e
|
|||||||
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("transaction Get", nil, err)
|
err = db.TransformAndLogDbError("transaction Get", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest)
|
return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
@@ -259,7 +262,10 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF
|
|||||||
OR (? = "false" AND error IS NULL)
|
OR (? = "false" AND error IS NULL)
|
||||||
)
|
)
|
||||||
ORDER BY timestamp DESC`,
|
ORDER BY timestamp DESC`,
|
||||||
user.Id, filter.AccountId, filter.AccountId, filter.TreasureChestId, filter.TreasureChestId, filter.Error, filter.Error, filter.Error)
|
user.Id,
|
||||||
|
filter.AccountId, filter.AccountId,
|
||||||
|
filter.TreasureChestId, filter.TreasureChestId,
|
||||||
|
filter.Error, filter.Error, filter.Error)
|
||||||
err = db.TransformAndLogDbError("transaction GetAll", nil, err)
|
err = db.TransformAndLogDbError("transaction GetAll", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -302,7 +308,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
|||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id)
|
||||||
err = db.TransformAndLogDbError("transaction Delete", r, err)
|
err = db.TransformAndLogDbError("transaction Delete", r, err)
|
||||||
if err != nil && err != db.ErrNotFound {
|
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -314,7 +320,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error {
|
|||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id)
|
||||||
err = db.TransformAndLogDbError("transaction Delete", r, err)
|
err = db.TransformAndLogDbError("transaction Delete", r, err)
|
||||||
if err != nil && err != db.ErrNotFound {
|
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -354,7 +360,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
SET current_balance = 0
|
SET current_balance = 0
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
|
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
|
||||||
if err != nil && err != db.ErrNotFound {
|
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,7 +369,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
SET current_balance = 0
|
SET current_balance = 0
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
|
err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err)
|
||||||
if err != nil && err != db.ErrNotFound {
|
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,7 +378,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
FROM "transaction"
|
FROM "transaction"
|
||||||
WHERE user_id = ?`, user.Id)
|
WHERE user_id = ?`, user.Id)
|
||||||
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
|
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
|
||||||
if err != nil && err != db.ErrNotFound {
|
if err != nil && !errors.Is(err, db.ErrNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -382,15 +388,15 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
transaction := &types.Transaction{}
|
var transaction types.Transaction
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
err = rows.StructScan(transaction)
|
err = rows.StructScan(&transaction)
|
||||||
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
|
err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.updateErrors(transaction)
|
s.updateErrors(&transaction)
|
||||||
r, err = tx.Exec(`
|
r, err = tx.Exec(`
|
||||||
UPDATE "transaction"
|
UPDATE "transaction"
|
||||||
SET error = ?
|
SET error = ?
|
||||||
@@ -424,7 +430,6 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -438,7 +443,6 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.TransactionInput) (*types.Transaction, error) {
|
func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.TransactionInput) (*types.Transaction, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
accountUuid *uuid.UUID
|
accountUuid *uuid.UUID
|
||||||
@@ -484,7 +488,6 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
|||||||
log.Error("transaction validate: %v", err)
|
log.Error("transaction validate: %v", err)
|
||||||
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
|
return nil, fmt.Errorf("account not found: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.TreasureChestId != "" {
|
if input.TreasureChestId != "" {
|
||||||
@@ -498,7 +501,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
|||||||
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
||||||
err = db.TransformAndLogDbError("transaction validate", nil, err)
|
err = db.TransformAndLogDbError("transaction validate", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
|
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -513,7 +516,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
|||||||
log.Error("transaction validate: %v", err)
|
log.Error("transaction validate: %v", err)
|
||||||
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
valueInt := int64(valueFloat * 100)
|
valueInt := int64(valueFloat * DECIMALS_MULTIPLIER)
|
||||||
|
|
||||||
timestamp, err := time.Parse("2006-01-02", input.Timestamp)
|
timestamp, err := time.Parse("2006-01-02", input.Timestamp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -544,6 +547,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
|||||||
Timestamp: timestamp,
|
Timestamp: timestamp,
|
||||||
Party: input.Party,
|
Party: input.Party,
|
||||||
Description: input.Description,
|
Description: input.Description,
|
||||||
|
Error: nil,
|
||||||
|
|
||||||
CreatedAt: createdAt,
|
CreatedAt: createdAt,
|
||||||
CreatedBy: createdBy,
|
CreatedBy: createdBy,
|
||||||
@@ -557,25 +561,26 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionImpl) updateErrors(transaction *types.Transaction) {
|
func (s TransactionImpl) updateErrors(transaction *types.Transaction) {
|
||||||
error := ""
|
errorStr := ""
|
||||||
|
|
||||||
if transaction.Value < 0 {
|
switch {
|
||||||
|
case transaction.Value < 0:
|
||||||
if transaction.TreasureChestId == nil {
|
if transaction.TreasureChestId == nil {
|
||||||
error = "no treasure chest specified"
|
errorStr = "no treasure chest specified"
|
||||||
}
|
}
|
||||||
} else if transaction.Value > 0 {
|
case transaction.Value > 0:
|
||||||
if transaction.AccountId == nil && transaction.TreasureChestId == nil {
|
if transaction.AccountId == nil && transaction.TreasureChestId == nil {
|
||||||
error = "either an account or a treasure chest needs to be specified"
|
errorStr = "either an account or a treasure chest needs to be specified"
|
||||||
} else if transaction.AccountId != nil && transaction.TreasureChestId != nil {
|
} else if transaction.AccountId != nil && transaction.TreasureChestId != nil {
|
||||||
error = "positive amounts can only be applied to either an account or a treasure chest"
|
errorStr = "positive amounts can only be applied to either an account or a treasure chest"
|
||||||
}
|
}
|
||||||
} else {
|
default:
|
||||||
error = "\"value\" needs to be specified"
|
errorStr = "\"value\" needs to be specified"
|
||||||
}
|
}
|
||||||
|
|
||||||
if error == "" {
|
if errorStr == "" {
|
||||||
transaction.Error = nil
|
transaction.Error = nil
|
||||||
} else {
|
} else {
|
||||||
transaction.Error = &error
|
transaction.Error = &errorStr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,8 +19,8 @@ import (
|
|||||||
var (
|
var (
|
||||||
transactionRecurringMetric = promauto.NewCounterVec(
|
transactionRecurringMetric = promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "spendsparrow_transactionRecurring_recurring_total",
|
Name: "spendsparrow_transaction_recurring_total",
|
||||||
Help: "The total of transactionRecurring recurring operations",
|
Help: "The total of transactionRecurring operations",
|
||||||
},
|
},
|
||||||
[]string{"operation"},
|
[]string{"operation"},
|
||||||
)
|
)
|
||||||
@@ -28,7 +29,6 @@ var (
|
|||||||
type TransactionRecurring interface {
|
type TransactionRecurring interface {
|
||||||
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||||
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error)
|
||||||
Get(user *types.User, id string) (*types.TransactionRecurring, error)
|
|
||||||
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error)
|
GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error)
|
||||||
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error)
|
||||||
Delete(user *types.User, id string) error
|
Delete(user *types.User, id string) error
|
||||||
@@ -50,7 +50,9 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, settings *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
|
func (s TransactionRecurringImpl) Add(
|
||||||
|
user *types.User,
|
||||||
|
transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
|
||||||
transactionRecurringMetric.WithLabelValues("add").Inc()
|
transactionRecurringMetric.WithLabelValues("add").Inc()
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
@@ -72,8 +74,11 @@ func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
r, err := tx.NamedExec(`
|
r, err := tx.NamedExec(`
|
||||||
INSERT INTO "transaction_recurring" (id, user_id, interval_months, active, party, description, account_id, treasure_chest_id, value, created_at, created_by)
|
INSERT INTO "transaction_recurring" (id, user_id, interval_months,
|
||||||
VALUES (:id, :user_id, :interval_months, :active, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`, transactionRecurring)
|
active, party, description, account_id, treasure_chest_id, value, created_at, created_by)
|
||||||
|
VALUES (:id, :user_id, :interval_months,
|
||||||
|
:active, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`,
|
||||||
|
transactionRecurring)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Insert", r, err)
|
err = db.TransformAndLogDbError("transactionRecurring Insert", r, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -88,7 +93,9 @@ func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInpu
|
|||||||
return transactionRecurring, nil
|
return transactionRecurring, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Update(user *types.User, input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
|
func (s TransactionRecurringImpl) Update(
|
||||||
|
user *types.User,
|
||||||
|
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
|
||||||
transactionRecurringMetric.WithLabelValues("update").Inc()
|
transactionRecurringMetric.WithLabelValues("update").Inc()
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, ErrUnauthorized
|
return nil, ErrUnauthorized
|
||||||
@@ -112,7 +119,7 @@ func (s TransactionRecurringImpl) Update(user *types.User, input types.Transacti
|
|||||||
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest)
|
return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
@@ -151,31 +158,6 @@ func (s TransactionRecurringImpl) Update(user *types.User, input types.Transacti
|
|||||||
return transactionRecurring, nil
|
return transactionRecurring, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) Get(user *types.User, id string) (*types.TransactionRecurring, error) {
|
|
||||||
transactionRecurringMetric.WithLabelValues("get").Inc()
|
|
||||||
|
|
||||||
if user == nil {
|
|
||||||
return nil, ErrUnauthorized
|
|
||||||
}
|
|
||||||
uuid, err := uuid.Parse(id)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("transactionRecurring get: %v", err)
|
|
||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
var transactionRecurring types.TransactionRecurring
|
|
||||||
err = s.db.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
|
||||||
err = db.TransformAndLogDbError("transactionRecurring Get", nil, err)
|
|
||||||
if err != nil {
|
|
||||||
if err == db.ErrNotFound {
|
|
||||||
return nil, fmt.Errorf("transactionRecurring %v not found: %w", id, ErrBadRequest)
|
|
||||||
}
|
|
||||||
return nil, types.ErrInternal
|
|
||||||
}
|
|
||||||
|
|
||||||
return &transactionRecurring, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) {
|
||||||
transactionRecurringMetric.WithLabelValues("get_all_by_account").Inc()
|
transactionRecurringMetric.WithLabelValues("get_all_by_account").Inc()
|
||||||
if user == nil {
|
if user == nil {
|
||||||
@@ -201,7 +183,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st
|
|||||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest)
|
return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
@@ -254,7 +236,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(user *types.User, treasu
|
|||||||
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest)
|
return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
@@ -329,7 +311,6 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
oldTransactionRecurring *types.TransactionRecurring,
|
oldTransactionRecurring *types.TransactionRecurring,
|
||||||
userId uuid.UUID,
|
userId uuid.UUID,
|
||||||
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
|
input types.TransactionRecurringInput) (*types.TransactionRecurring, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
accountUuid *uuid.UUID
|
accountUuid *uuid.UUID
|
||||||
@@ -393,7 +374,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId)
|
||||||
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
|
err = db.TransformAndLogDbError("transactionRecurring validate", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
|
return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -418,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
log.Error("transactionRecurring validate: %v", err)
|
log.Error("transactionRecurring validate: %v", err)
|
||||||
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
valueInt := int64(valueFloat * 100)
|
valueInt := int64(valueFloat * DECIMALS_MULTIPLIER)
|
||||||
|
|
||||||
if input.Party != "" {
|
if input.Party != "" {
|
||||||
err = validateString(input.Party, "party")
|
err = validateString(input.Party, "party")
|
||||||
@@ -444,12 +425,12 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring(
|
|||||||
active := input.Active == "on"
|
active := input.Active == "on"
|
||||||
|
|
||||||
transactionRecurring := types.TransactionRecurring{
|
transactionRecurring := types.TransactionRecurring{
|
||||||
|
|
||||||
Id: id,
|
Id: id,
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
|
|
||||||
IntervalMonths: intervalMonths,
|
IntervalMonths: intervalMonths,
|
||||||
Active: active,
|
Active: active,
|
||||||
|
LastExecution: nil,
|
||||||
|
|
||||||
Party: input.Party,
|
Party: input.Party,
|
||||||
Description: input.Description,
|
Description: input.Description,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -131,7 +132,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string
|
|||||||
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Update", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
|
return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
@@ -198,17 +199,17 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes
|
|||||||
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
treasureChest := &types.TreasureChest{}
|
var treasureChest types.TreasureChest
|
||||||
err = s.db.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid)
|
||||||
err = db.TransformAndLogDbError("treasureChest Get", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Get", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == db.ErrNotFound {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
|
return nil, fmt.Errorf("treasureChest %v not found: %w", id, err)
|
||||||
}
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return treasureChest, nil
|
return &treasureChest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) {
|
func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) {
|
||||||
@@ -259,7 +260,9 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
transactionsCount := 0
|
transactionsCount := 0
|
||||||
err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`, user.Id, id)
|
err = tx.Get(&transactionsCount,
|
||||||
|
`SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`,
|
||||||
|
user.Id, id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -284,12 +287,11 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func sortTree(nodes []*types.TreasureChest) []*types.TreasureChest {
|
func sortTree(nodes []*types.TreasureChest) []*types.TreasureChest {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
roots []*types.TreasureChest
|
roots []*types.TreasureChest
|
||||||
result []*types.TreasureChest
|
|
||||||
)
|
)
|
||||||
children := make(map[uuid.UUID][]*types.TreasureChest)
|
children := make(map[uuid.UUID][]*types.TreasureChest)
|
||||||
|
result := make([]*types.TreasureChest, 0)
|
||||||
|
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.ParentId == nil {
|
if node.ParentId == nil {
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ templ AccountItem(account *types.Account) {
|
|||||||
hx-target="closest #account"
|
hx-target="closest #account"
|
||||||
hx-swap="outerHTML"
|
hx-swap="outerHTML"
|
||||||
class="button button-neglect px-1 flex items-center gap-2"
|
class="button button-neglect px-1 flex items-center gap-2"
|
||||||
|
hx-confirm="Are you sure you want to delete this account?"
|
||||||
>
|
>
|
||||||
@svg.Delete()
|
@svg.Delete()
|
||||||
<span>
|
<span>
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ templ Layout(slot templ.Component, user templ.Component, loggedIn bool, path str
|
|||||||
<script src="/static/js/toast.js"></script>
|
<script src="/static/js/toast.js"></script>
|
||||||
<script src="/static/js/time.js"></script>
|
<script src="/static/js/time.js"></script>
|
||||||
</head>
|
</head>
|
||||||
<body class="h-screen flex flex-col" hx-headers='{"csrf-token": "CSRF_TOKEN"}'>
|
<body class="h-screen flex flex-col" hx-headers='{"Csrf-Token": "CSRF_TOKEN"}'>
|
||||||
// Header
|
// Header
|
||||||
<nav class="flex bg-white items-center gap-2 py-1 px-2 h-12 md:gap-10 md:px-10 md:py-2">
|
<nav class="flex bg-white items-center gap-2 py-1 px-2 h-12 md:gap-10 md:px-10 md:py-2">
|
||||||
<a href="/" class="flex gap-2 mr-20">
|
<a href="/" class="flex gap-2 mr-20">
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// The Account holds money
|
// The Account holds money.
|
||||||
type Account struct {
|
type Account struct {
|
||||||
Id uuid.UUID
|
Id uuid.UUID `db:"id"`
|
||||||
UserId uuid.UUID `db:"user_id"`
|
UserId uuid.UUID `db:"user_id"`
|
||||||
|
|
||||||
// Custom Name of the account, e.g. "Bank", "Cash", "Credit Card"
|
// Custom Name of the account, e.g. "Bank", "Cash", "Credit Card"
|
||||||
Name string
|
Name string `db:"name"`
|
||||||
|
|
||||||
CurrentBalance int64 `db:"current_balance"`
|
CurrentBalance int64 `db:"current_balance"`
|
||||||
LastTransaction *time.Time `db:"last_transaction"`
|
LastTransaction *time.Time `db:"last_transaction"`
|
||||||
|
|||||||
@@ -17,7 +17,15 @@ type User struct {
|
|||||||
CreateAt time.Time
|
CreateAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User {
|
func NewUser(
|
||||||
|
id uuid.UUID,
|
||||||
|
email string,
|
||||||
|
emailVerified bool,
|
||||||
|
emailVerifiedAt *time.Time,
|
||||||
|
isAdmin bool,
|
||||||
|
password []byte,
|
||||||
|
salt []byte,
|
||||||
|
createAt time.Time) *User {
|
||||||
return &User{
|
return &User{
|
||||||
Id: id,
|
Id: id,
|
||||||
Email: email,
|
Email: email,
|
||||||
@@ -31,10 +39,10 @@ func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *ti
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
Id string
|
Id string `db:"session_id"`
|
||||||
UserId uuid.UUID
|
UserId uuid.UUID `db:"user_id"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time `db:"created_at"`
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time `db:"expires_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session {
|
func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session {
|
||||||
@@ -63,7 +71,13 @@ var (
|
|||||||
TokenTypeCsrf TokenType = "csrf"
|
TokenTypeCsrf TokenType = "csrf"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewToken(userId uuid.UUID, sessionId string, token string, tokenType TokenType, createdAt time.Time, expiresAt time.Time) *Token {
|
func NewToken(
|
||||||
|
userId uuid.UUID,
|
||||||
|
sessionId string,
|
||||||
|
token string,
|
||||||
|
tokenType TokenType,
|
||||||
|
createdAt time.Time,
|
||||||
|
expiresAt time.Time) *Token {
|
||||||
return &Token{
|
return &Token{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
SessionId: sessionId,
|
SessionId: sessionId,
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package types
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The SavingsPlan is applied every interval to the TreasureChest/Account as a transaction
|
|
||||||
type SavingsPlan struct {
|
|
||||||
Id uuid.UUID
|
|
||||||
UserId uuid.UUID `db:"user_id"`
|
|
||||||
|
|
||||||
TreasureChestId uuid.UUID `db:"treasure_chest_id"`
|
|
||||||
|
|
||||||
MonthlySaving int64 `db:"monthly_saving"`
|
|
||||||
|
|
||||||
ValidFrom time.Time `db:"valid_from"`
|
|
||||||
/// nil means it is valid indefinitely
|
|
||||||
ValidTo *time.Time `db:"valid_to"`
|
|
||||||
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
|
||||||
CreatedBy uuid.UUID `db:"created_by"`
|
|
||||||
UpdatedAt *time.Time `db:"updated_at"`
|
|
||||||
UpdatedBy *uuid.UUID `db:"updated_by"`
|
|
||||||
}
|
|
||||||
@@ -23,10 +23,37 @@ type SmtpSettings struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewSettingsFromEnv(env func(string) string) *Settings {
|
func NewSettingsFromEnv(env func(string) string) *Settings {
|
||||||
|
|
||||||
var smtp *SmtpSettings
|
var smtp *SmtpSettings
|
||||||
if env("SMTP_ENABLED") == "true" {
|
if env("SMTP_ENABLED") == "true" {
|
||||||
smtp = &SmtpSettings{
|
smtp = getSmtpSettings(env)
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := &Settings{
|
||||||
|
Port: env("PORT"),
|
||||||
|
PrometheusEnabled: env("PROMETHEUS_ENABLED") == "true",
|
||||||
|
BaseUrl: env("BASE_URL"),
|
||||||
|
Environment: env("ENVIRONMENT"),
|
||||||
|
Smtp: smtp,
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.BaseUrl == "" {
|
||||||
|
log.Fatal("BASE_URL must be set")
|
||||||
|
}
|
||||||
|
if settings.Port == "" {
|
||||||
|
log.Fatal("PORT must be set")
|
||||||
|
}
|
||||||
|
if settings.Environment == "" {
|
||||||
|
log.Fatal("ENVIRONMENT must be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("BASE_URL is %q", settings.BaseUrl)
|
||||||
|
log.Info("ENVIRONMENT is %q", settings.Environment)
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSmtpSettings(env func(string) string) *SmtpSettings {
|
||||||
|
smtp := SmtpSettings{
|
||||||
Host: env("SMTP_HOST"),
|
Host: env("SMTP_HOST"),
|
||||||
Port: env("SMTP_PORT"),
|
Port: env("SMTP_PORT"),
|
||||||
User: env("SMTP_USER"),
|
User: env("SMTP_USER"),
|
||||||
@@ -53,28 +80,6 @@ func NewSettingsFromEnv(env func(string) string) *Settings {
|
|||||||
if smtp.FromName == "" {
|
if smtp.FromName == "" {
|
||||||
log.Fatal("SMTP_FROM_NAME must be set")
|
log.Fatal("SMTP_FROM_NAME must be set")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
settings := &Settings{
|
return &smtp
|
||||||
Port: env("PORT"),
|
|
||||||
PrometheusEnabled: env("PROMETHEUS_ENABLED") == "true",
|
|
||||||
BaseUrl: env("BASE_URL"),
|
|
||||||
Environment: env("ENVIRONMENT"),
|
|
||||||
Smtp: smtp,
|
|
||||||
}
|
|
||||||
|
|
||||||
if settings.BaseUrl == "" {
|
|
||||||
log.Fatal("BASE_URL must be set")
|
|
||||||
}
|
|
||||||
if settings.Port == "" {
|
|
||||||
log.Fatal("PORT must be set")
|
|
||||||
}
|
|
||||||
if settings.Environment == "" {
|
|
||||||
log.Fatal("ENVIRONMENT must be set")
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("BASE_URL is %q", settings.BaseUrl)
|
|
||||||
log.Info("ENVIRONMENT is %q", settings.Environment)
|
|
||||||
|
|
||||||
return settings
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,19 +14,19 @@ import (
|
|||||||
// If it becomes necessary to precalculate snapshots for performance reasons, this can be done in the future.
|
// If it becomes necessary to precalculate snapshots for performance reasons, this can be done in the future.
|
||||||
// But the transaction should always be the source of truth.
|
// But the transaction should always be the source of truth.
|
||||||
type Transaction struct {
|
type Transaction struct {
|
||||||
Id uuid.UUID
|
Id uuid.UUID `db:"id"`
|
||||||
UserId uuid.UUID `db:"user_id"`
|
UserId uuid.UUID `db:"user_id"`
|
||||||
|
|
||||||
Timestamp time.Time
|
Timestamp time.Time `db:"timestamp"`
|
||||||
Party string
|
Party string `db:"party"`
|
||||||
Description string
|
Description string `db:"description"`
|
||||||
|
|
||||||
AccountId *uuid.UUID `db:"account_id"`
|
AccountId *uuid.UUID `db:"account_id"`
|
||||||
TreasureChestId *uuid.UUID `db:"treasure_chest_id"`
|
TreasureChestId *uuid.UUID `db:"treasure_chest_id"`
|
||||||
Value int64
|
Value int64 `db:"value"`
|
||||||
|
|
||||||
// If an error is present, then the transaction is not valid and should not be used for calculations.
|
// If an error is present, then the transaction is not valid and should not be used for calculations.
|
||||||
Error *string
|
Error *string `db:"error"`
|
||||||
CreatedAt time.Time `db:"created_at"`
|
CreatedAt time.Time `db:"created_at"`
|
||||||
// Either a user_id or a transaction_recurring_id
|
// Either a user_id or a transaction_recurring_id
|
||||||
CreatedBy uuid.UUID `db:"created_by"`
|
CreatedBy uuid.UUID `db:"created_by"`
|
||||||
|
|||||||
@@ -7,19 +7,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TransactionRecurring struct {
|
type TransactionRecurring struct {
|
||||||
Id uuid.UUID
|
Id uuid.UUID `db:"id"`
|
||||||
UserId uuid.UUID `db:"user_id"`
|
UserId uuid.UUID `db:"user_id"`
|
||||||
|
|
||||||
IntervalMonths int64 `db:"interval_months"`
|
IntervalMonths int64 `db:"interval_months"`
|
||||||
LastExecution *time.Time `db:"last_execution"`
|
LastExecution *time.Time `db:"last_execution"`
|
||||||
Active bool
|
Active bool `db:"active"`
|
||||||
|
|
||||||
Party string
|
Party string `db:"party"`
|
||||||
Description string
|
Description string `db:"description"`
|
||||||
|
|
||||||
AccountId *uuid.UUID `db:"account_id"`
|
AccountId *uuid.UUID `db:"account_id"`
|
||||||
TreasureChestId *uuid.UUID `db:"treasure_chest_id"`
|
TreasureChestId *uuid.UUID `db:"treasure_chest_id"`
|
||||||
Value int64
|
Value int64 `db:"value"`
|
||||||
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
CreatedAt time.Time `db:"created_at"`
|
||||||
CreatedBy uuid.UUID `db:"created_by"`
|
CreatedBy uuid.UUID `db:"created_by"`
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ import (
|
|||||||
// The money it "holds" distributed across all accounts
|
// The money it "holds" distributed across all accounts
|
||||||
//
|
//
|
||||||
// At the time of writing this, linking it to a specific account doesn't really make sense
|
// At the time of writing this, linking it to a specific account doesn't really make sense
|
||||||
// Imagne a TreasureChest for free time activities, where some money is spend in cash and some other with credit card
|
// Imagine a TreasureChest for free time activities, where some money is spend in cash and some other with credit card.
|
||||||
type TreasureChest struct {
|
type TreasureChest struct {
|
||||||
Id uuid.UUID
|
Id uuid.UUID `db:"id"`
|
||||||
ParentId *uuid.UUID `db:"parent_id"`
|
ParentId *uuid.UUID `db:"parent_id"`
|
||||||
UserId uuid.UUID `db:"user_id"`
|
UserId uuid.UUID `db:"user_id"`
|
||||||
|
|
||||||
Name string
|
Name string `db:"name"`
|
||||||
|
|
||||||
CurrentBalance int64 `db:"current_balance"`
|
CurrentBalance int64 `db:"current_balance"`
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
func TriggerToast(w http.ResponseWriter, r *http.Request, class string, message string) {
|
func TriggerToast(w http.ResponseWriter, r *http.Request, class string, message string) {
|
||||||
if IsHtmx(r) {
|
if IsHtmx(r) {
|
||||||
w.Header().Set("HX-Trigger", fmt.Sprintf(`{"toast": "%v|%v"}`, class, strings.ReplaceAll(message, `"`, `\"`)))
|
w.Header().Set("Hx-Trigger", fmt.Sprintf(`{"toast": "%v|%v"}`, class, strings.ReplaceAll(message, `"`, `\"`)))
|
||||||
} else {
|
} else {
|
||||||
log.Error("Trying to trigger toast in non-HTMX request")
|
log.Error("Trying to trigger toast in non-HTMX request")
|
||||||
}
|
}
|
||||||
@@ -24,19 +24,19 @@ func TriggerToastWithStatus(w http.ResponseWriter, r *http.Request, class string
|
|||||||
|
|
||||||
func DoRedirect(w http.ResponseWriter, r *http.Request, url string) {
|
func DoRedirect(w http.ResponseWriter, r *http.Request, url string) {
|
||||||
if IsHtmx(r) {
|
if IsHtmx(r) {
|
||||||
w.Header().Add("HX-Redirect", url)
|
w.Header().Add("Hx-Redirect", url)
|
||||||
} else {
|
} else {
|
||||||
http.Redirect(w, r, url, http.StatusSeeOther)
|
http.Redirect(w, r, url, http.StatusSeeOther)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WaitMinimumTime[T interface{}](waitTime time.Duration, function func() (T, error)) (T, error) {
|
func WaitMinimumTime[T interface{}](waitTime time.Duration, f func() (T, error)) (T, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
result, err := function()
|
result, err := f()
|
||||||
time.Sleep(waitTime - time.Since(start))
|
time.Sleep(waitTime - time.Since(start))
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsHtmx(r *http.Request) bool {
|
func IsHtmx(r *http.Request) bool {
|
||||||
return r.Header.Get("HX-Request") == "true"
|
return r.Header.Get("Hx-Request") == "true"
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user