chore(auth): #331 add and fix session tests
This commit was merged in pull request #342.
This commit is contained in:
@@ -343,6 +343,7 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
||||
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
|
||||
|
||||
if err != nil {
|
||||
log.Warn("Session not found: %v", err)
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
|
||||
@@ -16,20 +16,29 @@ var UserKey ContextKey = "user"
|
||||
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
sessionId := getSessionID(r)
|
||||
session, user, _ := service.SignInSession(sessionId)
|
||||
|
||||
if session != nil {
|
||||
var err error
|
||||
// Always sign in anonymous
|
||||
// This way, we can always generate csrf tokens
|
||||
if session == nil {
|
||||
session, err = service.SignInAnonymous()
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
cookie := CreateSessionCookie(session.Id)
|
||||
http.SetCookie(w, &cookie)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, UserKey, user)
|
||||
ctx = context.WithValue(ctx, SessionKey, session)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,15 +62,6 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
||||
}
|
||||
}
|
||||
|
||||
// Always sign in anonymous
|
||||
// This way, there is no way to forget creating a csrf token
|
||||
if session == nil {
|
||||
session, _ = auth.SignInAnonymous()
|
||||
|
||||
cookie := CreateSessionCookie(session.Id)
|
||||
http.SetCookie(w, &cookie)
|
||||
}
|
||||
|
||||
responseWriter := newCsrfResponseWriter(w, auth, session)
|
||||
next.ServeHTTP(responseWriter, r)
|
||||
})
|
||||
|
||||
@@ -22,7 +22,7 @@ func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Han
|
||||
"form-action 'self'; "+
|
||||
"frame-ancestors 'none'; ",
|
||||
)
|
||||
w.Header().Set("Cross-Origin-Resource-Policy", "same-origin")
|
||||
w.Header().Set("Cross-Origin-Resource-Policy", "same-site") // same-site, as same origin prohibits umami
|
||||
w.Header().Set("Cross-Origin-Opener-Policy", "same-origin")
|
||||
w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp")
|
||||
w.Header().Set("Permissions-Policy", "geolocation=(), camera=(), microphone=(), interest-cohort=()")
|
||||
|
||||
77
main_test.go
77
main_test.go
@@ -90,7 +90,7 @@ func TestIntegrationSecurityHeader(t *testing.T) {
|
||||
"frame-ancestors 'none';", value)
|
||||
|
||||
value = resp.Header.Get("Cross-Origin-Resource-Policy")
|
||||
assert.Equal(t, "same-origin", value)
|
||||
assert.Equal(t, "same-site", value)
|
||||
|
||||
value = resp.Header.Get("Cross-Origin-Opener-Policy")
|
||||
assert.Equal(t, "same-origin", value)
|
||||
@@ -300,6 +300,81 @@ func TestIntegrationAuth(t *testing.T) {
|
||||
assert.False(t, sessions.Next())
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Session", func(t *testing.T) {
|
||||
t.Run("should create new anonymous session if current session gets outdated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d, basePath, ctx := setupIntegrationTest(t)
|
||||
|
||||
userId := uuid.New()
|
||||
sessionId := "session-id"
|
||||
|
||||
_, err := d.Exec(`
|
||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||
assert.Nil(t, err)
|
||||
_, err = d.Exec(`
|
||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
|
||||
assert.Nil(t, err)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", basePath, nil)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Cookie", "id="+sessionId)
|
||||
resp, err := httpClient.Do(req)
|
||||
assert.Nil(t, err)
|
||||
|
||||
newSession := findCookie(resp, "id")
|
||||
assert.NotNil(t, newSession)
|
||||
assert.NotEqual(t, sessionId, newSession.Value)
|
||||
|
||||
var rows int
|
||||
err = d.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, rows)
|
||||
})
|
||||
t.Run("should create anonymous session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, basePath, ctx := setupIntegrationTest(t)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", basePath, nil)
|
||||
assert.Nil(t, err)
|
||||
resp, err := httpClient.Do(req)
|
||||
assert.Nil(t, err)
|
||||
|
||||
newSession := findCookie(resp, "id")
|
||||
assert.NotNil(t, newSession)
|
||||
assert.NotEqual(t, "", newSession.Value)
|
||||
})
|
||||
t.Run("should not have access to user information with outdated session", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
d, basePath, ctx := setupIntegrationTest(t)
|
||||
|
||||
userId := uuid.New()
|
||||
sessionId := "session-id"
|
||||
|
||||
_, err := d.Exec(`
|
||||
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||
assert.Nil(t, err)
|
||||
_, err = d.Exec(`
|
||||
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
|
||||
assert.Nil(t, err)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/workout", nil)
|
||||
assert.Nil(t, err)
|
||||
req.Header.Set("Cookie", "id="+sessionId)
|
||||
resp, err := httpClient.Do(req)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
|
||||
assert.Equal(t, "/auth/signin", resp.Header.Get("Location"))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func findCookie(resp *http.Response, name string) *http.Cookie {
|
||||
|
||||
@@ -99,6 +99,7 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
|
||||
return nil, nil, types.ErrInternal
|
||||
}
|
||||
if session.ExpiresAt.Before(service.clock.Now()) {
|
||||
_ = service.db.DeleteSession(sessionId)
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -120,6 +121,8 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
log.Info("Anonymous session created: %v", session.Id)
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
@@ -130,7 +133,6 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error)
|
||||
}
|
||||
|
||||
err = service.db.DeleteOldSessions(userId)
|
||||
|
||||
if err != nil {
|
||||
return nil, types.ErrInternal
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ func (r *RandomImpl) Bytes(size int) ([]byte, error) {
|
||||
func (r *RandomImpl) String(size int) (string, error) {
|
||||
bytes, err := r.Bytes(size)
|
||||
if err != nil {
|
||||
log.Error("Error generating random string: %v", err)
|
||||
return "", types.ErrInternal
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user