From 128a2fc4d78d04bc14d7fb2d74c783102a124e6a Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Sun, 25 May 2025 16:36:30 +0200 Subject: [PATCH] fix: lint errors --- .golangci.yaml | 28 + db/auth.go | 39 +- db/auth_test.go | 89 +- db/error.go | 2 +- handler/auth.go | 19 +- handler/error.go | 7 +- handler/middleware/authenticate.go | 15 +- handler/middleware/cache_control.go | 1 - .../middleware/cross_site_request_forgery.go | 6 +- handler/middleware/gzip.go | 4 +- handler/middleware/security_headers.go | 3 +- handler/middleware/wrapper.go | 6 +- handler/render.go | 1 - handler/transaction_recurring.go | 1 - main.go | 26 +- main_test.go | 1144 +++++++++-------- service/account.go | 9 +- service/auth.go | 134 +- service/auth_test.go | 45 +- service/default.go | 11 +- service/mail.go | 14 +- service/money.go | 8 - service/money_test.go | 80 -- service/transaction.go | 61 +- service/transaction_recurring.go | 59 +- service/treasure_chest.go | 20 +- template/account/account.templ | 1 + template/layout.templ | 2 +- types/account.go | 6 +- types/auth.go | 26 +- types/savings_plan.go | 26 - types/settings.go | 61 +- types/transaction.go | 12 +- types/transaction_recurring.go | 10 +- types/treasure_chest.go | 6 +- utils/http.go | 10 +- 36 files changed, 1024 insertions(+), 968 deletions(-) create mode 100644 .golangci.yaml delete mode 100644 service/money.go delete mode 100644 service/money_test.go delete mode 100644 types/savings_plan.go diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..68bd4c2 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,28 @@ +version: '2' +linters: + default: all + disable: + - wsl + - wrapcheck + - varnamelen + - revive # should probably be enabled + - nlreturn + - mnd # should probably be enabled + - lll # should probably be enabled + - ireturn # should probably be enabled + - interfacebloat + - iface + - goconst # should probably be enabled + - gocognit # should probably be enabled + - gochecknoglobals # should probably be enabled + - funlen + - maintidx + - exhaustruct + - dupword # should probably be enabled + - dupl # should probably be enabled + - depguard + - cyclop + - contextcheck + settings: + nestif: + min-complexity: 6 diff --git a/db/auth.go b/db/auth.go index 8923c38..d83f659 100644 --- a/db/auth.go +++ b/db/auth.go @@ -1,6 +1,7 @@ package db import ( + "errors" "spend-sparrow/log" "spend-sparrow/types" @@ -89,7 +90,7 @@ func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) { FROM user WHERE email = ?`, email).Scan(&userId, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } else { log.Error("SQL error GetUser: %v", err) @@ -116,7 +117,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) { FROM user WHERE user_id = ?`, userId).Scan(&email, &emailVerified, &emailVerifiedAt, &password, &salt, &createdAt) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } else { log.Error("SQL error GetUser %v", err) @@ -128,7 +129,6 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) { } func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { - tx, err := db.db.Begin() if err != nil { log.Error("Could not start transaction: %v", err) @@ -216,7 +216,7 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) { WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr) if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { log.Info("Token '%v' not found", token) return nil, ErrNotFound } else { @@ -241,7 +241,6 @@ func (db AuthSqlite) GetToken(token string) (*types.Token, error) { } func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) { - query, err := db.db.Query(` SELECT token, created_at, expires_at FROM token @@ -257,7 +256,6 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types. } func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) { - query, err := db.db.Query(` SELECT token, created_at, expires_at FROM token @@ -325,7 +323,6 @@ func (db AuthSqlite) DeleteToken(token string) error { } func (db AuthSqlite) InsertSession(session *types.Session) error { - _, err := db.db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)`, session.Id, session.UserId, session.CreatedAt, session.ExpiresAt) @@ -339,7 +336,6 @@ func (db AuthSqlite) InsertSession(session *types.Session) error { } func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { - var ( userId uuid.UUID createdAt time.Time @@ -360,9 +356,9 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { } func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) { - - sessions, err := db.db.Query(` - SELECT session_id, created_at, expires_at + var sessions []*types.Session + err := db.db.Select(&sessions, ` + SELECT * FROM session WHERE user_id = ?`, userId) if err != nil { @@ -370,26 +366,7 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) { return nil, types.ErrInternal } - var result []*types.Session - - for sessions.Next() { - var ( - sessionId string - createdAt time.Time - expiresAt time.Time - ) - - err := sessions.Scan(&sessionId, &createdAt, &expiresAt) - if err != nil { - log.Error("Could not scan session: %v", err) - return nil, types.ErrInternal - } - - session := types.NewSession(sessionId, userId, createdAt, expiresAt) - result = append(result, session) - } - - return result, nil + return sessions, nil } func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error { diff --git a/db/auth_test.go b/db/auth_test.go index 04b505e..4e87066 100644 --- a/db/auth_test.go +++ b/db/auth_test.go @@ -1,6 +1,7 @@ -package db +package db_test import ( + "spend-sparrow/db" "spend-sparrow/types" "testing" "time" @@ -8,26 +9,29 @@ import ( "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func setupDb(t *testing.T) *sqlx.DB { - db, err := sqlx.Open("sqlite3", ":memory:") + t.Helper() + + d, err := sqlx.Open("sqlite3", ":memory:") if err != nil { t.Fatalf("Error opening database: %v", err) } t.Cleanup(func() { - err := db.Close() + err := d.Close() if err != nil { panic(err) } }) - err = RunMigrations(db, "../") + err = db.RunMigrations(d, "../") if err != nil { t.Fatalf("Error running migrations: %v", err) } - return db + return d } func TestUser(t *testing.T) { @@ -35,55 +39,55 @@ func TestUser(t *testing.T) { t.Run("should insert and get the same", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(expected) - assert.Nil(t, err) + require.NoError(t, err) actual, err := underTest.GetUser(expected.Id) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, expected, actual) actual, err = underTest.GetUserByEmail(expected.Email) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, expected, actual) }) t.Run("should return ErrNotFound", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) _, err := underTest.GetUserByEmail("nonExistentEmail") - assert.Equal(t, ErrNotFound, err) + assert.Equal(t, db.ErrNotFound, err) }) t.Run("should return ErrUserExist", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(user) - assert.Nil(t, err) + require.NoError(t, err) err = underTest.InsertUser(user) - assert.Equal(t, ErrAlreadyExists, err) + assert.Equal(t, db.ErrAlreadyExists, err) }) t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) @@ -98,37 +102,37 @@ func TestToken(t *testing.T) { t.Run("should insert and get the same", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) expiresAt := createAt.Add(24 * time.Hour) expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt) err := underTest.InsertToken(expected) - assert.Nil(t, err) + require.NoError(t, err) actual, err := underTest.GetToken(expected.Token) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, expected, actual) expected.SessionId = "" actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, []*types.Token{expected}, actuals) expected.SessionId = "sessionId" expected.UserId = uuid.Nil actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, []*types.Token{expected}, actuals) }) t.Run("should insert and return multiple tokens", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) expiresAt := createAt.Add(24 * time.Hour) @@ -137,14 +141,14 @@ func TestToken(t *testing.T) { expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt) err := underTest.InsertToken(expected1) - assert.Nil(t, err) + require.NoError(t, err) err = underTest.InsertToken(expected2) - assert.Nil(t, err) + require.NoError(t, err) expected1.UserId = uuid.Nil expected2.UserId = uuid.Nil actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, []*types.Token{expected1, expected2}, actuals) expected1.SessionId = "" @@ -152,46 +156,45 @@ func TestToken(t *testing.T) { expected1.UserId = userId expected2.UserId = userId actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, []*types.Token{expected1, expected2}, actuals) - }) t.Run("should return ErrNotFound", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) _, err := underTest.GetToken("nonExistent") - assert.Equal(t, ErrNotFound, err) + assert.Equal(t, db.ErrNotFound, err) _, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify) - assert.Equal(t, ErrNotFound, err) + assert.Equal(t, db.ErrNotFound, err) _, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify) - assert.Equal(t, ErrNotFound, err) + assert.Equal(t, db.ErrNotFound, err) }) t.Run("should return ErrAlreadyExists", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(user) - assert.Nil(t, err) + require.NoError(t, err) err = underTest.InsertUser(user) - assert.Equal(t, ErrAlreadyExists, err) + assert.Equal(t, db.ErrAlreadyExists, err) }) t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Parallel() - db := setupDb(t) + d := setupDb(t) - underTest := AuthSqlite{db: db} + underTest := db.NewAuthSqlite(d) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) diff --git a/db/error.go b/db/error.go index f83ba68..194b0e8 100644 --- a/db/error.go +++ b/db/error.go @@ -14,7 +14,7 @@ var ( func TransformAndLogDbError(module string, r sql.Result, err error) error { if err != nil { - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return ErrNotFound } log.Error("%v: %v", module, err) diff --git a/handler/auth.go b/handler/auth.go index e656671..88f99cc 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -77,7 +77,6 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc { func (handler AuthImpl) handleSignIn() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) { session := middleware.GetSession(r) email := r.FormValue("email") @@ -95,7 +94,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc { }) if err != nil { - if err == service.ErrInvalidCredentials { + if errors.Is(err, service.ErrInvalidCredentials) { utils.TriggerToastWithStatus(w, r, "error", "Invalid email or password", http.StatusUnauthorized) } else { utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError) @@ -166,7 +165,6 @@ func (handler AuthImpl) handleVerifyResendComp() http.HandlerFunc { func (handler AuthImpl) handleSignUpVerifyResponsePage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - token := r.URL.Query().Get("token") err := handler.service.VerifyUserEmail(token) @@ -203,13 +201,14 @@ func (handler AuthImpl) handleSignUp() http.HandlerFunc { }) if err != nil { - if errors.Is(err, types.ErrInternal) { + switch { + case errors.Is(err, types.ErrInternal): utils.TriggerToastWithStatus(w, r, "error", "An error occurred", http.StatusInternalServerError) return - } else if errors.Is(err, service.ErrInvalidEmail) { + case errors.Is(err, service.ErrInvalidEmail): utils.TriggerToastWithStatus(w, r, "error", "The email provided is invalid", http.StatusBadRequest) return - } else if errors.Is(err, service.ErrInvalidPassword) { + case errors.Is(err, service.ErrInvalidPassword): utils.TriggerToastWithStatus(w, r, "error", service.ErrInvalidPassword.Error(), http.StatusBadRequest) return } @@ -272,7 +271,7 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { err := handler.service.DeleteAccount(user, password) if err != nil { - if err == service.ErrInvalidCredentials { + if errors.Is(err, service.ErrInvalidCredentials) { utils.TriggerToastWithStatus(w, r, "error", "Password not correct", http.StatusBadRequest) } else { utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) @@ -286,7 +285,6 @@ func (handler AuthImpl) handleDeleteAccountComp() http.HandlerFunc { func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - isPasswordReset := r.URL.Query().Has("token") user := middleware.GetUser(r) @@ -303,7 +301,6 @@ func (handler AuthImpl) handleChangePasswordPage() http.HandlerFunc { func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - session := middleware.GetSession(r) user := middleware.GetUser(r) if session == nil || user == nil { @@ -326,7 +323,6 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user := middleware.GetUser(r) if user != nil { utils.DoRedirect(w, r, "/") @@ -340,7 +336,6 @@ func (handler AuthImpl) handleForgotPasswordPage() http.HandlerFunc { func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - email := r.FormValue("email") if email == "" { utils.TriggerToastWithStatus(w, r, "error", "Please enter an email", http.StatusBadRequest) @@ -362,7 +357,7 @@ func (handler AuthImpl) handleForgotPasswordComp() http.HandlerFunc { func (handler AuthImpl) handleForgotPasswordResponseComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - pageUrl, err := url.Parse(r.Header.Get("HX-Current-URL")) + pageUrl, err := url.Parse(r.Header.Get("Hx-Current-Url")) if err != nil { log.Error("Could not get current URL: %v", err) utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError) diff --git a/handler/error.go b/handler/error.go index 959291a..b5fd1d9 100644 --- a/handler/error.go +++ b/handler/error.go @@ -10,13 +10,14 @@ import ( ) func handleError(w http.ResponseWriter, r *http.Request, err error) { - if errors.Is(err, service.ErrUnauthorized) { + switch { + case errors.Is(err, service.ErrUnauthorized): utils.TriggerToastWithStatus(w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized) return - } else if errors.Is(err, service.ErrBadRequest) { + case errors.Is(err, service.ErrBadRequest): utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest) return - } else if errors.Is(err, db.ErrNotFound) { + case errors.Is(err, db.ErrNotFound): utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusNotFound) return } diff --git a/handler/middleware/authenticate.go b/handler/middleware/authenticate.go index 48cb934..2742f49 100644 --- a/handler/middleware/authenticate.go +++ b/handler/middleware/authenticate.go @@ -16,7 +16,6 @@ var UserKey ContextKey = "user" func Authenticate(service service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sessionId := getSessionID(r) session, user, _ := service.SignInSession(sessionId) @@ -49,7 +48,12 @@ func GetUser(r *http.Request) *types.User { return nil } - return obj.(*types.User) + user, ok := obj.(*types.User) + if !ok { + return nil + } + + return user } func GetSession(r *http.Request) *types.Session { @@ -58,7 +62,12 @@ func GetSession(r *http.Request) *types.Session { return nil } - return obj.(*types.Session) + session, ok := obj.(*types.Session) + if !ok { + return nil + } + + return session } func getSessionID(r *http.Request) string { diff --git a/handler/middleware/cache_control.go b/handler/middleware/cache_control.go index 32d1c56..f6bc7a0 100644 --- a/handler/middleware/cache_control.go +++ b/handler/middleware/cache_control.go @@ -7,7 +7,6 @@ import ( func CacheControl(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - shouldCache := strings.HasPrefix(r.URL.Path, "/static") if !shouldCache { diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 90b5981..d165978 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -37,19 +37,17 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) { func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := GetSession(r) if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete || r.Method == http.MethodPatch { - - csrfToken := r.Header.Get("csrf-token") + csrfToken := r.Header.Get("Csrf-Token") if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { log.Info("CSRF-Token \"%s\" not correct", csrfToken) - if r.Header.Get("HX-Request") == "true" { + if r.Header.Get("Hx-Request") == "true" { utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) } else { http.Error(w, "CSRF-Token not correct", http.StatusBadRequest) diff --git a/handler/middleware/gzip.go b/handler/middleware/gzip.go index 77f5e45..1d73f8e 100644 --- a/handler/middleware/gzip.go +++ b/handler/middleware/gzip.go @@ -2,6 +2,7 @@ package middleware import ( "compress/gzip" + "errors" "io" "net/http" "strings" @@ -32,8 +33,7 @@ func Gzip(next http.Handler) http.Handler { next.ServeHTTP(wrapper, r) err := gz.Close() - if err != nil && err != http.ErrBodyNotAllowed { - // if err != nil { + if err != nil && !errors.Is(err, http.ErrBodyNotAllowed) { log.Error("Gzip: could not close Writer: %v", err) } }) diff --git a/handler/middleware/security_headers.go b/handler/middleware/security_headers.go index c015e22..948a159 100644 --- a/handler/middleware/security_headers.go +++ b/handler/middleware/security_headers.go @@ -7,7 +7,6 @@ import ( ) func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") @@ -30,7 +29,7 @@ func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Han w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains; preload") - if r.Method == "OPTIONS" { + if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } diff --git a/handler/middleware/wrapper.go b/handler/middleware/wrapper.go index cd5f9af..d4fb102 100644 --- a/handler/middleware/wrapper.go +++ b/handler/middleware/wrapper.go @@ -2,12 +2,12 @@ package middleware import "net/http" -// Chain list of handlers together +// Chain list of handlers together. func Wrapper(next http.Handler, handlers ...func(http.Handler) http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { lastHandler := next - for i := 0; i < len(handlers); i++ { - lastHandler = handlers[i](lastHandler) + for _, handler := range handlers { + lastHandler = handler(lastHandler) } lastHandler.ServeHTTP(w, r) }) diff --git a/handler/render.go b/handler/render.go index 969f79a..7a447a1 100644 --- a/handler/render.go +++ b/handler/render.go @@ -44,7 +44,6 @@ func (render *Render) RenderLayoutWithStatus(r *http.Request, w http.ResponseWri } func (render *Render) getUserComp(user *types.User) templ.Component { - if user != nil { return auth.UserComp(user.Email) } else { diff --git a/handler/transaction_recurring.go b/handler/transaction_recurring.go index 437e31c..d56a5ee 100644 --- a/handler/transaction_recurring.go +++ b/handler/transaction_recurring.go @@ -106,7 +106,6 @@ func (h TransactionRecurringImpl) handleDeleteTransactionRecurring() http.Handle } func (h TransactionRecurringImpl) renderItems(w http.ResponseWriter, r *http.Request, user *types.User, id, accountId, treasureChestId string) { - var transactionsRecurring []*types.TransactionRecurring var err error if accountId == "" && treasureChestId == "" { diff --git a/main.go b/main.go index 062fbab..818a1f7 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package main import ( + "errors" + "fmt" "spend-sparrow/db" "spend-sparrow/handler" "spend-sparrow/handler/middleware" @@ -37,10 +39,14 @@ func main() { log.Fatal("Could not close Database data.db: %v", err) }() - run(context.Background(), db, os.Getenv) + err = run(context.Background(), db, os.Getenv) + if err != nil { + log.Error("Error running server: %v", err) + return + } } -func run(ctx context.Context, database *sqlx.DB, env func(string) string) { +func run(ctx context.Context, database *sqlx.DB, env func(string) string) error { ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) defer cancel() @@ -52,22 +58,24 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) { // init db err := db.RunMigrations(database, "") if err != nil { - log.Fatal("Could not run migrations: %v", err) + return fmt.Errorf("could not run migrations: %w", err) } // init servers var prometheusServer *http.Server if serverSettings.PrometheusEnabled { prometheusServer := &http.Server{ - Addr: ":8081", - Handler: promhttp.Handler(), + Addr: ":8081", + Handler: promhttp.Handler(), + ReadHeaderTimeout: 10 * time.Second, } go startServer(prometheusServer) } httpServer := &http.Server{ - Addr: ":" + serverSettings.Port, - Handler: createHandler(database, serverSettings), + Addr: ":" + serverSettings.Port, + Handler: createHandler(database, serverSettings), + ReadHeaderTimeout: 10 * time.Second, } go startServer(httpServer) @@ -77,11 +85,13 @@ func run(ctx context.Context, database *sqlx.DB, env func(string) string) { go shutdownServer(httpServer, ctx, &wg) go shutdownServer(prometheusServer, ctx, &wg) wg.Wait() + + return nil } func startServer(s *http.Server) { log.Info("Starting server on %q", s.Addr) - if err := s.ListenAndServe(); err != nil && err != http.ErrServerClosed { + if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Error("error listening and serving: %v", err) } } diff --git a/main_test.go b/main_test.go index 832cf1d..b87ba03 100644 --- a/main_test.go +++ b/main_test.go @@ -3,10 +3,10 @@ package main import ( "context" "encoding/json" - "fmt" "io" "net/http" "net/url" + "strconv" "strings" "sync/atomic" "testing" @@ -18,6 +18,7 @@ import ( "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/html" ) @@ -29,7 +30,7 @@ var ( }, } - port atomic.Int32 + port atomic.Int64 ) func TestIntegrationSecurityHeader(t *testing.T) { @@ -39,25 +40,27 @@ func TestIntegrationSecurityHeader(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/static/favicon.svg", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/static/favicon.svg", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() cacheControl := resp.Header.Get("Cache-Control") - assert.Equal(t, "", cacheControl) + assert.Empty(t, cacheControl) }) t.Run("should disable caching for dynamic content", func(t *testing.T) { t.Parallel() _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath, nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath, nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() cacheControl := resp.Header.Get("Cache-Control") assert.Equal(t, "no-cache, no-store, must-revalidate", cacheControl) @@ -67,11 +70,12 @@ func TestIntegrationSecurityHeader(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath, nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath, nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() value := resp.Header.Get("X-Content-Type-Options") assert.Equal(t, "nosniff", value) @@ -116,6 +120,7 @@ func TestIntegrationAuth(t *testing.T) { t.Parallel() t.Run("SignIn", func(t *testing.T) { + t.Parallel() t.Run(`should redirect to "/" if user is already signed in`, func(t *testing.T) { t.Parallel() @@ -128,17 +133,18 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/", resp.Header.Get("Location")) }) @@ -153,19 +159,20 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) formData := url.Values{ "email": {"mail@mail.de"}, "password": {"password"}, } - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("csrf-token", "invalid-csrf-token") + req.Header.Set("Csrf-Token", "invalid-csrf-token") resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) }) t.Run(`should fail with invalid username`, func(t *testing.T) { @@ -179,34 +186,37 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() + csrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", csrfToken) - session := findCookie(resp, "id") + assert.NotEmpty(t, csrfToken) + session := findCookie(t, resp) formData := url.Values{ "email": {"invalid@mail.de"}, "password": {"password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+session.Value) - req.Header.Set("csrf-token", csrfToken) - req.Header.Set("HX-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) + req.Header.Set("Hx-Request", "true") resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Contains(t, resp.Header.Get("HX-Trigger"), "email or password") + assert.Contains(t, resp.Header.Get("Hx-Trigger"), "email or password") }) t.Run(`should fail with invalid password`, func(t *testing.T) { t.Parallel() @@ -219,34 +229,36 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() csrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", csrfToken) - session := findCookie(resp, "id") + assert.NotEmpty(t, csrfToken) + session := findCookie(t, resp) formData := url.Values{ "email": {"mail@mail.de"}, "password": {"invalid-password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+session.Value) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Contains(t, resp.Header.Get("HX-Trigger"), "email or password") + assert.Contains(t, resp.Header.Get("Hx-Trigger"), "email or password") }) t.Run("should return secure cookie with NEW session-id", func(t *testing.T) { t.Parallel() @@ -255,41 +267,40 @@ func TestIntegrationAuth(t *testing.T) { pass := service.GetHashPassword("password", []byte("salt")) _, err := db.Exec(` - INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) - VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) - assert.Nil(t, err) + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) - + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSession := findCookie(resp, "id") + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSession := findCookie(t, resp) assert.NotNil(t, anonymousSession) formData := url.Values{ "email": {"mail@mail.de"}, "password": {"password"}, } - - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Csrf-Token", anonymousCsrfToken) req.Header.Set("Cookie", "id="+anonymousSession.Value) resp, err = httpClient.Do(req) - assert.Nil(t, err) - + require.NoError(t, err) assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + _ = resp.Body.Close() - cookie := findCookie(resp, "id") + cookie := findCookie(t, resp) assert.NotNil(t, cookie) assert.Equal(t, http.SameSiteStrictMode, cookie.SameSite, "Cookie is not secure") assert.True(t, cookie.HttpOnly, "Cookie is not secure") @@ -306,18 +317,19 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) // Everythings correct - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) body, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, body) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSession := findCookie(resp, "id") + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSession := findCookie(t, resp) assert.NotNil(t, anonymousSession) formData := url.Values{ @@ -325,33 +337,35 @@ func TestIntegrationAuth(t *testing.T) { "password": {"password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+anonymousSession.Value) - req.Header.Set("csrf-token", anonymousCsrfToken) - req.Header.Set("HX-Request", "true") + req.Header.Set("Csrf-Token", anonymousCsrfToken) + req.Header.Set("Hx-Request", "true") timeStart := time.Now() resp, err = httpClient.Do(req) timeEnd := time.Now() - assert.Nil(t, err) + require.NoError(t, err) if timeEnd.Sub(timeStart) > 260*time.Millisecond || timeEnd.Sub(timeStart) < 250*time.Millisecond { t.Fail() t.Logf("Time did not match: %v", timeEnd.Sub(timeStart)) } assert.Equal(t, http.StatusOK, resp.StatusCode) + _ = resp.Body.Close() - //Wrong password - req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + // Wrong password + req, err = http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) body, err = html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken = findCsrfToken(t, body) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSession = findCookie(resp, "id") + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSession = findCookie(t, resp) assert.NotNil(t, anonymousSession) formData = url.Values{ @@ -359,33 +373,35 @@ func TestIntegrationAuth(t *testing.T) { "password": {"wrong-password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+anonymousSession.Value) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", anonymousCsrfToken) timeStart = time.Now() resp, err = httpClient.Do(req) timeEnd = time.Now() - assert.Nil(t, err) + require.NoError(t, err) if timeEnd.Sub(timeStart) > 260*time.Millisecond || timeEnd.Sub(timeStart) <= 250*time.Millisecond { t.Fail() t.Logf("Time did not match: %v", timeEnd.Sub(timeStart)) } assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + _ = resp.Body.Close() - //Wrong username - req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + // Wrong username + req, err = http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) body, err = html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken = findCsrfToken(t, body) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSession = findCookie(resp, "id") + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSession = findCookie(t, resp) assert.NotNil(t, anonymousSession) formData = url.Values{ @@ -393,22 +409,23 @@ func TestIntegrationAuth(t *testing.T) { "password": {"password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+anonymousSession.Value) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", anonymousCsrfToken) timeStart = time.Now() resp, err = httpClient.Do(req) timeEnd = time.Now() - assert.Nil(t, err) + require.NoError(t, err) if timeEnd.Sub(timeStart) > 260*time.Millisecond || timeEnd.Sub(timeStart) <= 250*time.Millisecond { t.Fail() t.Logf("Time did not match: %v", timeEnd.Sub(timeStart)) } assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + _ = resp.Body.Close() }) t.Run("should create new session and invalidate old one (session fixation prevention)", func(t *testing.T) { t.Parallel() @@ -419,18 +436,19 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSession := findCookie(resp, "id") + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSession := findCookie(t, resp) assert.NotNil(t, anonymousSession) formData := url.Values{ @@ -438,23 +456,24 @@ func TestIntegrationAuth(t *testing.T) { "password": {"password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+anonymousSession.Value) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", anonymousCsrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) var rows int err = db.QueryRow("SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) err = db.QueryRow("SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) }) }) @@ -471,17 +490,18 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/", resp.Header.Get("Location")) @@ -495,48 +515,51 @@ func TestIntegrationAuth(t *testing.T) { "email": {"mail@mail.de"}, "password": {"password"}, } - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", "invalid-csrf-token") + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", "invalid-csrf-token") resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - assert.Contains(t, resp.Header.Get("HX-Trigger"), "CSRF") + assert.Contains(t, resp.Header.Get("Hx-Trigger"), "CSRF") }) t.Run(`should fail if password is insecure`, func(t *testing.T) { t.Parallel() _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signup", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) body, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, body) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSession := findCookie(resp, "id") + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSession := findCookie(t, resp) assert.NotNil(t, anonymousSession) formData := url.Values{ "email": {"mail@mail.de"}, "password": {"insecure-password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signup", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signup", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("HX-Request", "true") + req.Header.Set("Hx-Request", "true") req.Header.Set("Cookie", "id="+anonymousSession.Value) - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Csrf-Token", anonymousCsrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - assert.Contains(t, resp.Header.Get("HX-Trigger"), "password") + assert.Contains(t, resp.Header.Get("Hx-Trigger"), "password") }) t.Run(`should say "verification mail send" if user already exists within ~250 ms`, func(t *testing.T) { t.Parallel() @@ -546,85 +569,89 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, uuid.New(), service.GetHashPassword("password", []byte("salt")), []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signup", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) body, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, body) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSession := findCookie(resp, "id") + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSession := findCookie(t, resp) assert.NotNil(t, anonymousSession) formData := url.Values{ "email": {"mail@mail.de"}, "password": {"secure-Password!1"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signup", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signup", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("HX-Request", "true") + req.Header.Set("Hx-Request", "true") req.Header.Set("Cookie", "id="+anonymousSession.Value) - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Csrf-Token", anonymousCsrfToken) timeStart := time.Now() resp, err = httpClient.Do(req) timeEnd := time.Now() - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() timeTaken := timeEnd.Sub(timeStart) assert.LessOrEqual(t, timeTaken, 260*time.Millisecond) assert.GreaterOrEqual(t, timeTaken, 250*time.Millisecond) assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Contains(t, resp.Header.Get("HX-Trigger"), "An activation link has been send to your email") + assert.Contains(t, resp.Header.Get("Hx-Trigger"), "An activation link has been send to your email") }) t.Run(`should say "verification mail send" within ~250 ms`, func(t *testing.T) { t.Parallel() db, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signup", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signup", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) body, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, body) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSession := findCookie(resp, "id") + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSession := findCookie(t, resp) assert.NotNil(t, anonymousSession) formData := url.Values{ "email": {"mail@mail.de"}, "password": {"secure-Password!1"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signup", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signup", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("HX-Request", "true") + req.Header.Set("Hx-Request", "true") req.Header.Set("Cookie", "id="+anonymousSession.Value) - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Csrf-Token", anonymousCsrfToken) timeStart := time.Now() resp, err = httpClient.Do(req) timeEnd := time.Now() - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() timeTaken := timeEnd.Sub(timeStart) assert.LessOrEqual(t, timeTaken, 260*time.Millisecond) assert.GreaterOrEqual(t, timeTaken, 250*time.Millisecond) assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Contains(t, resp.Header.Get("HX-Trigger"), "An activation link has been send to your email") + assert.Contains(t, resp.Header.Get("Hx-Trigger"), "An activation link has been send to your email") var rows int err = db.QueryRow("SELECT COUNT(*) FROM user WHERE email = ? AND email_verified = FALSE", "mail@mail.de").Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) var token string err = db.QueryRow("SELECT t.token FROM token t INNER JOIN user u ON u.user_id = t.user_id WHERE u.email = ? AND t.type = ?", "mail@mail.de", types.TokenTypeEmailVerify).Scan(&token) - assert.Nil(t, err) - assert.NotEqual(t, "", token) + require.NoError(t, err) + assert.NotEmpty(t, token) }) }) t.Run("SignUpVerification", func(t *testing.T) { @@ -638,18 +665,19 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/verify-email?token=invalid-token", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token=invalid-token", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) t.Run(`should fail verifying email with outdated token`, func(t *testing.T) { @@ -663,22 +691,23 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO token (token, user_id, type, created_at, expires_at) VALUES (?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, types.TokenTypeEmailVerify) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/verify-email?token="+token, nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = FALSE", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) t.Run(`should verify email with correct token`, func(t *testing.T) { @@ -692,22 +721,23 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, "", ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, types.TokenTypeEmailVerify) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/verify-email?token="+token, nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/verify-email?token="+token, nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) var rows int err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND email_verified = TRUE", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) }) @@ -717,11 +747,12 @@ func TestIntegrationAuth(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/sign-out", nil) - assert.Nil(t, err) - req.Header.Set("csrf-token", "invalid-csrf-token") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/sign-out", nil) + require.NoError(t, err) + req.Header.Set("Csrf-Token", "invalid-csrf-token") resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) }) @@ -737,41 +768,43 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) var csrfToken string err = db.QueryRow("SELECT token FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypeCsrf).Scan(&csrfToken) - assert.Nil(t, err) + require.NoError(t, err) - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signout", nil) - assert.Nil(t, err) - req.Header.Set("csrf-token", csrfToken) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/signout", nil) + require.NoError(t, err) + req.Header.Set("Csrf-Token", csrfToken) req.Header.Set("Cookie", "id="+sessionId) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/", resp.Header.Get("Location")) - cookie := findCookie(resp, "id") + cookie := findCookie(t, resp) assert.NotNil(t, cookie) - assert.Equal(t, "", cookie.Value) + assert.Empty(t, cookie.Value) assert.Equal(t, -1, cookie.MaxAge) var rows int err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) }) }) @@ -781,10 +814,11 @@ func TestIntegrationAuth(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/delete-account", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/delete-account", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) @@ -794,10 +828,11 @@ func TestIntegrationAuth(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/delete-account", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/delete-account", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) }) @@ -813,35 +848,37 @@ func TestIntegrationAuth(t *testing.T) { VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/delete-account", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/delete-account", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() csrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", csrfToken) + assert.NotEmpty(t, csrfToken) formData := url.Values{ "password": {"wrong-password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/delete-account", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/delete-account", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) }) @@ -857,24 +894,25 @@ func TestIntegrationAuth(t *testing.T) { VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) formData := url.Values{ "password": {"password"}, } - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/delete-account", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/delete-account", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", "wrong-csrf-token") + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", "wrong-csrf-token") resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) }) @@ -883,80 +921,84 @@ func TestIntegrationAuth(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) - userId, csrfToken, sessionId := createValidUserSession(t, db, ctx, basePath, "") + userId, csrfToken, sessionId := createValidUserSession(t, db, "") formData := url.Values{ "name": {"Name"}, } - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/account/new", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("csrf-token", csrfToken) - req.Header.Set("HX-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) + req.Header.Set("Hx-Request", "true") resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) formData = url.Values{ "name": {"Name"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/treasurechest/new", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/treasurechest/new", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) formData = url.Values{ "timestamp": {"2006-01-02"}, "value": {"100.00"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/transaction/new", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/transaction/new", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) formData = url.Values{ "password": {"password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/delete-account", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/delete-account", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) var rows int err = db.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ?", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ?", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) err = db.QueryRow("SELECT COUNT(*) FROM account WHERE user_id = ?", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) err = db.QueryRow("SELECT COUNT(*) FROM treasure_chest WHERE user_id = ?", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) err = db.QueryRow("SELECT COUNT(*) FROM \"transaction\" WHERE user_id = ?", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) }) }) @@ -967,10 +1009,11 @@ func TestIntegrationAuth(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/change-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/change-password", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) @@ -980,29 +1023,31 @@ func TestIntegrationAuth(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", anonymousCsrfToken) - anonymousSessionId := findCookie(resp, "id").Value - assert.NotEqual(t, "", anonymousSessionId) + assert.NotEmpty(t, anonymousCsrfToken) + anonymousSessionId := findCookie(t, resp).Value + assert.NotEmpty(t, anonymousSessionId) formData := url.Values{ "current-password": {"password"}, "new-password": {"MyNewSecurePassword1!"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+anonymousSessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", anonymousCsrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) }) @@ -1018,31 +1063,32 @@ func TestIntegrationAuth(t *testing.T) { VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) formData := url.Values{ "current-password": {"password"}, "new-password": {"MyNewSecurePassword1!"}, } - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", "invalid-csrf-token") + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", "invalid-csrf-token") resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) t.Run("should fail if current password does not match", func(t *testing.T) { @@ -1057,40 +1103,42 @@ func TestIntegrationAuth(t *testing.T) { VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/change-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/change-password", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() csrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", csrfToken) + assert.NotEmpty(t, csrfToken) formData := url.Values{ "current-password": {"wrong-password"}, "new-password": {"MyNewSecurePassword1!"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) t.Run("should fail if new password is insecure", func(t *testing.T) { @@ -1105,40 +1153,42 @@ func TestIntegrationAuth(t *testing.T) { VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/change-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/change-password", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() csrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", csrfToken) + assert.NotEmpty(t, csrfToken) formData := url.Values{ "current-password": {"password"}, "new-password": {"insecure-password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) t.Run("should change password and invalidate all other user sessions", func(t *testing.T) { @@ -1154,65 +1204,67 @@ func TestIntegrationAuth(t *testing.T) { VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) sessionId := "session-id" - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES ("second", ?, datetime(), datetime("now", "+1 day"))`, userId) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES ("other", ?, datetime(), datetime("now", "+1 day"))`, userIdOther) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/change-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/change-password", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() csrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", csrfToken) + assert.NotEmpty(t, csrfToken) formData := url.Values{ "current-password": {"password"}, "new-password": {"MyNewSecurePassword1!"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/change-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) pass = service.GetHashPassword("MyNewSecurePassword1!", []byte("salt")) var rows int err = db.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) var sessionIds []string sessions, err := db.Query(`SELECT session_id FROM session WHERE NOT user_id = ? ORDER BY session_id`, uuid.Nil) - assert.Nil(t, err) + require.NoError(t, err) for sessions.Next() { var sessionId string err = sessions.Scan(&sessionId) - assert.Nil(t, err) + require.NoError(t, err) sessionIds = append(sessionIds, sessionId) } - assert.Equal(t, 2, len(sessionIds)) + assert.Len(t, sessionIds, 2) assert.Equal(t, "other", sessionIds[0]) assert.Equal(t, "session-id", sessionIds[1]) }) @@ -1229,19 +1281,20 @@ func TestIntegrationAuth(t *testing.T) { _, err := d.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) sessionId := "session-id" _, err = d.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/forgot-password", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/", resp.Header.Get("Location")) @@ -1256,32 +1309,34 @@ func TestIntegrationAuth(t *testing.T) { _, err := d.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/forgot-password", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) - anonymousSessionId := findCookie(resp, "id").Value - assert.NotEqual(t, "", anonymousSessionId) + anonymousSessionId := findCookie(t, resp).Value + assert.NotEmpty(t, anonymousSessionId) formData := url.Values{ "email": {"mail@mail.de"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", "invalid-csrf-token") + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", "invalid-csrf-token") resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = d.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) }) t.Run(`should fail but respond with uniform message`, func(t *testing.T) { @@ -1289,33 +1344,35 @@ func TestIntegrationAuth(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/forgot-password", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - anonymousSessionId := findCookie(resp, "id").Value - assert.NotEqual(t, "", anonymousSessionId) + anonymousSessionId := findCookie(t, resp).Value + assert.NotEmpty(t, anonymousSessionId) body, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, body) formData := url.Values{ "email": {"non-existent@mail.de"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("HX-Request", "true") + req.Header.Set("Hx-Request", "true") req.Header.Set("Cookie", "id="+anonymousSessionId) - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Csrf-Token", anonymousCsrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) msg := "If the address exists, an email has been sent." - assert.Contains(t, resp.Header.Get("HX-Trigger"), msg) + assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg) }) t.Run(`should generate token and respond with uniform message`, func(t *testing.T) { t.Parallel() @@ -1327,39 +1384,41 @@ func TestIntegrationAuth(t *testing.T) { _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", TRUE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/forgot-password", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - anonymousSessionId := findCookie(resp, "id").Value - assert.NotEqual(t, "", anonymousSessionId) + anonymousSessionId := findCookie(t, resp).Value + assert.NotEmpty(t, anonymousSessionId) body, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, body) formData := url.Values{ "email": {"mail@mail.de"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("HX-Request", "true") + req.Header.Set("Hx-Request", "true") req.Header.Set("Cookie", "id="+anonymousSessionId) - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Csrf-Token", anonymousCsrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) msg := "If the address exists, an email has been sent." - assert.Contains(t, resp.Header.Get("HX-Trigger"), msg) + assert.Contains(t, resp.Header.Get("Hx-Trigger"), msg) var rows int err = db.QueryRow("SELECT COUNT(*) FROM token WHERE user_id = ? AND type = ?", userId, types.TokenTypePasswordReset).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) }) @@ -1375,35 +1434,37 @@ func TestIntegrationAuth(t *testing.T) { _, err := d.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/forgot-password", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) - anonymousSessionId := findCookie(resp, "id").Value + require.NoError(t, err) + anonymousSessionId := findCookie(t, resp).Value html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", anonymousCsrfToken) + assert.NotEmpty(t, anonymousCsrfToken) formData := url.Values{ "new-password": {"MyNewSecurePassword1!"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+anonymousSessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("HX-Current-URL", basePath+"/auth/change-password?token=invalidToken") - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Hx-Current-Url", basePath+"/auth/change-password?token=invalidToken") + req.Header.Set("Csrf-Token", anonymousCsrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) t.Run(`should fail if token is outdated`, func(t *testing.T) { @@ -1416,41 +1477,43 @@ func TestIntegrationAuth(t *testing.T) { _, err := d.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/forgot-password", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) - anonymousSessionId := findCookie(resp, "id").Value + require.NoError(t, err) + anonymousSessionId := findCookie(t, resp).Value html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", anonymousCsrfToken) + assert.NotEmpty(t, anonymousCsrfToken) token := "password-reset-token" _, err = d.Exec(` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, ?, ?, datetime("now", "-16 minute"), datetime("now", "-1 minute"))`, token, userId, "", types.TokenTypePasswordReset) - assert.Nil(t, err) + require.NoError(t, err) formData := url.Values{ "new-password": {"MyNewSecurePassword1!"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+anonymousSessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("HX-Current-URL", basePath+"/auth/change-password?token="+url.QueryEscape(token)) - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Hx-Current-Url", basePath+"/auth/change-password?token="+url.QueryEscape(token)) + req.Header.Set("Csrf-Token", anonymousCsrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) t.Run(`should fail if password is insecure`, func(t *testing.T) { @@ -1463,41 +1526,43 @@ func TestIntegrationAuth(t *testing.T) { _, err := d.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/forgot-password", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) - anonymousSessionId := findCookie(resp, "id").Value + require.NoError(t, err) + anonymousSessionId := findCookie(t, resp).Value html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() anonymousCsrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", anonymousCsrfToken) + assert.NotEmpty(t, anonymousCsrfToken) token := "password-reset-token" _, err = d.Exec(` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, ?, ?, datetime("now"), datetime("now", "+15 minute"))`, token, userId, "", types.TokenTypePasswordReset) - assert.Nil(t, err) + require.NoError(t, err) formData := url.Values{ "new-password": {"insecure-password"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+anonymousSessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("HX-Current-URL", basePath+"/auth/change-password?token="+url.QueryEscape(token)) - req.Header.Set("csrf-token", anonymousCsrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Hx-Current-Url", basePath+"/auth/change-password?token="+url.QueryEscape(token)) + req.Header.Set("Csrf-Token", anonymousCsrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusBadRequest, resp.StatusCode) var rows int err = d.QueryRow("SELECT COUNT(*) FROM user WHERE user_id = ? AND password = ?", userId, pass).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 1, rows) }) t.Run("should change password and invalidate ALL sessions", func(t *testing.T) { @@ -1511,56 +1576,58 @@ func TestIntegrationAuth(t *testing.T) { INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = d.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES ("session-id", ?, datetime(), datetime("now", "+1 day"))`, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/forgot-password", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/forgot-password", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) - - sessionId := findCookie(resp, "id").Value + require.NoError(t, err) + sessionId := findCookie(t, resp).Value html, err := html.Parse(resp.Body) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() csrfToken := findCsrfToken(t, html) - assert.NotEqual(t, "", csrfToken) + assert.NotEmpty(t, csrfToken) formData := url.Values{ "email": {"mail@mail.de"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/forgot-password", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) var token string err = d.QueryRow("SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token) - assert.Nil(t, err) + require.NoError(t, err) formData = url.Values{ "new-password": {"MyNewSecurePassword1!"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/api/auth/forgot-password-actual", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("HX-Request", "true") - req.Header.Set("csrf-token", csrfToken) - req.Header.Set("HX-Current-URL", basePath+"/auth/change-password?token="+url.QueryEscape(token)) + req.Header.Set("Hx-Request", "true") + req.Header.Set("Csrf-Token", csrfToken) + req.Header.Set("Hx-Current-Url", basePath+"/auth/change-password?token="+url.QueryEscape(token)) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) sessions, err := d.Query("SELECT session_id FROM session WHERE user_id = ?", userId) - assert.Nil(t, err) + require.NoError(t, err) assert.False(t, sessions.Next()) }) }) @@ -1577,25 +1644,25 @@ func TestIntegrationAuth(t *testing.T) { _, err := d.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = d.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath, nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath, nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) - - newSession := findCookie(resp, "id") + require.NoError(t, err) + _ = resp.Body.Close() + newSession := findCookie(t, resp) assert.NotNil(t, newSession) assert.NotEqual(t, sessionId, newSession.Value) var rows int err = d.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 0, rows) }) t.Run("should create anonymous session", func(t *testing.T) { @@ -1603,14 +1670,15 @@ func TestIntegrationAuth(t *testing.T) { _, basePath, ctx := setupIntegrationTest(t) - req, err := http.NewRequestWithContext(ctx, "GET", basePath, nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath, nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() - newSession := findCookie(resp, "id") + newSession := findCookie(t, resp) assert.NotNil(t, newSession) - assert.NotEqual(t, "", newSession.Value) + assert.NotEmpty(t, newSession.Value) }) t.Run("should not have access to user information with outdated session", func(t *testing.T) { t.Parallel() @@ -1623,17 +1691,18 @@ func TestIntegrationAuth(t *testing.T) { _, err := d.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = d.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/account", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/account", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) @@ -1645,6 +1714,7 @@ func TestIntegrationAccount(t *testing.T) { t.Parallel() t.Run("SignIn", func(t *testing.T) { + t.Parallel() t.Run(`should throw unauthorized if try to getAll, get, edit, insert or delete`, func(t *testing.T) { t.Parallel() @@ -1652,40 +1722,44 @@ func TestIntegrationAccount(t *testing.T) { csrfToken, sessionId := createAnonymousSession(t, ctx, basePath) - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/account", nil) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/account", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) - req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/some-id", nil) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/account/some-id", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) formData := url.Values{ "name": {"name"}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/account/some-id", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/account/some-id", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) - req, err = http.NewRequestWithContext(ctx, "DELETE", basePath+"/account/some-id", nil) - assert.Nil(t, err) - req.Header.Set("csrf-token", csrfToken) + req, err = http.NewRequestWithContext(ctx, http.MethodDelete, basePath+"/account/some-id", nil) + require.NoError(t, err) + req.Header.Set("Csrf-Token", csrfToken) req.Header.Set("Cookie", "id="+sessionId) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusSeeOther, resp.StatusCode) assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) }) @@ -1694,105 +1768,112 @@ func TestIntegrationAccount(t *testing.T) { db, basePath, ctx := setupIntegrationTest(t) - _, csrfToken, sessionId := createValidUserSession(t, db, ctx, basePath, "") + _, csrfToken, sessionId := createValidUserSession(t, db, "") // Insert expectedName := "My great Account" formData := url.Values{ "name": {expectedName}, } - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/account/new", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Csrf-Token", csrfToken) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Contains(t, readBody(t, resp.Body), expectedName) + _ = resp.Body.Close() var id uuid.UUID err = db.Get(&id, "SELECT id FROM account") - assert.Nil(t, err) + require.NoError(t, err) // Update expectedNewName := "My new Account" formData = url.Values{ "name": {expectedNewName}, } - req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/account/"+id.String(), strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/account/"+id.String(), strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Contains(t, readBody(t, resp.Body), expectedNewName) + _ = resp.Body.Close() // Get - req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/"+id.String(), nil) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/account/"+id.String(), nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Contains(t, readBody(t, resp.Body), expectedNewName) + _ = resp.Body.Close() // Delete - req, err = http.NewRequestWithContext(ctx, "DELETE", basePath+"/account/"+id.String(), strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodDelete, basePath+"/account/"+id.String(), strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Csrf-Token", csrfToken) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) // Get (not found) - req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/"+id.String(), nil) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/account/"+id.String(), nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusNotFound, resp.StatusCode) assert.NotContains(t, readBody(t, resp.Body), expectedNewName) + _ = resp.Body.Close() }) t.Run(`should not be able to see other users content`, func(t *testing.T) { t.Parallel() db, basePath, ctx := setupIntegrationTest(t) - _, csrfToken1, sessionId1 := createValidUserSession(t, db, ctx, basePath, "1") - _, _, sessionId2 := createValidUserSession(t, db, ctx, basePath, "2") + _, csrfToken1, sessionId1 := createValidUserSession(t, db, "1") + _, _, sessionId2 := createValidUserSession(t, db, "2") expectedName1 := "Account 1" formData := url.Values{ "name": {expectedName1}, } - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/account/new", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId1) - req.Header.Set("csrf-token", csrfToken1) + req.Header.Set("Csrf-Token", csrfToken1) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) - req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account", nil) - assert.Nil(t, err) + req, err = http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/account", nil) + require.NoError(t, err) req.Header.Set("Cookie", "id="+sessionId2) resp, err = httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) assert.NotContains(t, expectedName1, readBody(t, resp.Body)) + _ = resp.Body.Close() }) t.Run(`should prohibit special characters in name`, func(t *testing.T) { t.Parallel() db, basePath, ctx := setupIntegrationTest(t) - _, csrfToken, sessionId := createValidUserSession(t, db, ctx, basePath, "") + _, csrfToken, sessionId := createValidUserSession(t, db, "") data := map[string]int{ "<": 400, @@ -1807,26 +1888,26 @@ func TestIntegrationAccount(t *testing.T) { } for name, status := range data { - formData := url.Values{ "name": {name}, } - req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode())) - assert.Nil(t, err) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, basePath+"/account/new", strings.NewReader(formData.Encode())) + require.NoError(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", "id="+sessionId) - req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Csrf-Token", csrfToken) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) + _ = resp.Body.Close() assert.Equal(t, status, resp.StatusCode, "for name: "+name) } - }) }) } -func createValidUserSession(t *testing.T, db *sqlx.DB, ctx context.Context, basePath string, add string) (uuid.UUID, string, string) { +func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, string, string) { + t.Helper() userId := uuid.New() sessionId := "session-id" + add pass := service.GetHashPassword("password", []byte("salt")) @@ -1836,35 +1917,40 @@ func createValidUserSession(t *testing.T, db *sqlx.DB, ctx context.Context, base _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt")) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) - assert.Nil(t, err) + require.NoError(t, err) _, err = db.Exec(` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf) - assert.Nil(t, err) + require.NoError(t, err) return userId, csrfToken, sessionId } func createAnonymousSession(t *testing.T, ctx context.Context, basePath string) (string, string) { - req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) - assert.Nil(t, err) + t.Helper() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) + require.NoError(t, err) resp, err := httpClient.Do(req) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) html, err := html.Parse(resp.Body) - assert.Nil(t, err) + _ = resp.Body.Close() + require.NoError(t, err) - return findCsrfToken(t, html), findCookie(resp, "id").Value + return findCsrfToken(t, html), findCookie(t, resp).Value } -func findCookie(resp *http.Response, name string) *http.Cookie { +func findCookie(t *testing.T, resp *http.Response) *http.Cookie { + t.Helper() + for _, cookie := range resp.Cookies() { - if cookie.Name == name { + if cookie.Name == "id" { return cookie } } @@ -1872,7 +1958,9 @@ func findCookie(resp *http.Response, name string) *http.Cookie { return nil } -func setupIntegrationTest(t *testing.T) (db *sqlx.DB, basePath string, ctx context.Context) { +func setupIntegrationTest(t *testing.T) (*sqlx.DB, string, context.Context) { + t.Helper() + ctx, done := context.WithCancel(context.Background()) t.Cleanup(done) @@ -1890,27 +1978,29 @@ func setupIntegrationTest(t *testing.T) (db *sqlx.DB, basePath string, ctx conte testPort := port.Add(1) testPort += 1024 - go run(ctx, db, getEnv(testPort)) + go func() { + _ = run(ctx, db, getEnv(testPort)) + }() - basePath = "http://localhost:" + fmt.Sprint(testPort) + basePath := "http://localhost:" + strconv.Itoa(int(testPort)) - err = waitForReady(ctx, 5*time.Second, basePath, t) - assert.Nil(t, err) + err = waitForReady(t, ctx, 5*time.Second, basePath) + require.NoError(t, err) return db, basePath, ctx } -func getEnv(port int32) func(string) string { +func getEnv(port int64) func(string) string { return func(key string) string { switch key { case "PORT": - return fmt.Sprint(port) + return strconv.Itoa(int(port)) case "SMTP_ENABLED": return "false" case "PROMETHEUS_ENABLED": return "false" case "BASE_URL": - return "http://localhost:" + fmt.Sprint(port) + return "http://localhost:" + strconv.Itoa(int(port)) case "ENVIRONMENT": return "test" default: @@ -1923,16 +2013,18 @@ func getEnv(port int32) func(string) string { // response or until the context is cancelled or the timeout is // reached. func waitForReady( + t *testing.T, ctx context.Context, timeout time.Duration, endpoint string, - t *testing.T, ) error { + t.Helper() + client := http.Client{} startTime := time.Now() for { req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - assert.Nil(t, err) + require.NoError(t, err) resp, err := client.Do(req) if err == nil && resp.StatusCode == http.StatusOK { @@ -1959,6 +2051,8 @@ func waitForReady( } func findCsrfToken(t *testing.T, data *html.Node) string { + t.Helper() + token := getTokenAttribute(t, data) if token != "" { return token @@ -1979,22 +2073,28 @@ func findCsrfToken(t *testing.T, data *html.Node) string { } func getTokenAttribute(t *testing.T, data *html.Node) string { + t.Helper() + for _, attr := range data.Attr { if attr.Key == "hx-headers" { var data map[string]interface{} err := json.Unmarshal([]byte(attr.Val), &data) - assert.Nil(t, err) - return data["csrf-token"].(string) + require.NoError(t, err) + result, ok := data["Csrf-Token"].(string) + if !ok { + return "" + } + return result } } return "" } func readBody(t *testing.T, body io.ReadCloser) string { + t.Helper() + data, err := io.ReadAll(body) - assert.Nil(t, err) - err = body.Close() - assert.Nil(t, err) + require.NoError(t, err) return string(data) } diff --git a/service/account.go b/service/account.go index 3e5ca5d..087a4e1 100644 --- a/service/account.go +++ b/service/account.go @@ -1,6 +1,7 @@ package service import ( + "errors" "fmt" "spend-sparrow/db" @@ -119,7 +120,7 @@ func (s AccountImpl) UpdateName(user *types.User, id string, name string) (*type err = tx.Get(&account, `SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("account Update", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest) } return nil, types.ErrInternal @@ -164,8 +165,8 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - account := &types.Account{} - err = s.db.Get(account, ` + var account types.Account + err = s.db.Get(&account, ` SELECT * FROM account WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("account Get", nil, err) if err != nil { @@ -173,7 +174,7 @@ func (s AccountImpl) Get(user *types.User, id string) (*types.Account, error) { return nil, err } - return account, nil + return &account, nil } func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) { diff --git a/service/auth.go b/service/auth.go index 27ba7c9..77408cb 100644 --- a/service/auth.go +++ b/service/auth.go @@ -94,30 +94,6 @@ func (service AuthImpl) SignIn(session *types.Session, email string, password st return session, user, nil } -func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error { - if session == nil { - return nil - } - - err := service.db.DeleteSession(session.Id) - if err != nil { - return types.ErrInternal - } - - tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf) - if err != nil { - return types.ErrInternal - } - for _, token := range tokens { - err = service.db.DeleteToken(token.Token) - if err != nil { - return types.ErrInternal - } - } - - return nil -} - func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) { if sessionId == "" { return nil, nil, ErrSessionIdInvalid @@ -155,30 +131,6 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) { return session, nil } -func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) { - sessionId, err := service.random.String(32) - if err != nil { - return nil, types.ErrInternal - } - - err = service.db.DeleteOldSessions(userId) - if err != nil { - return nil, types.ErrInternal - } - - createAt := service.clock.Now() - expiresAt := createAt.Add(24 * time.Hour) - - session := types.NewSession(sessionId, userId, createAt, expiresAt) - - err = service.db.InsertSession(session) - if err != nil { - return nil, types.ErrInternal - } - - return session, nil -} - func (service AuthImpl) SignUp(email string, password string) (*types.User, error) { _, err := mail.ParseAddress(email) if err != nil { @@ -205,7 +157,7 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro err = service.db.InsertUser(user) if err != nil { - if err == db.ErrAlreadyExists { + if errors.Is(err, db.ErrAlreadyExists) { return nil, ErrAccountExists } else { return nil, types.ErrInternal @@ -216,9 +168,8 @@ func (service AuthImpl) SignUp(email string, password string) (*types.User, erro } func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { - tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify) - if err != nil && err != db.ErrNotFound { + if err != nil && !errors.Is(err, db.ErrNotFound) { return } @@ -234,7 +185,13 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { return } - token = types.NewToken(userId, "", newTokenStr, types.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) + token = types.NewToken( + userId, + "", + newTokenStr, + types.TokenTypeEmailVerify, + service.clock.Now(), + service.clock.Now().Add(24*time.Hour)) err = service.db.InsertToken(token) if err != nil { @@ -253,7 +210,6 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { } func (service AuthImpl) VerifyUserEmail(tokenStr string) error { - if tokenStr == "" { return types.ErrInternal } @@ -291,12 +247,10 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error { } func (service AuthImpl) SignOut(sessionId string) error { - return service.db.DeleteSession(sessionId) } func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error { - userDb, err := service.db.GetUser(user.Id) if err != nil { return types.ErrInternal @@ -318,7 +272,6 @@ func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error { } func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error { - if !isPasswordValid(newPass) { return ErrInvalidPassword } @@ -365,14 +318,20 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error { user, err := service.db.GetUserByEmail(email) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil } else { return types.ErrInternal } } - token := types.NewToken(user.Id, "", tokenStr, types.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute)) + token := types.NewToken( + user.Id, + "", + tokenStr, + types.TokenTypePasswordReset, + service.clock.Now(), + service.clock.Now().Add(15*time.Minute)) err = service.db.InsertToken(token) if err != nil { @@ -391,7 +350,6 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error { } func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { - if !isPasswordValid(newPass) { return ErrInvalidPassword } @@ -449,7 +407,6 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool if token.Type != types.TokenTypeCsrf || token.SessionId != sessionId || token.ExpiresAt.Before(service.clock.Now()) { - return false } @@ -472,7 +429,13 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) { return "", types.ErrInternal } - token := types.NewToken(session.UserId, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(8*time.Hour)) + token := types.NewToken( + session.UserId, + session.Id, + tokenStr, + types.TokenTypeCsrf, + service.clock.Now(), + service.clock.Now().Add(8*time.Hour)) err = service.db.InsertToken(token) if err != nil { return "", types.ErrInternal @@ -483,12 +446,59 @@ func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) { return tokenStr, nil } +func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error { + if session == nil { + return nil + } + + err := service.db.DeleteSession(session.Id) + if err != nil { + return types.ErrInternal + } + + tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf) + if err != nil { + return types.ErrInternal + } + for _, token := range tokens { + err = service.db.DeleteToken(token.Token) + if err != nil { + return types.ErrInternal + } + } + + return nil +} + +func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) { + sessionId, err := service.random.String(32) + if err != nil { + return nil, types.ErrInternal + } + + err = service.db.DeleteOldSessions(userId) + if err != nil { + return nil, types.ErrInternal + } + + createAt := service.clock.Now() + expiresAt := createAt.Add(24 * time.Hour) + + session := types.NewSession(sessionId, userId, createAt, expiresAt) + + err = service.db.InsertSession(session) + if err != nil { + return nil, types.ErrInternal + } + + return session, nil +} + func GetHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) } func isPasswordValid(password string) bool { - if len(password) < 8 || !strings.ContainsAny(password, "0123456789") || !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || diff --git a/service/auth_test.go b/service/auth_test.go index aab512a..d69330d 100644 --- a/service/auth_test.go +++ b/service/auth_test.go @@ -1,8 +1,9 @@ -package service +package service_test import ( "spend-sparrow/db" "spend-sparrow/mocks" + "spend-sparrow/service" "spend-sparrow/types" "strings" @@ -12,6 +13,17 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +var ( + settings = types.Settings{ + Port: "", + PrometheusEnabled: false, + BaseUrl: "", + Environment: "test", + Smtp: nil, + } ) func TestSignUp(t *testing.T) { @@ -24,11 +36,11 @@ func TestSignUp(t *testing.T) { mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) - underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) + underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) _, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!") - assert.Equal(t, ErrInvalidEmail, err) + assert.Equal(t, service.ErrInvalidEmail, err) }) t.Run("should check for password complexity", func(t *testing.T) { t.Parallel() @@ -38,7 +50,7 @@ func TestSignUp(t *testing.T) { mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) - underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) + underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) weakPasswords := []string{ "123!ab", // too short @@ -49,7 +61,7 @@ func TestSignUp(t *testing.T) { for _, password := range weakPasswords { _, err := underTest.SignUp("some@valid.email", password) - assert.Equal(t, ErrInvalidPassword, err) + assert.Equal(t, service.ErrInvalidPassword, err) } }) t.Run("should signup correctly", func(t *testing.T) { @@ -66,17 +78,17 @@ func TestSignUp(t *testing.T) { salt := []byte("salt") createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) - expected := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime) + expected := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime) mockRandom.EXPECT().UUID().Return(userId, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil) mockClock.EXPECT().Now().Return(createTime) mockAuthDb.EXPECT().InsertUser(expected).Return(nil) - underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) + underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) actual, err := underTest.SignUp(email, password) - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, expected, actual) }) @@ -93,7 +105,7 @@ func TestSignUp(t *testing.T) { createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) password := "SomeStrongPassword123!" salt := []byte("salt") - user := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime) + user := types.NewUser(userId, email, false, nil, false, service.GetHashPassword(password, salt), salt, createTime) mockRandom.EXPECT().UUID().Return(user.Id, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil) @@ -101,20 +113,25 @@ func TestSignUp(t *testing.T) { mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists) - underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) + underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) _, err := underTest.SignUp(user.Email, password) - assert.Equal(t, ErrAccountExists, err) + assert.Equal(t, service.ErrAccountExists, err) }) } func TestSendVerificationMail(t *testing.T) { - t.Parallel() t.Run("should use stored token and send mail", func(t *testing.T) { t.Parallel() - token := types.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", types.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) + token := types.NewToken( + uuid.New(), + "sessionId", + "someRandomTokenToUse", + types.TokenTypeEmailVerify, + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) tokens := []*types.Token{token} email := "some@email.de" @@ -131,7 +148,7 @@ func TestSendVerificationMail(t *testing.T) { return strings.Contains(message, token.Token) })).Return() - underTest := NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) + underTest := service.NewAuth(mockAuthDb, mockRandom, mockClock, mockMail, &settings) underTest.SendVerificationMail(userId, email) }) diff --git a/service/default.go b/service/default.go index 56ccd3d..defd444 100644 --- a/service/default.go +++ b/service/default.go @@ -5,16 +5,21 @@ import ( "regexp" ) +const ( + DECIMALS_MULTIPLIER = 100 +) + var ( safeInputRegex = regexp.MustCompile(`^[a-zA-Z0-9ÄÖÜäöüß,&'" -]+$`) ) func validateString(value string, fieldName string) error { - if value == "" { + switch { + case value == "": return fmt.Errorf("field \"%s\" needs to be set: %w", fieldName, ErrBadRequest) - } else if !safeInputRegex.MatchString(value) { + case !safeInputRegex.MatchString(value): return fmt.Errorf("use only letters, dashes and spaces for \"%s\": %w", fieldName, ErrBadRequest) - } else { + default: return nil } } diff --git a/service/mail.go b/service/mail.go index 008b636..4efa662 100644 --- a/service/mail.go +++ b/service/mail.go @@ -34,7 +34,19 @@ func (m MailImpl) internalSendMail(to string, subject string, message string) { auth := smtp.PlainAuth("", s.User, s.Pass, s.Host) - msg := fmt.Sprintf("From: %v <%v>\nTo: %v\nSubject: %v\nMIME-version: 1.0;\nContent-Type: text/html; charset=\"UTF-8\";\n\n%v", s.FromName, s.FromMail, to, subject, message) + msg := fmt.Sprintf( + `From: %v <%v> + To: %v + Subject: %v + MIME-version: 1.0; + Content-Type: text/html; charset="UTF-8"; + + %v`, + s.FromName, + s.FromMail, + to, + subject, + message) log.Info("Sending mail to %v", to) err := smtp.SendMail(s.Host+":"+s.Port, auth, s.FromMail, []string{to}, []byte(msg)) diff --git a/service/money.go b/service/money.go deleted file mode 100644 index 0f31e83..0000000 --- a/service/money.go +++ /dev/null @@ -1,8 +0,0 @@ -package service - -type MoneyImpl struct { -} - -func NewMoneyImpl() *MoneyImpl { - return &MoneyImpl{} -} diff --git a/service/money_test.go b/service/money_test.go deleted file mode 100644 index 3578081..0000000 --- a/service/money_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package service - -import ( - "testing" -) - -func TestMoneyCalculation(t *testing.T) { - t.Parallel() - t.Run("should calculate correct oink balance", func(t *testing.T) { - // t.Parallel() - // - // underTest := NewMoneyImpl() - // - // // GIVEN - // timestamp := time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC) - // - // userId := uuid.New() - // - // account := types.Account{ - // Id: uuid.New(), - // UserId: userId, - // - // Type: "Bank", - // Name: "Bank", - // - // CurrentBalance: 0, - // LastTransaction: time.Time{}, - // OinkBalance: 0, - // } - // - // // The PiggyBank is a fictional account. The money it "holds" is actually in the Account - // piggyBank := types.PiggyBank{ - // Id: uuid.New(), - // UserId: userId, - // - // AccountId: account.Id, - // Name: "Car", - // - // CurrentBalance: 0, - // } - // - // savingsPlan := types.SavingsPlan{ - // Id: uuid.New(), - // UserId: userId, - // PiggyBankId: piggyBank.Id, - // - // MonthlySaving: 10, - // - // ValidFrom: timestamp, - // } - // - // transaction1 := types.Transaction{ - // Id: uuid.New(), - // UserId: userId, - // - // AccountId: account.Id, - // - // Value: 20, - // Timestamp: timestamp, - // } - // - // transaction2 := types.Transaction{ - // Id: uuid.New(), - // UserId: userId, - // - // AccountId: account.Id, - // PiggyBankId: &piggyBank.Id, - // - // Value: -1, - // Timestamp: timestamp.Add(1 * time.Hour), - // } - // - // // WHEN - // actual, err := underTest.CalculateAllBalancesInTime(account, piggyBank, savingsPlan, []types.Transaction{transaction1, transaction2}) - // - // // THEN - // assert.Nil(t, err) - // assert.ElementsMatch(t, expected, actual) - }) -} diff --git a/service/transaction.go b/service/transaction.go index d2c2978..15a00cf 100644 --- a/service/transaction.go +++ b/service/transaction.go @@ -1,6 +1,7 @@ package service import ( + "errors" "fmt" "strconv" "time" @@ -73,8 +74,10 @@ func (s TransactionImpl) Add(user *types.User, transactionInput types.Transactio } r, err := tx.NamedExec(` - INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp, party, description, error, created_at, created_by) - VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp, :party, :description, :error, :created_at, :created_by)`, transaction) + INSERT INTO "transaction" (id, user_id, account_id, treasure_chest_id, value, timestamp, + party, description, error, created_at, created_by) + VALUES (:id, :user_id, :account_id, :treasure_chest_id, :value, :timestamp, + :party, :description, :error, :created_at, :created_by)`, transaction) err = db.TransformAndLogDbError("transaction Insert", r, err) if err != nil { return nil, err @@ -135,7 +138,7 @@ func (s TransactionImpl) Update(user *types.User, input types.TransactionInput) err = tx.Get(transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("transaction Update", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("transaction %v not found: %w", input.Id, ErrBadRequest) } return nil, types.ErrInternal @@ -232,7 +235,7 @@ func (s TransactionImpl) Get(user *types.User, id string) (*types.Transaction, e err = s.db.Get(&transaction, `SELECT * FROM "transaction" WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("transaction Get", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("transaction %v not found: %w", id, ErrBadRequest) } return nil, types.ErrInternal @@ -259,7 +262,10 @@ func (s TransactionImpl) GetAll(user *types.User, filter types.TransactionItemsF OR (? = "false" AND error IS NULL) ) ORDER BY timestamp DESC`, - user.Id, filter.AccountId, filter.AccountId, filter.TreasureChestId, filter.TreasureChestId, filter.Error, filter.Error, filter.Error) + user.Id, + filter.AccountId, filter.AccountId, + filter.TreasureChestId, filter.TreasureChestId, + filter.Error, filter.Error, filter.Error) err = db.TransformAndLogDbError("transaction GetAll", nil, err) if err != nil { return nil, err @@ -302,7 +308,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error { WHERE id = ? AND user_id = ?`, transaction.Value, transaction.AccountId, user.Id) err = db.TransformAndLogDbError("transaction Delete", r, err) - if err != nil && err != db.ErrNotFound { + if err != nil && !errors.Is(err, db.ErrNotFound) { return err } } @@ -314,7 +320,7 @@ func (s TransactionImpl) Delete(user *types.User, id string) error { WHERE id = ? AND user_id = ?`, transaction.Value, transaction.TreasureChestId, user.Id) err = db.TransformAndLogDbError("transaction Delete", r, err) - if err != nil && err != db.ErrNotFound { + if err != nil && !errors.Is(err, db.ErrNotFound) { return err } } @@ -354,7 +360,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { SET current_balance = 0 WHERE user_id = ?`, user.Id) err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err) - if err != nil && err != db.ErrNotFound { + if err != nil && !errors.Is(err, db.ErrNotFound) { return err } @@ -363,7 +369,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { SET current_balance = 0 WHERE user_id = ?`, user.Id) err = db.TransformAndLogDbError("transaction RecalculateBalances", r, err) - if err != nil && err != db.ErrNotFound { + if err != nil && !errors.Is(err, db.ErrNotFound) { return err } @@ -372,7 +378,7 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { FROM "transaction" WHERE user_id = ?`, user.Id) err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err) - if err != nil && err != db.ErrNotFound { + if err != nil && !errors.Is(err, db.ErrNotFound) { return err } defer func() { @@ -382,15 +388,15 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { } }() - transaction := &types.Transaction{} + var transaction types.Transaction for rows.Next() { - err = rows.StructScan(transaction) + err = rows.StructScan(&transaction) err = db.TransformAndLogDbError("transaction RecalculateBalances", nil, err) if err != nil { return err } - s.updateErrors(transaction) + s.updateErrors(&transaction) r, err = tx.Exec(` UPDATE "transaction" SET error = ? @@ -424,7 +430,6 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { if err != nil { return err } - } } @@ -438,7 +443,6 @@ func (s TransactionImpl) RecalculateBalances(user *types.User) error { } func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransaction *types.Transaction, userId uuid.UUID, input types.TransactionInput) (*types.Transaction, error) { - var ( id uuid.UUID accountUuid *uuid.UUID @@ -484,7 +488,6 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio log.Error("transaction validate: %v", err) return nil, fmt.Errorf("account not found: %w", ErrBadRequest) } - } if input.TreasureChestId != "" { @@ -498,7 +501,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = db.TransformAndLogDbError("transaction validate", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) } return nil, err @@ -513,7 +516,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio log.Error("transaction validate: %v", err) return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest) } - valueInt := int64(valueFloat * 100) + valueInt := int64(valueFloat * DECIMALS_MULTIPLIER) timestamp, err := time.Parse("2006-01-02", input.Timestamp) if err != nil { @@ -544,6 +547,7 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio Timestamp: timestamp, Party: input.Party, Description: input.Description, + Error: nil, CreatedAt: createdAt, CreatedBy: createdBy, @@ -557,25 +561,26 @@ func (s TransactionImpl) validateAndEnrichTransaction(tx *sqlx.Tx, oldTransactio } func (s TransactionImpl) updateErrors(transaction *types.Transaction) { - error := "" + errorStr := "" - if transaction.Value < 0 { + switch { + case transaction.Value < 0: if transaction.TreasureChestId == nil { - error = "no treasure chest specified" + errorStr = "no treasure chest specified" } - } else if transaction.Value > 0 { + case transaction.Value > 0: if transaction.AccountId == nil && transaction.TreasureChestId == nil { - error = "either an account or a treasure chest needs to be specified" + errorStr = "either an account or a treasure chest needs to be specified" } else if transaction.AccountId != nil && transaction.TreasureChestId != nil { - error = "positive amounts can only be applied to either an account or a treasure chest" + errorStr = "positive amounts can only be applied to either an account or a treasure chest" } - } else { - error = "\"value\" needs to be specified" + default: + errorStr = "\"value\" needs to be specified" } - if error == "" { + if errorStr == "" { transaction.Error = nil } else { - transaction.Error = &error + transaction.Error = &errorStr } } diff --git a/service/transaction_recurring.go b/service/transaction_recurring.go index eee9939..0f0fd76 100644 --- a/service/transaction_recurring.go +++ b/service/transaction_recurring.go @@ -1,6 +1,7 @@ package service import ( + "errors" "fmt" "strconv" "time" @@ -18,8 +19,8 @@ import ( var ( transactionRecurringMetric = promauto.NewCounterVec( prometheus.CounterOpts{ - Name: "spendsparrow_transactionRecurring_recurring_total", - Help: "The total of transactionRecurring recurring operations", + Name: "spendsparrow_transaction_recurring_total", + Help: "The total of transactionRecurring operations", }, []string{"operation"}, ) @@ -28,7 +29,6 @@ var ( type TransactionRecurring interface { Add(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) Update(user *types.User, transactionRecurring types.TransactionRecurringInput) (*types.TransactionRecurring, error) - Get(user *types.User, id string) (*types.TransactionRecurring, error) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) GetAllByTreasureChest(user *types.User, treasureChestId string) ([]*types.TransactionRecurring, error) Delete(user *types.User, id string) error @@ -50,7 +50,9 @@ func NewTransactionRecurring(db *sqlx.DB, random Random, clock Clock, settings * } } -func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) { +func (s TransactionRecurringImpl) Add( + user *types.User, + transactionRecurringInput types.TransactionRecurringInput) (*types.TransactionRecurring, error) { transactionRecurringMetric.WithLabelValues("add").Inc() if user == nil { @@ -72,8 +74,11 @@ func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInpu } r, err := tx.NamedExec(` - INSERT INTO "transaction_recurring" (id, user_id, interval_months, active, party, description, account_id, treasure_chest_id, value, created_at, created_by) - VALUES (:id, :user_id, :interval_months, :active, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`, transactionRecurring) + INSERT INTO "transaction_recurring" (id, user_id, interval_months, + active, party, description, account_id, treasure_chest_id, value, created_at, created_by) + VALUES (:id, :user_id, :interval_months, + :active, :party, :description, :account_id, :treasure_chest_id, :value, :created_at, :created_by)`, + transactionRecurring) err = db.TransformAndLogDbError("transactionRecurring Insert", r, err) if err != nil { return nil, err @@ -88,7 +93,9 @@ func (s TransactionRecurringImpl) Add(user *types.User, transactionRecurringInpu return transactionRecurring, nil } -func (s TransactionRecurringImpl) Update(user *types.User, input types.TransactionRecurringInput) (*types.TransactionRecurring, error) { +func (s TransactionRecurringImpl) Update( + user *types.User, + input types.TransactionRecurringInput) (*types.TransactionRecurring, error) { transactionRecurringMetric.WithLabelValues("update").Inc() if user == nil { return nil, ErrUnauthorized @@ -112,7 +119,7 @@ func (s TransactionRecurringImpl) Update(user *types.User, input types.Transacti err = tx.Get(transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("transactionRecurring Update", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("transactionRecurring %v not found: %w", input.Id, ErrBadRequest) } return nil, types.ErrInternal @@ -151,31 +158,6 @@ func (s TransactionRecurringImpl) Update(user *types.User, input types.Transacti return transactionRecurring, nil } -func (s TransactionRecurringImpl) Get(user *types.User, id string) (*types.TransactionRecurring, error) { - transactionRecurringMetric.WithLabelValues("get").Inc() - - if user == nil { - return nil, ErrUnauthorized - } - uuid, err := uuid.Parse(id) - if err != nil { - log.Error("transactionRecurring get: %v", err) - return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) - } - - var transactionRecurring types.TransactionRecurring - err = s.db.Get(&transactionRecurring, `SELECT * FROM transaction_recurring WHERE user_id = ? AND id = ?`, user.Id, uuid) - err = db.TransformAndLogDbError("transactionRecurring Get", nil, err) - if err != nil { - if err == db.ErrNotFound { - return nil, fmt.Errorf("transactionRecurring %v not found: %w", id, ErrBadRequest) - } - return nil, types.ErrInternal - } - - return &transactionRecurring, nil -} - func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId string) ([]*types.TransactionRecurring, error) { transactionRecurringMetric.WithLabelValues("get_all_by_account").Inc() if user == nil { @@ -201,7 +183,7 @@ func (s TransactionRecurringImpl) GetAllByAccount(user *types.User, accountId st err = tx.Get(&rowCount, `SELECT COUNT(*) FROM account WHERE id = ? AND user_id = ?`, accountUuid, user.Id) err = db.TransformAndLogDbError("transactionRecurring GetAllByAccount", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("account %v not found: %w", accountId, ErrBadRequest) } return nil, types.ErrInternal @@ -254,7 +236,7 @@ func (s TransactionRecurringImpl) GetAllByTreasureChest(user *types.User, treasu err = tx.Get(&rowCount, `SELECT COUNT(*) FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestId, user.Id) err = db.TransformAndLogDbError("transactionRecurring GetAllByTreasureChest", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("treasurechest %v not found: %w", treasureChestId, ErrBadRequest) } return nil, types.ErrInternal @@ -329,7 +311,6 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( oldTransactionRecurring *types.TransactionRecurring, userId uuid.UUID, input types.TransactionRecurringInput) (*types.TransactionRecurring, error) { - var ( id uuid.UUID accountUuid *uuid.UUID @@ -393,7 +374,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( err = tx.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE id = ? AND user_id = ?`, treasureChestUuid, userId) err = db.TransformAndLogDbError("transactionRecurring validate", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("treasure chest not found: %w", ErrBadRequest) } return nil, err @@ -418,7 +399,7 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( log.Error("transactionRecurring validate: %v", err) return nil, fmt.Errorf("could not parse value: %w", ErrBadRequest) } - valueInt := int64(valueFloat * 100) + valueInt := int64(valueFloat * DECIMALS_MULTIPLIER) if input.Party != "" { err = validateString(input.Party, "party") @@ -444,12 +425,12 @@ func (s TransactionRecurringImpl) validateAndEnrichTransactionRecurring( active := input.Active == "on" transactionRecurring := types.TransactionRecurring{ - Id: id, UserId: userId, IntervalMonths: intervalMonths, Active: active, + LastExecution: nil, Party: input.Party, Description: input.Description, diff --git a/service/treasure_chest.go b/service/treasure_chest.go index ab865f7..4ad4db0 100644 --- a/service/treasure_chest.go +++ b/service/treasure_chest.go @@ -1,6 +1,7 @@ package service import ( + "errors" "fmt" "slices" @@ -131,7 +132,7 @@ func (s TreasureChestImpl) Update(user *types.User, idStr, parentId, name string err = tx.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, id) err = db.TransformAndLogDbError("treasureChest Update", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("treasureChest %v not found: %w", idStr, err) } return nil, types.ErrInternal @@ -198,17 +199,17 @@ func (s TreasureChestImpl) Get(user *types.User, id string) (*types.TreasureChes return nil, fmt.Errorf("could not parse Id: %w", ErrBadRequest) } - treasureChest := &types.TreasureChest{} - err = s.db.Get(treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) + var treasureChest types.TreasureChest + err = s.db.Get(&treasureChest, `SELECT * FROM treasure_chest WHERE user_id = ? AND id = ?`, user.Id, uuid) err = db.TransformAndLogDbError("treasureChest Get", nil, err) if err != nil { - if err == db.ErrNotFound { + if errors.Is(err, db.ErrNotFound) { return nil, fmt.Errorf("treasureChest %v not found: %w", id, err) } return nil, types.ErrInternal } - return treasureChest, nil + return &treasureChest, nil } func (s TreasureChestImpl) GetAll(user *types.User) ([]*types.TreasureChest, error) { @@ -259,7 +260,9 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { } transactionsCount := 0 - err = tx.Get(&transactionsCount, `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`, user.Id, id) + err = tx.Get(&transactionsCount, + `SELECT COUNT(*) FROM "transaction" WHERE user_id = ? AND treasure_chest_id = ?`, + user.Id, id) err = db.TransformAndLogDbError("treasureChest Delete", nil, err) if err != nil { return err @@ -284,12 +287,11 @@ func (s TreasureChestImpl) Delete(user *types.User, idStr string) error { } func sortTree(nodes []*types.TreasureChest) []*types.TreasureChest { - var ( - roots []*types.TreasureChest - result []*types.TreasureChest + roots []*types.TreasureChest ) children := make(map[uuid.UUID][]*types.TreasureChest) + result := make([]*types.TreasureChest, 0) for _, node := range nodes { if node.ParentId == nil { diff --git a/template/account/account.templ b/template/account/account.templ index d57981f..8ad3e07 100644 --- a/template/account/account.templ +++ b/template/account/account.templ @@ -111,6 +111,7 @@ templ AccountItem(account *types.Account) { hx-target="closest #account" hx-swap="outerHTML" class="button button-neglect px-1 flex items-center gap-2" + hx-confirm="Are you sure you want to delete this account?" > @svg.Delete() diff --git a/template/layout.templ b/template/layout.templ index b29c9a7..bd3321c 100644 --- a/template/layout.templ +++ b/template/layout.templ @@ -28,7 +28,7 @@ templ Layout(slot templ.Component, user templ.Component, loggedIn bool, path str - + // Header