feat(account): #49 refactor error handling #61
@@ -1,6 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"spend-sparrow/log"
|
"spend-sparrow/log"
|
||||||
"spend-sparrow/types"
|
"spend-sparrow/types"
|
||||||
|
|
||||||
@@ -31,7 +32,7 @@ func (db AccountSqlite) Insert(userId uuid.UUID, account *types.Account) error {
|
|||||||
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
|
INSERT INTO account (id, user_id, name, current_balance, oink_balance, created_at, created_by)
|
||||||
VALUES (?,?,?,?,?,?,?)`, account.Id, userId, account.Name, 0, 0, account.CreatedAt, account.CreatedBy)
|
VALUES (?,?,?,?,?,?,?)`, account.Id, userId, account.Name, 0, 0, account.CreatedAt, account.CreatedBy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error inserting account: %v", err)
|
log.Error("account Insert: %v", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,7 +41,6 @@ func (db AccountSqlite) Insert(userId uuid.UUID, account *types.Account) error {
|
|||||||
|
|
||||||
func (db AccountSqlite) Update(userId uuid.UUID, account *types.Account) error {
|
func (db AccountSqlite) Update(userId uuid.UUID, account *types.Account) error {
|
||||||
|
|
||||||
log.Info("Updating account: %v", account)
|
|
||||||
r, err := db.db.Exec(`
|
r, err := db.db.Exec(`
|
||||||
UPDATE account
|
UPDATE account
|
||||||
SET
|
SET
|
||||||
@@ -53,17 +53,17 @@ func (db AccountSqlite) Update(userId uuid.UUID, account *types.Account) error {
|
|||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
AND user_id = ?`, account.Name, account.CurrentBalance, account.LastTransaction, account.OinkBalance, account.UpdatedAt, account.UpdatedBy, account.Id, userId)
|
AND user_id = ?`, account.Name, account.CurrentBalance, account.LastTransaction, account.OinkBalance, account.UpdatedAt, account.UpdatedBy, account.Id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error updating account: %v", err)
|
log.Error("account Update: %v", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
rows, err := r.RowsAffected()
|
rows, err := r.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error deleting account, getting rows affected: %v", err)
|
log.Error("account Update: %v", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows == 0 {
|
if rows == 0 {
|
||||||
log.Error("Error deleting account, rows affected: %v", rows)
|
log.Info("account Update: not found")
|
||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,7 +82,7 @@ func (db AccountSqlite) GetAll(userId uuid.UUID) ([]*types.Account, error) {
|
|||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
ORDER BY name`, userId)
|
ORDER BY name`, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Could not getAll accounts: %v", err)
|
log.Error("account GetAll: %v", err)
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +101,10 @@ func (db AccountSqlite) Get(userId uuid.UUID, id uuid.UUID) (*types.Account, err
|
|||||||
WHERE user_id = ?
|
WHERE user_id = ?
|
||||||
AND id = ?`, userId, id)
|
AND id = ?`, userId, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Could not get accounts: %v", err)
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
log.Error("account Get: %v", err)
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,18 +115,18 @@ func (db AccountSqlite) Delete(userId uuid.UUID, id uuid.UUID) error {
|
|||||||
|
|
||||||
res, err := db.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", id, userId)
|
res, err := db.db.Exec("DELETE FROM account WHERE id = ? and user_id = ?", id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error deleting account: %v", err)
|
log.Error("account Delete: %v", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := res.RowsAffected()
|
rows, err := res.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error deleting account, getting rows affected: %v", err)
|
log.Error("account Delete: %v", err)
|
||||||
return types.ErrInternal
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows == 0 {
|
if rows == 0 {
|
||||||
log.Error("Error deleting account, rows affected: %v", rows)
|
log.Info("account Delete: not found")
|
||||||
return ErrNotFound
|
return ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -338,7 +338,7 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
|||||||
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
|
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("Session not found: %v", err)
|
log.Warn("Session \"%s\" not found: %v", sessionId, err)
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
10
db/error.go
Normal file
10
db/error.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotFound = errors.New("the value does not exist")
|
||||||
|
ErrAlreadyExists = errors.New("row already exists")
|
||||||
|
)
|
||||||
@@ -12,11 +12,6 @@ import (
|
|||||||
"github.com/jmoiron/sqlx"
|
"github.com/jmoiron/sqlx"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
ErrNotFound = errors.New("the value does not exist")
|
|
||||||
ErrAlreadyExists = errors.New("row already exists")
|
|
||||||
)
|
|
||||||
|
|
||||||
func RunMigrations(db *sqlx.DB, pathPrefix string) error {
|
func RunMigrations(db *sqlx.DB, pathPrefix string) error {
|
||||||
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1,15 +1,14 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"spend-sparrow/handler/middleware"
|
"spend-sparrow/handler/middleware"
|
||||||
"spend-sparrow/log"
|
|
||||||
"spend-sparrow/service"
|
"spend-sparrow/service"
|
||||||
t "spend-sparrow/template/account"
|
t "spend-sparrow/template/account"
|
||||||
"spend-sparrow/types"
|
"spend-sparrow/types"
|
||||||
"spend-sparrow/utils"
|
"spend-sparrow/utils"
|
||||||
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/a-h/templ"
|
"github.com/a-h/templ"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@@ -49,7 +48,7 @@ func (h AccountImpl) handleAccountPage() http.HandlerFunc {
|
|||||||
|
|
||||||
accounts, err := h.s.GetAll(user)
|
accounts, err := h.s.GetAll(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,20 +68,19 @@ func (h AccountImpl) handleAccountItemComp() http.HandlerFunc {
|
|||||||
idStr := r.PathValue("id")
|
idStr := r.PathValue("id")
|
||||||
if idStr == "new" {
|
if idStr == "new" {
|
||||||
comp := t.EditAccount(nil)
|
comp := t.EditAccount(nil)
|
||||||
log.Info("Component: %v", comp)
|
|
||||||
h.r.Render(r, w, comp)
|
h.r.Render(r, w, comp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := uuid.Parse(idStr)
|
id, err := uuid.Parse(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest)
|
handleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := h.s.Get(user, id)
|
account, err := h.s.Get(user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,18 +111,18 @@ func (h AccountImpl) handleUpdateAccount() http.HandlerFunc {
|
|||||||
if idStr == "new" {
|
if idStr == "new" {
|
||||||
account, err = h.s.Add(user, name)
|
account, err = h.s.Add(user, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
id, err := uuid.Parse(idStr)
|
id, err := uuid.Parse(idStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest)
|
handleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
account, err = h.s.Update(user, id, name)
|
account, err = h.s.Update(user, id, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -144,13 +142,13 @@ func (h AccountImpl) handleDeleteAccount() http.HandlerFunc {
|
|||||||
|
|
||||||
id, err := uuid.Parse(r.PathValue("id"))
|
id, err := uuid.Parse(r.PathValue("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "Could not parse Id", http.StatusBadRequest)
|
handleError(w, r, fmt.Errorf("could not parse Id: %w", service.ErrBadRequest))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.s.Delete(user, id)
|
err = h.s.Delete(user, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", err.Error(), http.StatusInternalServerError)
|
handleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
30
handler/error.go
Normal file
30
handler/error.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"spend-sparrow/service"
|
||||||
|
"spend-sparrow/utils"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleError(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
if errors.Is(err, service.ErrUnauthorized) {
|
||||||
|
utils.TriggerToastWithStatus(w, r, "error", "You are not autorized to perform this operation.", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
} else if errors.Is(err, service.ErrBadRequest) {
|
||||||
|
utils.TriggerToastWithStatus(w, r, "error", extractErrorMessage(err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.TriggerToastWithStatus(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractErrorMessage(err error) string {
|
||||||
|
errMsg := err.Error()
|
||||||
|
if errMsg == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.SplitN(errMsg, ":", 2)[0]
|
||||||
|
}
|
||||||
@@ -37,10 +37,6 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) {
|
|||||||
return rr.ResponseWriter.Write([]byte(dataStr))
|
return rr.ResponseWriter.Write([]byte(dataStr))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rr *csrfResponseWriter) WriteHeader(statusCode int) {
|
|
||||||
rr.ResponseWriter.WriteHeader(statusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
|
func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -57,7 +53,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
|||||||
csrfToken = r.Header.Get("csrf-token")
|
csrfToken = r.Header.Get("csrf-token")
|
||||||
}
|
}
|
||||||
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
|
if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) {
|
||||||
log.Info("CSRF-Token not correct")
|
log.Info("CSRF-Token \"%s\" not correct", csrfToken)
|
||||||
if r.Header.Get("HX-Request") == "true" {
|
if r.Header.Get("HX-Request") == "true" {
|
||||||
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
|
utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -27,11 +27,13 @@ func Gzip(next http.Handler) http.Handler {
|
|||||||
|
|
||||||
w.Header().Set("Content-Encoding", "gzip")
|
w.Header().Set("Content-Encoding", "gzip")
|
||||||
gz := gzip.NewWriter(w)
|
gz := gzip.NewWriter(w)
|
||||||
gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w}
|
wrapper := gzipResponseWriter{Writer: gz, ResponseWriter: w}
|
||||||
next.ServeHTTP(gzr, r)
|
|
||||||
|
next.ServeHTTP(wrapper, r)
|
||||||
|
|
||||||
err := gz.Close()
|
err := gz.Close()
|
||||||
if err != nil {
|
if err != nil && err != http.ErrBodyNotAllowed {
|
||||||
|
// if err != nil {
|
||||||
log.Error("Gzip: could not close Writer: %v", err)
|
log.Error("Gzip: could not close Writer: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ type WrappedWriter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WrappedWriter) WriteHeader(code int) {
|
func (w *WrappedWriter) WriteHeader(code int) {
|
||||||
w.ResponseWriter.WriteHeader(code)
|
|
||||||
w.StatusCode = code
|
w.StatusCode = code
|
||||||
|
w.ResponseWriter.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Log(next http.Handler) http.Handler {
|
func Log(next http.Handler) http.Handler {
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/http"
|
||||||
"spend-sparrow/handler/middleware"
|
"spend-sparrow/handler/middleware"
|
||||||
"spend-sparrow/service"
|
"spend-sparrow/service"
|
||||||
"spend-sparrow/template"
|
"spend-sparrow/template"
|
||||||
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/a-h/templ"
|
"github.com/a-h/templ"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
3
main.go
3
main.go
@@ -134,8 +134,7 @@ func createHandler(d *sqlx.DB, serverSettings *types.Settings) http.Handler {
|
|||||||
middleware.CacheControl,
|
middleware.CacheControl,
|
||||||
middleware.CrossSiteRequestForgery(authService),
|
middleware.CrossSiteRequestForgery(authService),
|
||||||
middleware.Authenticate(authService),
|
middleware.Authenticate(authService),
|
||||||
middleware.Log,
|
|
||||||
// Gzip last, as it compresses the body
|
|
||||||
middleware.Gzip,
|
middleware.Gzip,
|
||||||
|
middleware.Log,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
232
main_test.go
232
main_test.go
@@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -1615,6 +1616,228 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIntegrationAccount(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("SignIn", func(t *testing.T) {
|
||||||
|
t.Run(`should throw unauthorized if try to getAll, get, edit, insert or delete`, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
csrfToken, sessionId := createAnonymousSession(t, ctx, basePath)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/account", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
|
||||||
|
assert.Equal(t, "/auth/signin", resp.Header.Get("Location"))
|
||||||
|
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/some-id", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
|
||||||
|
assert.Equal(t, "/auth/signin", resp.Header.Get("Location"))
|
||||||
|
|
||||||
|
formData := url.Values{
|
||||||
|
"name": {"name"},
|
||||||
|
"csrf-token": {csrfToken},
|
||||||
|
}
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/account/some-id", strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
|
||||||
|
assert.Equal(t, "/auth/signin", resp.Header.Get("Location"))
|
||||||
|
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "DELETE", basePath+"/account/some-id", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("csrf-token", csrfToken)
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
|
||||||
|
assert.Equal(t, "/auth/signin", resp.Header.Get("Location"))
|
||||||
|
})
|
||||||
|
t.Run(`should be able to insert, get, delete and update`, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
csrfToken, sessionId := createValidUserSession(t, db, ctx, basePath, "")
|
||||||
|
|
||||||
|
// Insert
|
||||||
|
expectedName := "My great Account"
|
||||||
|
formData := url.Values{
|
||||||
|
"name": {expectedName},
|
||||||
|
"csrf-token": {csrfToken},
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
assert.Contains(t, readBody(t, resp.Body), expectedName)
|
||||||
|
|
||||||
|
var id uuid.UUID
|
||||||
|
err = db.Get(&id, "SELECT id FROM account")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
// Update
|
||||||
|
expectedNewName := "My new Account"
|
||||||
|
formData = url.Values{
|
||||||
|
"name": {expectedNewName},
|
||||||
|
"csrf-token": {csrfToken},
|
||||||
|
}
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/account/"+id.String(), strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
assert.Contains(t, readBody(t, resp.Body), expectedNewName)
|
||||||
|
|
||||||
|
// Get
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/"+id.String(), nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
assert.Contains(t, readBody(t, resp.Body), expectedNewName)
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "DELETE", basePath+"/account/"+id.String(), strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
req.Header.Set("csrf-token", csrfToken)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
// Get (not found)
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/"+id.String(), nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
assert.NotContains(t, readBody(t, resp.Body), expectedNewName)
|
||||||
|
})
|
||||||
|
t.Run(`should not be able to see other users content`, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
csrfToken1, sessionId1 := createValidUserSession(t, db, ctx, basePath, "1")
|
||||||
|
_, sessionId2 := createValidUserSession(t, db, ctx, basePath, "2")
|
||||||
|
|
||||||
|
expectedName1 := "Account 1"
|
||||||
|
|
||||||
|
formData := url.Values{
|
||||||
|
"name": {expectedName1},
|
||||||
|
"csrf-token": {csrfToken1},
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId1)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId2)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
assert.NotContains(t, expectedName1, readBody(t, resp.Body))
|
||||||
|
})
|
||||||
|
t.Run(`should prohibit special characters in name`, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
csrfToken, sessionId := createValidUserSession(t, db, ctx, basePath, "")
|
||||||
|
|
||||||
|
data := map[string]int{
|
||||||
|
"<": 400,
|
||||||
|
">": 400,
|
||||||
|
"/": 400,
|
||||||
|
"\\": 400,
|
||||||
|
"?": 400,
|
||||||
|
":": 400,
|
||||||
|
"*": 400,
|
||||||
|
"|": 400,
|
||||||
|
"\"": 400,
|
||||||
|
"Account": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, status := range data {
|
||||||
|
|
||||||
|
formData := url.Values{
|
||||||
|
"name": {name},
|
||||||
|
"csrf-token": {csrfToken},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, status, resp.StatusCode, "for name: "+name)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func createValidUserSession(t *testing.T, db *sqlx.DB, ctx context.Context, basePath string, add string) (string, string) {
|
||||||
|
userId := uuid.New()
|
||||||
|
sessionId := "session-id" + add
|
||||||
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
csrfToken := "my-verifying-token" + add
|
||||||
|
email := add + "mail@mail.de"
|
||||||
|
|
||||||
|
_, err := db.Exec(`
|
||||||
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
|
VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
_, err = db.Exec(`
|
||||||
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
|
VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
_, err = db.Exec(`
|
||||||
|
INSERT INTO token (token, user_id, session_id, type, created_at, expires_at)
|
||||||
|
VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
return csrfToken, sessionId
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAnonymousSession(t *testing.T, ctx context.Context, basePath string) (string, string) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
html, err := html.Parse(resp.Body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
return findCsrfToken(html), findCookie(resp, "id").Value
|
||||||
|
}
|
||||||
|
|
||||||
func findCookie(resp *http.Response, name string) *http.Cookie {
|
func findCookie(resp *http.Response, name string) *http.Cookie {
|
||||||
for _, cookie := range resp.Cookies() {
|
for _, cookie := range resp.Cookies() {
|
||||||
if cookie.Name == name {
|
if cookie.Name == name {
|
||||||
@@ -1751,3 +1974,12 @@ func getTokenAttribute(data *html.Node) *html.Attribute {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readBody(t *testing.T, body io.ReadCloser) string {
|
||||||
|
data, err := io.ReadAll(body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = body.Close()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"spend-sparrow/db"
|
"spend-sparrow/db"
|
||||||
@@ -39,17 +39,17 @@ func NewAccountImpl(db db.Account, random Random, clock Clock, settings *types.S
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AccountImpl) Add(user *types.User, name string) (*types.Account, error) {
|
func (s AccountImpl) Add(user *types.User, name string) (*types.Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, types.ErrInternal
|
return nil, ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
newId, err := service.random.UUID()
|
newId, err := s.random.UUID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.validateAccount(name)
|
err = s.validateAccount(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -64,61 +64,48 @@ func (service AccountImpl) Add(user *types.User, name string) (*types.Account, e
|
|||||||
LastTransaction: nil,
|
LastTransaction: nil,
|
||||||
OinkBalance: 0,
|
OinkBalance: 0,
|
||||||
|
|
||||||
CreatedAt: service.clock.Now(),
|
CreatedAt: s.clock.Now(),
|
||||||
CreatedBy: user.Id,
|
CreatedBy: user.Id,
|
||||||
UpdatedAt: nil,
|
UpdatedAt: nil,
|
||||||
UpdatedBy: nil,
|
UpdatedBy: nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.Insert(user.Id, account)
|
err = s.db.Insert(user.Id, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
savedAccount, err := service.db.Get(user.Id, newId)
|
savedAccount, err := s.db.Get(user.Id, newId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
log.Error("account %v not found after insert: %v", newId, err)
|
||||||
log.Error("Account not found after insert: %v", err)
|
|
||||||
}
|
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
return savedAccount, nil
|
return savedAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AccountImpl) Update(user *types.User, id uuid.UUID, name string) (*types.Account, error) {
|
func (s AccountImpl) Update(user *types.User, id uuid.UUID, name string) (*types.Account, error) {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
|
return nil, ErrUnauthorized
|
||||||
|
}
|
||||||
|
err := s.validateAccount(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.db.Get(user.Id, id)
|
||||||
|
if err != nil {
|
||||||
|
if err == db.ErrNotFound {
|
||||||
|
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
|
||||||
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
err := service.validateAccount(name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := service.db.Get(user.Id, id)
|
timestamp := s.clock.Now()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
timestamp := service.clock.Now()
|
|
||||||
account.Name = name
|
account.Name = name
|
||||||
account.UpdatedAt = ×tamp
|
account.UpdatedAt = ×tamp
|
||||||
account.UpdatedBy = &user.Id
|
account.UpdatedBy = &user.Id
|
||||||
|
|
||||||
err = service.db.Update(user.Id, account)
|
err = s.db.Update(user.Id, account)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (service AccountImpl) Get(user *types.User, id uuid.UUID) (*types.Account, error) {
|
|
||||||
|
|
||||||
if user == nil {
|
|
||||||
return nil, types.ErrInternal
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := service.db.Get(user.Id, id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -126,13 +113,30 @@ func (service AccountImpl) Get(user *types.User, id uuid.UUID) (*types.Account,
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
func (s AccountImpl) Get(user *types.User, id uuid.UUID) (*types.Account, error) {
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
|
return nil, ErrUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := s.db.Get(user.Id, id)
|
||||||
|
if err != nil {
|
||||||
|
if err == db.ErrNotFound {
|
||||||
|
return nil, fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
|
||||||
|
}
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts, err := service.db.GetAll(user.Id)
|
return account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
return nil, ErrUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts, err := s.db.GetAll(user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -140,33 +144,36 @@ func (service AccountImpl) GetAll(user *types.User) ([]*types.Account, error) {
|
|||||||
return accounts, nil
|
return accounts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AccountImpl) Delete(user *types.User, id uuid.UUID) error {
|
func (s AccountImpl) Delete(user *types.User, id uuid.UUID) error {
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return types.ErrInternal
|
return ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := service.db.Get(user.Id, id)
|
account, err := s.db.Get(user.Id, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
if err == db.ErrNotFound {
|
||||||
|
return fmt.Errorf("account %v not found: %w", id, ErrBadRequest)
|
||||||
|
}
|
||||||
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.UserId != user.Id {
|
if account.UserId != user.Id {
|
||||||
return types.ErrUnauthorized
|
return types.ErrUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.Delete(user.Id, account.Id)
|
err = s.db.Delete(user.Id, account.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AccountImpl) validateAccount(name string) error {
|
func (s AccountImpl) validateAccount(name string) error {
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return errors.New("please enter a value for the \"name\" field")
|
return fmt.Errorf("field \"name\" needs to be set: %w", ErrBadRequest)
|
||||||
} else if !safeInputRegex.MatchString(name) {
|
} else if !safeInputRegex.MatchString(name) {
|
||||||
return errors.New("please use only letters, dashes or numbers for \"name\"")
|
return fmt.Errorf("use only letters, dashes and spaces for \"name\": %w", ErrBadRequest)
|
||||||
} else {
|
} else {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
8
service/error.go
Normal file
8
service/error.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBadRequest = errors.New("bad request")
|
||||||
|
ErrUnauthorized = errors.New("unauthorized")
|
||||||
|
)
|
||||||
@@ -45,5 +45,11 @@ func (r *RandomImpl) String(size int) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *RandomImpl) UUID() (uuid.UUID, error) {
|
func (r *RandomImpl) UUID() (uuid.UUID, error) {
|
||||||
return uuid.NewRandom()
|
id, err := uuid.NewRandom()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error generating random UUID: %v", err)
|
||||||
|
return uuid.Nil, types.ErrInternal
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,36 +18,6 @@ templ Account(accounts []*types.Account) {
|
|||||||
<div id="account-items" class="my-6 flex flex-col items-center">
|
<div id="account-items" class="my-6 flex flex-col items-center">
|
||||||
for _, account := range accounts {
|
for _, account := range accounts {
|
||||||
@AccountItem(account)
|
@AccountItem(account)
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
@AccountItem(account)
|
|
||||||
}
|
}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
Reference in New Issue
Block a user