chore(auth): #331 add and fix session tests
This commit was merged in pull request #342.
This commit is contained in:
@@ -343,6 +343,7 @@ func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
|
|||||||
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
|
WHERE session_id = ?`, sessionId).Scan(&userId, &createdAt, &expiresAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Warn("Session not found: %v", err)
|
||||||
return nil, ErrNotFound
|
return nil, ErrNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,20 +16,29 @@ var UserKey ContextKey = "user"
|
|||||||
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
func Authenticate(service service.Auth) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
|
||||||
|
|
||||||
sessionId := getSessionID(r)
|
sessionId := getSessionID(r)
|
||||||
session, user, _ := service.SignInSession(sessionId)
|
session, user, _ := service.SignInSession(sessionId)
|
||||||
|
|
||||||
if session != nil {
|
var err error
|
||||||
|
// Always sign in anonymous
|
||||||
|
// This way, we can always generate csrf tokens
|
||||||
|
if session == nil {
|
||||||
|
session, err = service.SignInAnonymous()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie := CreateSessionCookie(session.Id)
|
||||||
|
http.SetCookie(w, &cookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
ctx = context.WithValue(ctx, UserKey, user)
|
ctx = context.WithValue(ctx, UserKey, user)
|
||||||
ctx = context.WithValue(ctx, SessionKey, session)
|
ctx = context.WithValue(ctx, SessionKey, session)
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
} else {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,15 +62,6 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always sign in anonymous
|
|
||||||
// This way, there is no way to forget creating a csrf token
|
|
||||||
if session == nil {
|
|
||||||
session, _ = auth.SignInAnonymous()
|
|
||||||
|
|
||||||
cookie := CreateSessionCookie(session.Id)
|
|
||||||
http.SetCookie(w, &cookie)
|
|
||||||
}
|
|
||||||
|
|
||||||
responseWriter := newCsrfResponseWriter(w, auth, session)
|
responseWriter := newCsrfResponseWriter(w, auth, session)
|
||||||
next.ServeHTTP(responseWriter, r)
|
next.ServeHTTP(responseWriter, r)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func SecurityHeaders(serverSettings *types.Settings) func(http.Handler) http.Han
|
|||||||
"form-action 'self'; "+
|
"form-action 'self'; "+
|
||||||
"frame-ancestors 'none'; ",
|
"frame-ancestors 'none'; ",
|
||||||
)
|
)
|
||||||
w.Header().Set("Cross-Origin-Resource-Policy", "same-origin")
|
w.Header().Set("Cross-Origin-Resource-Policy", "same-site") // same-site, as same origin prohibits umami
|
||||||
w.Header().Set("Cross-Origin-Opener-Policy", "same-origin")
|
w.Header().Set("Cross-Origin-Opener-Policy", "same-origin")
|
||||||
w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp")
|
w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp")
|
||||||
w.Header().Set("Permissions-Policy", "geolocation=(), camera=(), microphone=(), interest-cohort=()")
|
w.Header().Set("Permissions-Policy", "geolocation=(), camera=(), microphone=(), interest-cohort=()")
|
||||||
|
|||||||
77
main_test.go
77
main_test.go
@@ -90,7 +90,7 @@ func TestIntegrationSecurityHeader(t *testing.T) {
|
|||||||
"frame-ancestors 'none';", value)
|
"frame-ancestors 'none';", value)
|
||||||
|
|
||||||
value = resp.Header.Get("Cross-Origin-Resource-Policy")
|
value = resp.Header.Get("Cross-Origin-Resource-Policy")
|
||||||
assert.Equal(t, "same-origin", value)
|
assert.Equal(t, "same-site", value)
|
||||||
|
|
||||||
value = resp.Header.Get("Cross-Origin-Opener-Policy")
|
value = resp.Header.Get("Cross-Origin-Opener-Policy")
|
||||||
assert.Equal(t, "same-origin", value)
|
assert.Equal(t, "same-origin", value)
|
||||||
@@ -300,6 +300,81 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
assert.False(t, sessions.Next())
|
assert.False(t, sessions.Next())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Session", func(t *testing.T) {
|
||||||
|
t.Run("should create new anonymous session if current session gets outdated", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
d, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
userId := uuid.New()
|
||||||
|
sessionId := "session-id"
|
||||||
|
|
||||||
|
_, err := d.Exec(`
|
||||||
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
_, err = d.Exec(`
|
||||||
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
|
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", basePath, nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
newSession := findCookie(resp, "id")
|
||||||
|
assert.NotNil(t, newSession)
|
||||||
|
assert.NotEqual(t, sessionId, newSession.Value)
|
||||||
|
|
||||||
|
var rows int
|
||||||
|
err = d.QueryRow("SELECT COUNT(*) FROM session WHERE user_id = ?", userId).Scan(&rows)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, 0, rows)
|
||||||
|
})
|
||||||
|
t.Run("should create anonymous session", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
_, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", basePath, nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
newSession := findCookie(resp, "id")
|
||||||
|
assert.NotNil(t, newSession)
|
||||||
|
assert.NotEqual(t, "", newSession.Value)
|
||||||
|
})
|
||||||
|
t.Run("should not have access to user information with outdated session", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
d, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
userId := uuid.New()
|
||||||
|
sessionId := "session-id"
|
||||||
|
|
||||||
|
_, err := d.Exec(`
|
||||||
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, userId, []byte("pass"), []byte("salt"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
_, err = d.Exec(`
|
||||||
|
INSERT INTO session (session_id, user_id, created_at, expires_at)
|
||||||
|
VALUES (?, ?, datetime("now", "-8 hour"), datetime("now", "-1 minute"))`, sessionId, userId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/workout", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Cookie", "id="+sessionId)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusSeeOther, resp.StatusCode)
|
||||||
|
assert.Equal(t, "/auth/signin", resp.Header.Get("Location"))
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func findCookie(resp *http.Response, name string) *http.Cookie {
|
func findCookie(resp *http.Response, name string) *http.Cookie {
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.
|
|||||||
return nil, nil, types.ErrInternal
|
return nil, nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
if session.ExpiresAt.Before(service.clock.Now()) {
|
if session.ExpiresAt.Before(service.clock.Now()) {
|
||||||
|
_ = service.db.DeleteSession(sessionId)
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,6 +121,8 @@ func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
|
|||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info("Anonymous session created: %v", session.Id)
|
||||||
|
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,7 +133,6 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = service.db.DeleteOldSessions(userId)
|
err = service.db.DeleteOldSessions(userId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, types.ErrInternal
|
return nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ func (r *RandomImpl) Bytes(size int) ([]byte, error) {
|
|||||||
func (r *RandomImpl) String(size int) (string, error) {
|
func (r *RandomImpl) String(size int) (string, error) {
|
||||||
bytes, err := r.Bytes(size)
|
bytes, err := r.Bytes(size)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Error("Error generating random string: %v", err)
|
||||||
return "", types.ErrInternal
|
return "", types.ErrInternal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user