package test_test import ( "context" "encoding/json" "io" "net/http" "net/url" "spend-sparrow/internal" "spend-sparrow/internal/service" "spend-sparrow/internal/types" "strconv" "strings" "sync/atomic" "testing" "time" "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/html" ) var ( httpClient = http.Client{ // Disable redirect following CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } port atomic.Int64 ) func setupIntegrationTest(t *testing.T) (*sqlx.DB, string, context.Context) { t.Helper() ctx, done := context.WithCancel(context.Background()) t.Cleanup(done) db, err := sqlx.Open("sqlite3", ":memory:") if err != nil { t.Fatalf("Could not open Database data.db: %v", err) } t.Cleanup(func() { err := db.Close() if err != nil { panic(err) } }) testPort := port.Add(1) testPort += 1024 go func() { _ = internal.Run(ctx, db, "../", getEnv(testPort)) }() basePath := "http://localhost:" + strconv.Itoa(int(testPort)) err = waitForReady(t, ctx, 5*time.Second, basePath) require.NoError(t, err) return db, basePath, ctx } func getEnv(port int64) func(string) string { return func(key string) string { switch key { case "PORT": return strconv.Itoa(int(port)) case "SMTP_ENABLED": return "false" case "OLTP_ENABLED": return "false" case "BASE_URL": return "http://localhost:" + strconv.Itoa(int(port)) case "ENVIRONMENT": return "test" default: return "" } } } // waitForReady calls the specified endpoint until it gets a 200 // response or until the context is cancelled or the timeout is // reached. func waitForReady( t *testing.T, ctx context.Context, timeout time.Duration, endpoint string, ) error { t.Helper() client := http.Client{} startTime := time.Now() for { req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) require.NoError(t, err) resp, err := client.Do(req) if err == nil && resp.StatusCode == http.StatusOK { return resp.Body.Close() } else if err == nil { err := resp.Body.Close() if err != nil { return err } } select { case <-ctx.Done(): return ctx.Err() default: if time.Since(startTime) >= timeout { t.Fatal("timeout reached while waiting for endpoint") return types.ErrInternal } // wait a little while between checks time.Sleep(250 * time.Millisecond) } } } func findCsrfToken(t *testing.T, data *html.Node) string { t.Helper() token := getTokenAttribute(t, data) if token != "" { return token } if data.FirstChild != nil { if token = findCsrfToken(t, data.FirstChild); token != "" { return token } } if data.NextSibling != nil { if token = findCsrfToken(t, data.NextSibling); token != "" { return token } } return "" } func getTokenAttribute(t *testing.T, data *html.Node) string { t.Helper() for _, attr := range data.Attr { if attr.Key == "hx-headers" { var data map[string]any err := json.Unmarshal([]byte(attr.Val), &data) require.NoError(t, err) result, ok := data["Csrf-Token"].(string) if !ok { return "" } return result } } return "" } func readBody(t *testing.T, body io.ReadCloser) string { t.Helper() data, err := io.ReadAll(body) require.NoError(t, err) return string(data) } func createValidUserSession(t *testing.T, db *sqlx.DB, add string) (uuid.UUID, string, string) { t.Helper() userId := uuid.New() sessionId := "session-id" + add pass := service.GetHashPassword("password", []byte("salt")) csrfToken := "my-verifying-token" + add email := add + "mail@mail.de" _, err := db.ExecContext(context.Background(), ` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, TRUE, FALSE, ?, ?, datetime())`, userId, email, pass, []byte("salt")) require.NoError(t, err) _, err = db.ExecContext(context.Background(), ` INSERT INTO session (session_id, user_id, created_at, expires_at) VALUES (?, ?, datetime(), datetime("now", "+1 day"))`, sessionId, userId) require.NoError(t, err) _, err = db.ExecContext(context.Background(), ` INSERT INTO token (token, user_id, session_id, type, created_at, expires_at) VALUES (?, ?, ?, ?, datetime(), datetime("now", "+1 day"))`, csrfToken, userId, sessionId, types.TokenTypeCsrf) require.NoError(t, err) return userId, csrfToken, sessionId } func createAnonymousSession(t *testing.T, ctx context.Context, basePath string) (string, string) { t.Helper() req, err := http.NewRequestWithContext(ctx, http.MethodGet, basePath+"/auth/signin", nil) require.NoError(t, err) resp, err := httpClient.Do(req) require.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) html, err := html.Parse(resp.Body) _ = resp.Body.Close() require.NoError(t, err) return findCsrfToken(t, html), findCookie(t, resp).Value } func findCookie(t *testing.T, resp *http.Response) *http.Cookie { t.Helper() for _, cookie := range resp.Cookies() { if cookie.Name == "id" { return cookie } } return nil } func doAuthenticatedRequest( t *testing.T, ctx context.Context, method string, path string, formData url.Values, csrfToken string, sessionId string, ) *http.Response { t.Helper() req, err := http.NewRequestWithContext(ctx, method, path, strings.NewReader(formData.Encode())) require.NoError(t, err) req.Header.Set("Csrf-Token", csrfToken) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Hx-Request", "true") req.AddCookie(&http.Cookie{Name: "id", Value: sessionId}) resp, err := httpClient.Do(req) require.NoError(t, err) return resp }