diff --git a/db/auth.go b/db/auth.go index c86ce86..0a36c20 100644 --- a/db/auth.go +++ b/db/auth.go @@ -338,7 +338,7 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt) if err != nil { - log.Warn("Session not found: %v", err) + log.Warn("Session \"%s\" not found: %v", sessionId, err) return nil, ErrNotFound } diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 0f68067..74710f2 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -53,7 +53,7 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler csrfToken = r.Header.Get("csrf-token") } if session == nil || csrfToken == "" || !auth.IsCsrfTokenValid(csrfToken, session.Id) { - log.Info("CSRF-Token not correct") + log.Info("CSRF-Token \"%s\" not correct", csrfToken) if r.Header.Get("HX-Request") == "true" { utils.TriggerToastWithStatus(w, r, "error", "CSRF-Token not correct", http.StatusBadRequest) } else { diff --git a/main_test.go b/main_test.go index 3a919f9..ad9dd64 100644 --- a/main_test.go +++ b/main_test.go @@ -3,6 +3,7 @@ package main import ( "context" "fmt" + "io" "net/http" "net/url" "strings" @@ -1615,6 +1616,228 @@ func TestIntegrationAuth(t *testing.T) { }) } +func TestIntegrationAccount(t *testing.T) { + t.Parallel() + + t.Run("SignIn", func(t *testing.T) { + t.Run(`should throw unauthorized if try to getAll, get, edit, insert or delete`, func(t *testing.T) { + t.Parallel() + + _, basePath, ctx := setupIntegrationTest(t) + + csrfToken, sessionId := createAnonymousSession(t, ctx, basePath) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/account", nil) + assert.Nil(t, err) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) + + req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/some-id", nil) + assert.Nil(t, err) + req.Header.Set("Cookie", "id="+sessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) + + formData := url.Values{ + "name": {"name"}, + "csrf-token": {csrfToken}, + } + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/account/some-id", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+sessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) + + req, err = http.NewRequestWithContext(ctx, "DELETE", basePath+"/account/some-id", nil) + assert.Nil(t, err) + req.Header.Set("csrf-token", csrfToken) + req.Header.Set("Cookie", "id="+sessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + assert.Equal(t, "/auth/signin", resp.Header.Get("Location")) + }) + t.Run(`should be able to insert, get, delete and update`, func(t *testing.T) { + t.Parallel() + + db, basePath, ctx := setupIntegrationTest(t) + + csrfToken, sessionId := createValidUserSession(t, db, ctx, basePath, "") + + // Insert + expectedName := "My great Account" + formData := url.Values{ + "name": {expectedName}, + "csrf-token": {csrfToken}, + } + req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+sessionId) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, readBody(t, resp.Body), expectedName) + + var id uuid.UUID + err = db.Get(&id, "SELECT id FROM account") + assert.Nil(t, err) + + // Update + expectedNewName := "My new Account" + formData = url.Values{ + "name": {expectedNewName}, + "csrf-token": {csrfToken}, + } + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/account/"+id.String(), strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+sessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, readBody(t, resp.Body), expectedNewName) + + // Get + req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/"+id.String(), nil) + assert.Nil(t, err) + req.Header.Set("Cookie", "id="+sessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, readBody(t, resp.Body), expectedNewName) + + // Delete + req, err = http.NewRequestWithContext(ctx, "DELETE", basePath+"/account/"+id.String(), strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+sessionId) + req.Header.Set("csrf-token", csrfToken) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Get (not found) + req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account/"+id.String(), nil) + assert.Nil(t, err) + req.Header.Set("Cookie", "id="+sessionId) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) + assert.NotContains(t, readBody(t, resp.Body), expectedNewName) + }) + t.Run(`should not be able to see other users content`, func(t *testing.T) { + t.Parallel() + + db, basePath, ctx := setupIntegrationTest(t) + + csrfToken1, sessionId1 := createValidUserSession(t, db, ctx, basePath, "1") + _, sessionId2 := createValidUserSession(t, db, ctx, basePath, "2") + + expectedName1 := "Account 1" + + formData := url.Values{ + "name": {expectedName1}, + "csrf-token": {csrfToken1}, + } + req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+sessionId1) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/account", nil) + assert.Nil(t, err) + req.Header.Set("Cookie", "id="+sessionId2) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.NotContains(t, expectedName1, readBody(t, resp.Body)) + }) + t.Run(`should prohibit special characters in name`, func(t *testing.T) { + t.Parallel() + db, basePath, ctx := setupIntegrationTest(t) + + csrfToken, sessionId := createValidUserSession(t, db, ctx, basePath, "") + + data := map[string]int{ + "<": 400, + ">": 400, + "/": 400, + "\\": 400, + "?": 400, + ":": 400, + "*": 400, + "|": 400, + "\"": 400, + "Account": 200, + } + + for name, status := range data { + + formData := url.Values{ + "name": {name}, + "csrf-token": {csrfToken}, + } + + req, err := http.NewRequestWithContext(ctx, "POST", basePath+"/account/new", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+sessionId) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, status, resp.StatusCode, "for name: "+name) + } + + }) + }) +} + +func createValidUserSession(t *testing.T, db *sqlx.DB, ctx context.Context, basePath string, add string) (string, string) { + userId := uuid.New() + sessionId := "session-id" + add + pass := service.GetHashPassword("password", []byte("salt")) + csrfToken := "my-verifying-token" + add + email := add + "mail@mail.de" + + _, err := db.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt")) + assert.Nil(t, err) + _, err = db.Exec(` + INSERT INTO session (session_id, user_id, created_at, expires_at) + VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) + assert.Nil(t, err) + + _, err = db.Exec(` + INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) + VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf) + assert.Nil(t, err) + + return csrfToken, sessionId +} + +func createAnonymousSession(t *testing.T, ctx context.Context, basePath string) (string, string) { + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) + assert.Nil(t, err) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + html, err := html.Parse(resp.Body) + assert.Nil(t, err) + + return findCsrfToken(html), findCookie(resp, "id").Value +} + func findCookie(resp *http.Response, name string) *http.Cookie { for _, cookie := range resp.Cookies() { if cookie.Name == name { @@ -1751,3 +1974,11 @@ func getTokenAttribute(data *html.Node) *html.Attribute { return nil } + +func readBody(t *testing.T, body io.ReadCloser) string { + defer body.Close() + data, err := io.ReadAll(body) + assert.Nil(t, err) + + return string(data) +}