feat(transaction-recurring): #135 prohibit deletion of treasure chests if referenced
This commit was merged in pull request #139.
This commit is contained in:
@@ -23,6 +23,7 @@ linters:
|
|||||||
- depguard
|
- depguard
|
||||||
- cyclop
|
- cyclop
|
||||||
- contextcheck
|
- contextcheck
|
||||||
|
- bodyclose # i don't care in the tests, the implementation itself doesn't do http requests
|
||||||
settings:
|
settings:
|
||||||
nestif:
|
nestif:
|
||||||
min-complexity: 6
|
min-complexity: 6
|
||||||
|
|||||||
@@ -268,6 +268,18 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error {
|
|||||||
return fmt.Errorf("treasure chest has transactions: %w", ErrBadRequest)
|
return fmt.Errorf("treasure chest has transactions: %w", ErrBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
recurringCount := 0
|
||||||
|
err = tx.Get(&recurringCount, `
|
||||||
|
SELECT COUNT(*) FROM transaction_recurring WHERE user_id = ? AND treasure_chest_id = ?`,
|
||||||
|
user.Id, id)
|
||||||
|
err = db.TransformAndLogDbError("treasureChest Delete", nil, err)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if recurringCount > 0 {
|
||||||
|
return fmt.Errorf("cannot delete treasure chest with existing recurring transactions: %w", ErrBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
r, err := tx.Exec(`DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
r, err := tx.Exec(`DELETE FROM treasure_chest WHERE id = ? AND user_id = ?`, id, user.Id)
|
||||||
err = db.TransformAndLogDbError("treasureChest Delete", r, err)
|
err = db.TransformAndLogDbError("treasureChest Delete", r, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
package service_test
|
|
||||||
|
|
||||||
// import (
|
|
||||||
// "spned-sparrow"
|
|
||||||
// )
|
|
||||||
//
|
|
||||||
// func TestTreasureChestProhibitDeleteIfTransactionRecurringExists(t *testing.T) {
|
|
||||||
// service := main.Setup
|
|
||||||
//
|
|
||||||
// }
|
|
||||||
@@ -1,154 +0,0 @@
|
|||||||
package test_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"spend-sparrow/internal/db"
|
|
||||||
"spend-sparrow/internal/service"
|
|
||||||
"spend-sparrow/internal/types"
|
|
||||||
"spend-sparrow/mocks"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
settings = types.Settings{
|
|
||||||
Port: "",
|
|
||||||
PrometheusEnabled: false,
|
|
||||||
BaseUrl: "",
|
|
||||||
Environment: "test",
|
|
||||||
Smtp: nil,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSignUp(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
t.Run("should check for correct email address", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
|
||||||
mockClock := mocks.NewMockClock(t)
|
|
||||||
mockMail := mocks.NewMockMail(t)
|
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
|
||||||
|
|
||||||
_, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!")
|
|
||||||
|
|
||||||
assert.Equal(t, service.ErrInvalidEmail, err)
|
|
||||||
})
|
|
||||||
t.Run("should check for password complexity", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
|
||||||
mockClock := mocks.NewMockClock(t)
|
|
||||||
mockMail := mocks.NewMockMail(t)
|
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
|
||||||
|
|
||||||
weakPasswords := []string{
|
|
||||||
"123!ab", // too short
|
|
||||||
"no_upper_case_123",
|
|
||||||
"NO_LOWER_CASE_123",
|
|
||||||
"noSpecialChar123",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, password := range weakPasswords {
|
|
||||||
_, err := underTest.SignUp("some@valid.email", password)
|
|
||||||
assert.Equal(t, service.ErrInvalidPassword, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("should signup correctly", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
|
||||||
mockClock := mocks.NewMockClock(t)
|
|
||||||
mockMail := mocks.NewMockMail(t)
|
|
||||||
|
|
||||||
userId := uuid.New()
|
|
||||||
email := "mail@mail.de"
|
|
||||||
password := "SomeStrongPassword123!"
|
|
||||||
salt := []byte("salt")
|
|
||||||
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
|
||||||
|
|
||||||
expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
|
|
||||||
|
|
||||||
mockRandom.EXPECT().UUID().Return(userId, nil)
|
|
||||||
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
|
||||||
mockClock.EXPECT().Now().Return(createTime)
|
|
||||||
mockAuthDb.EXPECT().InsertUser(expected).Return(nil)
|
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
|
||||||
actual, err := underTest.SignUp(email, password)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, expected, actual)
|
|
||||||
})
|
|
||||||
t.Run("should return ErrAccountExists", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
|
||||||
mockClock := mocks.NewMockClock(t)
|
|
||||||
mockMail := mocks.NewMockMail(t)
|
|
||||||
|
|
||||||
userId := uuid.New()
|
|
||||||
email := "some@valid.email"
|
|
||||||
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
|
||||||
password := "SomeStrongPassword123!"
|
|
||||||
salt := []byte("salt")
|
|
||||||
user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
|
|
||||||
|
|
||||||
mockRandom.EXPECT().UUID().Return(user.Id, nil)
|
|
||||||
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
|
||||||
mockClock.EXPECT().Now().Return(createTime)
|
|
||||||
|
|
||||||
mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists)
|
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
|
||||||
|
|
||||||
_, err := underTest.SignUp(user.Email, password)
|
|
||||||
assert.Equal(t, service.ErrAccountExists, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSendVerificationMail(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
t.Run("should use stored token and send mail", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
token := types.NewToken(
|
|
||||||
uuid.New(),
|
|
||||||
"sessionId",
|
|
||||||
"someRandomTokenToUse",
|
|
||||||
types.TokenTypeEmailVerify,
|
|
||||||
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
|
|
||||||
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
|
|
||||||
tokens := []*types.Token{token}
|
|
||||||
|
|
||||||
email := "some@email.de"
|
|
||||||
userId := uuid.New()
|
|
||||||
|
|
||||||
mockAuthDb := mocks.NewMockAuth(t)
|
|
||||||
mockRandom := mocks.NewMockRandom(t)
|
|
||||||
mockClock := mocks.NewMockClock(t)
|
|
||||||
mockMail := mocks.NewMockMail(t)
|
|
||||||
|
|
||||||
mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify).Return(tokens, nil)
|
|
||||||
|
|
||||||
mockMail.EXPECT().SendMail(email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
|
|
||||||
return strings.Contains(message, token.Token)
|
|
||||||
})).Return()
|
|
||||||
|
|
||||||
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
|
||||||
|
|
||||||
underTest.SendVerificationMail(userId, email)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
205
test/auth_it_test.go
Normal file
205
test/auth_it_test.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package test_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"spend-sparrow/internal/db"
|
||||||
|
"spend-sparrow/internal/types"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupDb(t *testing.T) *sqlx.DB {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
d, err := sqlx.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error opening database: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := d.Close()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
err = db.RunMigrations(d, "../")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error running migrations: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUser(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("should insert and get the same", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
|
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
|
err := underTest.InsertUser(expected)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
actual, err := underTest.GetUser(expected.Id)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, actual)
|
||||||
|
|
||||||
|
actual, err = underTest.GetUserByEmail(expected.Email)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, actual)
|
||||||
|
})
|
||||||
|
t.Run("should return ErrNotFound", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
_, err := underTest.GetUserByEmail("nonExistentEmail")
|
||||||
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
|
})
|
||||||
|
t.Run("should return ErrUserExist", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
|
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
|
err := underTest.InsertUser(user)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = underTest.InsertUser(user)
|
||||||
|
assert.Equal(t, db.ErrAlreadyExists, err)
|
||||||
|
})
|
||||||
|
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
|
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
||||||
|
|
||||||
|
err := underTest.InsertUser(user)
|
||||||
|
assert.Equal(t, types.ErrInternal, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("should insert and get the same", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
|
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
|
||||||
|
err := underTest.InsertToken(expected)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
actual, err := underTest.GetToken(expected.Token)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, expected, actual)
|
||||||
|
|
||||||
|
expected.SessionId = ""
|
||||||
|
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []*types.Token{expected}, actuals)
|
||||||
|
|
||||||
|
expected.SessionId = "sessionId"
|
||||||
|
expected.UserId = uuid.Nil
|
||||||
|
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []*types.Token{expected}, actuals)
|
||||||
|
})
|
||||||
|
t.Run("should insert and return multiple tokens", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
|
expiresAt := createAt.Add(24 * time.Hour)
|
||||||
|
userId := uuid.New()
|
||||||
|
expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
|
||||||
|
|
||||||
|
err := underTest.InsertToken(expected1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = underTest.InsertToken(expected2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expected1.UserId = uuid.Nil
|
||||||
|
expected2.UserId = uuid.Nil
|
||||||
|
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
||||||
|
|
||||||
|
expected1.SessionId = ""
|
||||||
|
expected2.SessionId = ""
|
||||||
|
expected1.UserId = userId
|
||||||
|
expected2.UserId = userId
|
||||||
|
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
||||||
|
})
|
||||||
|
t.Run("should return ErrNotFound", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
_, err := underTest.GetToken("nonExistent")
|
||||||
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
|
|
||||||
|
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify)
|
||||||
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
|
|
||||||
|
_, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify)
|
||||||
|
assert.Equal(t, db.ErrNotFound, err)
|
||||||
|
})
|
||||||
|
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
||||||
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
|
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
||||||
|
|
||||||
|
err := underTest.InsertUser(user)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = underTest.InsertUser(user)
|
||||||
|
assert.Equal(t, db.ErrAlreadyExists, err)
|
||||||
|
})
|
||||||
|
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
d := setupDb(t)
|
||||||
|
|
||||||
|
underTest := db.NewAuthSqlite(d)
|
||||||
|
|
||||||
|
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
||||||
|
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
||||||
|
|
||||||
|
err := underTest.InsertUser(user)
|
||||||
|
assert.Equal(t, types.ErrInternal, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -2,204 +2,153 @@ package test_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"spend-sparrow/internal/db"
|
"spend-sparrow/internal/db"
|
||||||
|
"spend-sparrow/internal/service"
|
||||||
"spend-sparrow/internal/types"
|
"spend-sparrow/internal/types"
|
||||||
|
"spend-sparrow/mocks"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupDb(t *testing.T) *sqlx.DB {
|
var (
|
||||||
t.Helper()
|
settings = types.Settings{
|
||||||
|
Port: "",
|
||||||
|
PrometheusEnabled: false,
|
||||||
|
BaseUrl: "",
|
||||||
|
Environment: "test",
|
||||||
|
Smtp: nil,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
d, err := sqlx.Open("sqlite3", ":memory:")
|
func TestSignUp(t *testing.T) {
|
||||||
if err != nil {
|
t.Parallel()
|
||||||
t.Fatalf("Error opening database: %v", err)
|
t.Run("should check for correct email address", func(t *testing.T) {
|
||||||
}
|
t.Parallel()
|
||||||
t.Cleanup(func() {
|
|
||||||
err := d.Close()
|
mockAuthDb := mocks.NewMockAuth(t)
|
||||||
if err != nil {
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
panic(err)
|
mockClock := mocks.NewMockClock(t)
|
||||||
}
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
|
_, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!")
|
||||||
|
|
||||||
|
assert.Equal(t, service.ErrInvalidEmail, err)
|
||||||
})
|
})
|
||||||
|
t.Run("should check for password complexity", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
err = db.RunMigrations(d, "../")
|
mockAuthDb := mocks.NewMockAuth(t)
|
||||||
if err != nil {
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
t.Fatalf("Error running migrations: %v", err)
|
mockClock := mocks.NewMockClock(t)
|
||||||
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
|
weakPasswords := []string{
|
||||||
|
"123!ab", // too short
|
||||||
|
"no_upper_case_123",
|
||||||
|
"NO_LOWER_CASE_123",
|
||||||
|
"noSpecialChar123",
|
||||||
}
|
}
|
||||||
|
|
||||||
return d
|
for _, password := range weakPasswords {
|
||||||
}
|
_, err := underTest.SignUp("some@valid.email", password)
|
||||||
|
assert.Equal(t, service.ErrInvalidPassword, err)
|
||||||
func TestUser(t *testing.T) {
|
}
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("should insert and get the same", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
d := setupDb(t)
|
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
|
||||||
expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
|
||||||
|
|
||||||
err := underTest.InsertUser(expected)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
actual, err := underTest.GetUser(expected.Id)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, actual)
|
|
||||||
|
|
||||||
actual, err = underTest.GetUserByEmail(expected.Email)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, actual)
|
|
||||||
})
|
})
|
||||||
t.Run("should return ErrNotFound", func(t *testing.T) {
|
t.Run("should signup correctly", func(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
d := setupDb(t)
|
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
|
||||||
|
|
||||||
_, err := underTest.GetUserByEmail("nonExistentEmail")
|
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
|
||||||
})
|
|
||||||
t.Run("should return ErrUserExist", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
d := setupDb(t)
|
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = underTest.InsertUser(user)
|
|
||||||
assert.Equal(t, db.ErrAlreadyExists, err)
|
|
||||||
})
|
|
||||||
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
d := setupDb(t)
|
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
|
||||||
assert.Equal(t, types.ErrInternal, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestToken(t *testing.T) {
|
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
t.Run("should insert and get the same", func(t *testing.T) {
|
mockAuthDb := mocks.NewMockAuth(t)
|
||||||
t.Parallel()
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
d := setupDb(t)
|
mockClock := mocks.NewMockClock(t)
|
||||||
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
|
||||||
expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
|
|
||||||
|
|
||||||
err := underTest.InsertToken(expected)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
actual, err := underTest.GetToken(expected.Token)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, expected, actual)
|
|
||||||
|
|
||||||
expected.SessionId = ""
|
|
||||||
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []*types.Token{expected}, actuals)
|
|
||||||
|
|
||||||
expected.SessionId = "sessionId"
|
|
||||||
expected.UserId = uuid.Nil
|
|
||||||
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []*types.Token{expected}, actuals)
|
|
||||||
})
|
|
||||||
t.Run("should insert and return multiple tokens", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
d := setupDb(t)
|
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
|
||||||
|
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
|
||||||
expiresAt := createAt.Add(24 * time.Hour)
|
|
||||||
userId := uuid.New()
|
userId := uuid.New()
|
||||||
expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt)
|
email := "mail@mail.de"
|
||||||
expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
|
password := "SomeStrongPassword123!"
|
||||||
|
salt := []byte("salt")
|
||||||
|
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
|
||||||
|
|
||||||
|
mockRandom.EXPECT().UUID().Return(userId, nil)
|
||||||
|
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
||||||
|
mockClock.EXPECT().Now().Return(createTime)
|
||||||
|
mockAuthDb.EXPECT().InsertUser(expected).Return(nil)
|
||||||
|
|
||||||
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
actual, err := underTest.SignUp(email, password)
|
||||||
|
|
||||||
err := underTest.InsertToken(expected1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = underTest.InsertToken(expected2)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
expected1.UserId = uuid.Nil
|
assert.Equal(t, expected, actual)
|
||||||
expected2.UserId = uuid.Nil
|
|
||||||
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
|
||||||
|
|
||||||
expected1.SessionId = ""
|
|
||||||
expected2.SessionId = ""
|
|
||||||
expected1.UserId = userId
|
|
||||||
expected2.UserId = userId
|
|
||||||
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
|
|
||||||
})
|
})
|
||||||
t.Run("should return ErrNotFound", func(t *testing.T) {
|
t.Run("should return ErrAccountExists", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
d := setupDb(t)
|
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
mockAuthDb := mocks.NewMockAuth(t)
|
||||||
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
|
mockClock := mocks.NewMockClock(t)
|
||||||
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
_, err := underTest.GetToken("nonExistent")
|
userId := uuid.New()
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
email := "some@valid.email"
|
||||||
|
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
password := "SomeStrongPassword123!"
|
||||||
|
salt := []byte("salt")
|
||||||
|
user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime)
|
||||||
|
|
||||||
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify)
|
mockRandom.EXPECT().UUID().Return(user.Id, nil)
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
|
||||||
|
mockClock.EXPECT().Now().Return(createTime)
|
||||||
|
|
||||||
_, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify)
|
mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists)
|
||||||
assert.Equal(t, db.ErrNotFound, err)
|
|
||||||
})
|
|
||||||
t.Run("should return ErrAlreadyExists", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
d := setupDb(t)
|
|
||||||
|
|
||||||
underTest := db.NewAuthSqlite(d)
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
|
_, err := underTest.SignUp(user.Email, password)
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
assert.Equal(t, service.ErrAccountExists, err)
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
|
})
|
||||||
|
}
|
||||||
err := underTest.InsertUser(user)
|
|
||||||
require.NoError(t, err)
|
func TestSendVerificationMail(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
err = underTest.InsertUser(user)
|
t.Run("should use stored token and send mail", func(t *testing.T) {
|
||||||
assert.Equal(t, db.ErrAlreadyExists, err)
|
t.Parallel()
|
||||||
})
|
|
||||||
t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) {
|
token := types.NewToken(
|
||||||
t.Parallel()
|
uuid.New(),
|
||||||
d := setupDb(t)
|
"sessionId",
|
||||||
|
"someRandomTokenToUse",
|
||||||
underTest := db.NewAuthSqlite(d)
|
types.TokenTypeEmailVerify,
|
||||||
|
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||||
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
|
time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
|
||||||
user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
|
tokens := []*types.Token{token}
|
||||||
|
|
||||||
err := underTest.InsertUser(user)
|
email := "some@email.de"
|
||||||
assert.Equal(t, types.ErrInternal, err)
|
userId := uuid.New()
|
||||||
|
|
||||||
|
mockAuthDb := mocks.NewMockAuth(t)
|
||||||
|
mockRandom := mocks.NewMockRandom(t)
|
||||||
|
mockClock := mocks.NewMockClock(t)
|
||||||
|
mockMail := mocks.NewMockMail(t)
|
||||||
|
|
||||||
|
mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify).Return(tokens, nil)
|
||||||
|
|
||||||
|
mockMail.EXPECT().SendMail(email, "Welcome to spend-sparrow", mock.MatchedBy(func(message string) bool {
|
||||||
|
return strings.Contains(message, token.Token)
|
||||||
|
})).Return()
|
||||||
|
|
||||||
|
underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings)
|
||||||
|
|
||||||
|
underTest.SendVerificationMail(userId, email)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
252
test/it_test.go
Normal file
252
test/it_test.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
package test_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"spend-sparrow/internal"
|
||||||
|
"spend-sparrow/internal/service"
|
||||||
|
"spend-sparrow/internal/types"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/html"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
httpClient = http.Client{
|
||||||
|
// Disable redirect following
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
port atomic.Int64
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupIntegrationTest(t *testing.T) (*sqlx.DB, string, context.Context) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx, done := context.WithCancel(context.Background())
|
||||||
|
t.Cleanup(done)
|
||||||
|
|
||||||
|
db, err := sqlx.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not open Database data.db: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := db.Close()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
testPort := port.Add(1)
|
||||||
|
testPort += 1024
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = internal.Run(ctx, db, "../", getEnv(testPort))
|
||||||
|
}()
|
||||||
|
|
||||||
|
basePath := "http://localhost:" + strconv.Itoa(int(testPort))
|
||||||
|
|
||||||
|
err = waitForReady(t, ctx, 5*time.Second, basePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return db, basePath, ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEnv(port int64) func(string) string {
|
||||||
|
return func(key string) string {
|
||||||
|
switch key {
|
||||||
|
case "PORT":
|
||||||
|
return strconv.Itoa(int(port))
|
||||||
|
case "SMTP_ENABLED":
|
||||||
|
return "false"
|
||||||
|
case "PROMETHEUS_ENABLED":
|
||||||
|
return "false"
|
||||||
|
case "BASE_URL":
|
||||||
|
return "http://localhost:" + strconv.Itoa(int(port))
|
||||||
|
case "ENVIRONMENT":
|
||||||
|
return "test"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForReady calls the specified endpoint until it gets a 200
|
||||||
|
// response or until the context is cancelled or the timeout is
|
||||||
|
// reached.
|
||||||
|
func waitForReady(
|
||||||
|
t *testing.T,
|
||||||
|
ctx context.Context,
|
||||||
|
timeout time.Duration,
|
||||||
|
endpoint string,
|
||||||
|
) error {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
client := http.Client{}
|
||||||
|
startTime := time.Now()
|
||||||
|
for {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err == nil && resp.StatusCode == http.StatusOK {
|
||||||
|
return resp.Body.Close()
|
||||||
|
} else if err == nil {
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
if time.Since(startTime) >= timeout {
|
||||||
|
t.Fatal("timeout reached while waiting for endpoint")
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
// wait a little while between checks
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findCsrfToken(t *testing.T, data *html.Node) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
token := getTokenAttribute(t, data)
|
||||||
|
if token != "" {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.FirstChild != nil {
|
||||||
|
if token = findCsrfToken(t, data.FirstChild); token != "" {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if data.NextSibling != nil {
|
||||||
|
if token = findCsrfToken(t, data.NextSibling); token != "" {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTokenAttribute(t *testing.T, data *html.Node) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for _, attr := range data.Attr {
|
||||||
|
if attr.Key == "hx-headers" {
|
||||||
|
var data map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(attr.Val), &data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
result, ok := data["Csrf-Token"].(string)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBody(t *testing.T, body io.ReadCloser) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, string, string) {
|
||||||
|
t.Helper()
|
||||||
|
userId := uuid.New()
|
||||||
|
sessionId := "session-id" + add
|
||||||
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
csrfToken := "my-verifying-token" + add
|
||||||
|
email := add + "mail@mail.de"
|
||||||
|
|
||||||
|
_, err := db.Exec(`
|
||||||
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
|
VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = db.Exec(`
|
||||||
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = db.Exec(`
|
||||||
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
|
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return userId, csrfToken, sessionId
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAnonymousSession(t *testing.T, ctx context.Context, basePath string) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
html, err := html.Parse(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return findCsrfToken(t, html), findCookie(t, resp).Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func findCookie(t *testing.T, resp *http.Response) *http.Cookie {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
for _, cookie := range resp.Cookies() {
|
||||||
|
if cookie.Name == "id" {
|
||||||
|
return cookie
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func doAuthenticatedRequest(
|
||||||
|
t *testing.T,
|
||||||
|
ctx context.Context,
|
||||||
|
method string,
|
||||||
|
path string,
|
||||||
|
formData url.Values,
|
||||||
|
csrfToken string,
|
||||||
|
sessionId string,
|
||||||
|
) *http.Response {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, path, strings.NewReader(formData.Encode()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req.Header.Set("Csrf-Token", csrfToken)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Hx-Request", "true")
|
||||||
|
req.AddCookie(&http.Cookie{Name: "id", Value: sessionId})
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
}
|
||||||
@@ -1,39 +1,20 @@
|
|||||||
package test_test
|
package test_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"spend-sparrow/internal/service"
|
||||||
|
"spend-sparrow/internal/types"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"spend-sparrow/internal"
|
|
||||||
"spend-sparrow/internal/service"
|
|
||||||
"spend-sparrow/internal/types"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/net/html"
|
"golang.org/x/net/html"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
httpClient = http.Client{
|
|
||||||
// Disable redirect following
|
|
||||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
||||||
return http.ErrUseLastResponse
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
port atomic.Int64
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIntegrationSecurityHeader(t *testing.T) {
|
func TestIntegrationSecurityHeader(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
t.Run("should keep caching for static content", func(t *testing.T) {
|
t.Run("should keep caching for static content", func(t *testing.T) {
|
||||||
@@ -1906,196 +1887,3 @@ func TestIntegrationAccount(t *testing.T) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, string, string) {
|
|
||||||
t.Helper()
|
|
||||||
userId := uuid.New()
|
|
||||||
sessionId := "session-id" + add
|
|
||||||
pass := service.GetHashPassword("password", []byte("salt"))
|
|
||||||
csrfToken := "my-verifying-token" + add
|
|
||||||
email := add + "mail@mail.de"
|
|
||||||
|
|
||||||
_, err := db.Exec(`
|
|
||||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
|
||||||
VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = db.Exec(`
|
|
||||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
|
||||||
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = db.Exec(`
|
|
||||||
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
|
||||||
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return userId, csrfToken, sessionId
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAnonymousSession(t *testing.T, ctx context.Context, basePath string) (string, string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
||||||
html, err := html.Parse(resp.Body)
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return findCsrfToken(t, html), findCookie(t, resp).Value
|
|
||||||
}
|
|
||||||
|
|
||||||
func findCookie(t *testing.T, resp *http.Response) *http.Cookie {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
for _, cookie := range resp.Cookies() {
|
|
||||||
if cookie.Name == "id" {
|
|
||||||
return cookie
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupIntegrationTest(t *testing.T) (*sqlx.DB, string, context.Context) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ctx, done := context.WithCancel(context.Background())
|
|
||||||
t.Cleanup(done)
|
|
||||||
|
|
||||||
db, err := sqlx.Open("sqlite3", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Could not open Database data.db: %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := db.Close()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
testPort := port.Add(1)
|
|
||||||
testPort += 1024
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = internal.Run(ctx, db, "../", getEnv(testPort))
|
|
||||||
}()
|
|
||||||
|
|
||||||
basePath := "http://localhost:" + strconv.Itoa(int(testPort))
|
|
||||||
|
|
||||||
err = waitForReady(t, ctx, 5*time.Second, basePath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return db, basePath, ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
func getEnv(port int64) func(string) string {
|
|
||||||
return func(key string) string {
|
|
||||||
switch key {
|
|
||||||
case "PORT":
|
|
||||||
return strconv.Itoa(int(port))
|
|
||||||
case "SMTP_ENABLED":
|
|
||||||
return "false"
|
|
||||||
case "PROMETHEUS_ENABLED":
|
|
||||||
return "false"
|
|
||||||
case "BASE_URL":
|
|
||||||
return "http://localhost:" + strconv.Itoa(int(port))
|
|
||||||
case "ENVIRONMENT":
|
|
||||||
return "test"
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitForReady calls the specified endpoint until it gets a 200
|
|
||||||
// response or until the context is cancelled or the timeout is
|
|
||||||
// reached.
|
|
||||||
func waitForReady(
|
|
||||||
t *testing.T,
|
|
||||||
ctx context.Context,
|
|
||||||
timeout time.Duration,
|
|
||||||
endpoint string,
|
|
||||||
) error {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
client := http.Client{}
|
|
||||||
startTime := time.Now()
|
|
||||||
for {
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err == nil && resp.StatusCode == http.StatusOK {
|
|
||||||
return resp.Body.Close()
|
|
||||||
} else if err == nil {
|
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
default:
|
|
||||||
if time.Since(startTime) >= timeout {
|
|
||||||
t.Fatal("timeout reached while waiting for endpoint")
|
|
||||||
return types.ErrInternal
|
|
||||||
}
|
|
||||||
// wait a little while between checks
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func findCsrfToken(t *testing.T, data *html.Node) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
token := getTokenAttribute(t, data)
|
|
||||||
if token != "" {
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.FirstChild != nil {
|
|
||||||
if token = findCsrfToken(t, data.FirstChild); token != "" {
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if data.NextSibling != nil {
|
|
||||||
if token = findCsrfToken(t, data.NextSibling); token != "" {
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTokenAttribute(t *testing.T, data *html.Node) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
for _, attr := range data.Attr {
|
|
||||||
if attr.Key == "hx-headers" {
|
|
||||||
var data map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(attr.Val), &data)
|
|
||||||
require.NoError(t, err)
|
|
||||||
result, ok := data["Csrf-Token"].(string)
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func readBody(t *testing.T, body io.ReadCloser) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
data, err := io.ReadAll(body)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return string(data)
|
|
||||||
}
|
|
||||||
51
test/treasure_chest_it_test.go
Normal file
51
test/treasure_chest_it_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package test_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTreasureChestShouldNotDeleteIfTransactionRecurringExists(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, baseUrl, ctx := setupIntegrationTest(t)
|
||||||
|
_, csrfToken, sessionId := createValidUserSession(t, db, baseUrl)
|
||||||
|
|
||||||
|
formData := url.Values{
|
||||||
|
"name": {"Test Treasure Chest Parent"},
|
||||||
|
}
|
||||||
|
resp := doAuthenticatedRequest(t, ctx, http.MethodPost, baseUrl+"/treasurechest/new", formData, csrfToken, sessionId)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var parentId string
|
||||||
|
err := db.Get(&parentId, "SELECT id FROM treasure_chest")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
formData = url.Values{
|
||||||
|
"name": {"Test Treasure Chest Child"},
|
||||||
|
"parent-id": {parentId},
|
||||||
|
}
|
||||||
|
resp = doAuthenticatedRequest(t, ctx, http.MethodPost, baseUrl+"/treasurechest/new", formData, csrfToken, sessionId)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var childId string
|
||||||
|
err = db.Get(&childId, "SELECT id FROM treasure_chest WHERE parent_id = ?", parentId)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
formData = url.Values{
|
||||||
|
"treasure-chest-id": {childId},
|
||||||
|
"value": {"100"},
|
||||||
|
"interval-months": {"1"},
|
||||||
|
"party": {"Test Party"},
|
||||||
|
}
|
||||||
|
resp = doAuthenticatedRequest(t, ctx, http.MethodPost, baseUrl+"/transaction-recurring/new", formData, csrfToken, sessionId)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
resp = doAuthenticatedRequest(t, ctx, http.MethodDelete, baseUrl+"/treasurechest/"+childId, nil, csrfToken, sessionId)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.Contains(t, resp.Header.Get("Hx-Trigger"), "cannot delete treasure chest with existing recurring transactions")
|
||||||
|
}
|
||||||
@@ -1 +0,0 @@
|
|||||||
package test_test
|
|
||||||
Reference in New Issue
Block a user