diff --git a/db/auth.go b/db/auth.go index b2bb846..b755e67 100644 --- a/db/auth.go +++ b/db/auth.go @@ -343,6 +343,7 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt) if err != nil { + log.Warn("Session not found: %v", err) return nil, ErrNotFound } diff --git a/handler/middleware/authenticate.go b/handler/middleware/authenticate.go index 9071215..4b8426d 100644 --- a/handler/middleware/authenticate.go +++ b/handler/middleware/authenticate.go @@ -21,15 +21,23 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler { sessionId := getSessionID(r) session, user, _ := service.SignInSession(sessionId) - if session != nil { + // Always sign in anonymous + // This way, we can always generate csrf tokens + if session == nil { + session, err := service.SignInAnonymous() + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } - ctx = context.WithValue(ctx, UserKey, user) - ctx = context.WithValue(ctx, SessionKey, session) - - next.ServeHTTP(w, r.WithContext(ctx)) - } else { - next.ServeHTTP(w, r) + cookie := CreateSessionCookie(session.Id) + http.SetCookie(w, &cookie) } + + ctx = context.WithValue(ctx, UserKey, user) + ctx = context.WithValue(ctx, SessionKey, session) + + next.ServeHTTP(w, r.WithContext(ctx)) }) } } diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 2343cf4..2f35292 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -62,15 +62,6 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler } } - // Always sign in anonymous - // This way, there is no way to forget creating a csrf token - if session == nil { - session, _ = auth.SignInAnonymous() - - cookie := CreateSessionCookie(session.Id) - http.SetCookie(w, &cookie) - } - responseWriter := newCsrfResponseWriter(w, auth, session) next.ServeHTTP(responseWriter, r) }) diff --git a/handler/middleware/security_headers.go b/handler/middleware/security_headers.go index 474b402..959cf74 100644 --- a/handler/middleware/security_headers.go +++ b/handler/middleware/security_headers.go @@ -22,7 +22,7 @@ func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Han "form-action 'self'; "+ "frame-ancestors 'none'; ", ) - w.Header().Set("Cross-Origin-Resource-Policy", "same-origin") + w.Header().Set("Cross-Origin-Resource-Policy", "same-site") // same-site, as same origin prohibits umami w.Header().Set("Cross-Origin-Opener-Policy", "same-origin") w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp") w.Header().Set("Permissions-Policy", "geolocation=(), camera=(), microphone=(), interest-cohort=()") diff --git a/main_test.go b/main_test.go index 5bd15ea..04c8032 100644 --- a/main_test.go +++ b/main_test.go @@ -90,7 +90,7 @@ func TestIntegrationSecurityHeader(t *testing.T) { "frame-ancestors 'none';", value) value = resp.Header.Get("Cross-Origin-Resource-Policy") - assert.Equal(t, "same-origin", value) + assert.Equal(t, "same-site", value) value = resp.Header.Get("Cross-Origin-Opener-Policy") assert.Equal(t, "same-origin", value) @@ -300,6 +300,41 @@ func TestIntegrationAuth(t *testing.T) { assert.False(t, sessions.Next()) }) }) + + t.Run("Session", func(t *testing.T) { + t.Run("should create new anonymous session if current session gets outdated", func(t *testing.T) { + t.Parallel() + + d, basePath, ctx := setupIntegrationTest(t) + + userId := uuid.New() + sessionId := "session-id" + + _, err := d.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt")) + assert.Nil(t, err) + _, err = d.Exec(` + INSERT INTO session (session_id, user_id, created_at, expires_at) + VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId) + assert.Nil(t, err) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/", nil) + assert.Nil(t, err) + req.Header.Set("Cookie", "id="+sessionId) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + + newSession := findCookie(resp, "id") + assert.NotNil(t, newSession) + assert.NotEqual(t, sessionId, newSession.Value) + + var rows int + err = d.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows) + assert.Nil(t, err) + assert.Equal(t, 0, rows) + }) + }) } func findCookie(resp *http.Response, name string) *http.Cookie { diff --git a/service/auth.go b/service/auth.go index b907a81..32692fa 100644 --- a/service/auth.go +++ b/service/auth.go @@ -99,6 +99,7 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types. return nil, nil, types.ErrInternal } if session.ExpiresAt.Before(service.clock.Now()) { + _ = service.db.DeleteSession(sessionId) return nil, nil, nil } @@ -130,7 +131,6 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) } err = service.db.DeleteOldSessions(userId) - if err != nil { return nil, types.ErrInternal } diff --git a/service/random_generator.go b/service/random_generator.go index c0f2155..d004a76 100644 --- a/service/random_generator.go +++ b/service/random_generator.go @@ -37,6 +37,7 @@ func (r *RandomImpl) Bytes(size int) ([]byte, error) { func (r *RandomImpl) String(size int) (string, error) { bytes, err := r.Bytes(size) if err != nil { + log.Error("Error generating random string: %v", err) return "", types.ErrInternal }