package main import ( "context" "database/sql" "fmt" "net/http" "net/url" "strings" "sync/atomic" "testing" "time" "me-fit/log" "me-fit/service" "me-fit/types" "github.com/google/uuid" "github.com/stretchr/testify/assert" "golang.org/x/net/html" ) var ( httpClient = http.Client{ // Disable redirect following CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } port atomic.Int32 ) func TestSecurity(t *testing.T) { t.Parallel() } func TestAuth(t *testing.T) { t.Parallel() t.Run("should signin and return session cookie", func(t *testing.T) { t.Parallel() db, basePath, ctx := setupIntegrationTest(t) pass := service.GetHashPassword("password", []byte("salt")) _, err := db.Exec(` INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) if err != nil { t.Fatalf("Error inserting user: %v", err) } req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) assert.Nil(t, err) resp, err := httpClient.Do(req) assert.Nil(t, err) html, err := html.Parse(resp.Body) assert.Nil(t, err) csrfToken := findCsrfToken(html) assert.NotEqual(t, "", csrfToken) anonymousSession := findCookie(resp, "id") assert.NotNil(t, anonymousSession) formData := url.Values{ "email": {"mail@mail.de"}, "password": {"password"}, "csrf-token": {csrfToken}, } req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) assert.Nil(t, err) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Cookie", anonymousSession.Name+"="+anonymousSession.Value) resp, err = httpClient.Do(req) assert.Nil(t, err) assert.Equal(t, http.StatusSeeOther, resp.StatusCode) cookie := findCookie(resp, "id") if cookie == nil { t.Fatalf("No session cookie found") } else if cookie.SameSite != http.SameSiteStrictMode || cookie.HttpOnly != true || cookie.Secure != true { t.Fatalf("Cookie is not secure") } }) } func findCookie(resp *http.Response, name string) *http.Cookie { for _, cookie := range resp.Cookies() { if cookie.Name == name { return cookie } } return nil } func setupIntegrationTest(t *testing.T) (db *sql.DB, basePath string, ctx context.Context) { ctx, done := context.WithCancel(context.Background()) t.Cleanup(done) db, err := sql.Open("sqlite3", ":memory:") if err != nil { t.Fatalf("Could not open Database data.db: %v", err) } t.Cleanup(func() { db.Close() }) testPort := port.Add(1) testPort += 1024 go run(ctx, db, getEnv(testPort)) basePath = "http://localhost:" + fmt.Sprint(testPort) err = waitForReady(ctx, 5*time.Second, basePath) if err != nil { t.Fatalf("Failed to start server: %v", err) } return db, basePath, ctx } func getEnv(port int32) func(string) string { return func(key string) string { if key == "PORT" { return fmt.Sprint(port) } else if key == "SMTP_ENABLED" { return "false" } else if key == "PROMETHEUS_ENABLED" { return "false" } else if key == "BASE_URL" { return "http://localhost:" + fmt.Sprint(port) } else if key == "ENVIRONMENT" { return "test" } else { return "" } } } // waitForReady calls the specified endpoint until it gets a 200 // response or until the context is cancelled or the timeout is // reached. func waitForReady( ctx context.Context, timeout time.Duration, endpoint string, ) error { client := http.Client{} startTime := time.Now() for { req, err := http.NewRequestWithContext( ctx, http.MethodGet, endpoint, nil, ) if err != nil { log.Error("failed to create request: %v", err) return err } resp, err := client.Do(req) if err != nil { log.Info("Error making request: %v", err) continue } if resp.StatusCode == http.StatusOK { log.Info("Endpoint is ready!") resp.Body.Close() return nil } resp.Body.Close() select { case <-ctx.Done(): return ctx.Err() default: if time.Since(startTime) >= timeout { log.Error("timeout reached while waiting for endpoint") return types.ErrInternal } // wait a little while between checks time.Sleep(250 * time.Millisecond) } } } func findCsrfToken(data *html.Node) string { attr := getTokenAttribute(data) if attr != nil { return attr.Val } if data.FirstChild != nil { if token := findCsrfToken(data.FirstChild); token != "" { return token } } if data.NextSibling != nil { if token := findCsrfToken(data.NextSibling); token != "" { return token } } return "" } func getTokenAttribute(data *html.Node) *html.Attribute { returnValue := false for _, attr := range data.Attr { if attr.Key == "name" && attr.Val == "csrf-token" { returnValue = true } } if !returnValue { return nil } for _, attr := range data.Attr { if attr.Key == "value" { return &attr } } return nil }