diff --git a/db/auth.go b/db/auth.go index b755e67..0098661 100644 --- a/db/auth.go +++ b/db/auth.go @@ -397,7 +397,6 @@ func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error { func (db AuthSqlite) DeleteSession(sessionId string) error { if sessionId != "" { - _, err := db.db.Exec("DELETE FROM session WHERE session_id = ?", sessionId) if err != nil { log.Error("Could not delete session: %v", err) diff --git a/handler/auth.go b/handler/auth.go index 43e8a36..49a60d2 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -77,11 +77,13 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc { func (handler AuthImpl) handleSignIn() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) { - var email = r.FormValue("email") - var password = r.FormValue("password") - session, user, err := handler.service.SignIn(email, password) + user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) { + session := middleware.GetSession(r) + email := r.FormValue("email") + password := r.FormValue("password") + + session, user, err := handler.service.SignIn(session, email, password) if err != nil { return nil, err } diff --git a/main_test.go b/main_test.go index a15efdb..79d0930 100644 --- a/main_test.go +++ b/main_test.go @@ -294,6 +294,166 @@ func TestIntegrationAuth(t *testing.T) { assert.NotEqual(t, anonymousSession.Value, cookie.Value, "Session ID did not change") }) + t.Run("should return in ~250 ms in all cases", func(t *testing.T) { + t.Parallel() + + db, basePath, ctx := setupIntegrationTest(t) + + pass := service.GetHashPassword("password", []byte("salt")) + _, err := db.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) + assert.Nil(t, err) + + // Everythings correct + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) + assert.Nil(t, err) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + body, err := html.Parse(resp.Body) + assert.Nil(t, err) + anonymousCsrfToken := findCsrfToken(body) + assert.NotEqual(t, "", anonymousCsrfToken) + anonymousSession := findCookie(resp, "id") + assert.NotNil(t, anonymousSession) + + formData := url.Values{ + "email": {"mail@mail.de"}, + "password": {"password"}, + "csrf-token": {anonymousCsrfToken}, + } + + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+anonymousSession.Value) + req.Header.Set("HX-Request", "true") + + timeStart := time.Now() + resp, err = httpClient.Do(req) + timeEnd := time.Now() + assert.Nil(t, err) + if timeEnd.Sub(timeStart) > 253*time.Millisecond || timeEnd.Sub(timeStart) <= 250*time.Millisecond { + t.Fail() + t.Logf("Time did not match: %v", timeEnd.Sub(timeStart)) + } + assert.Equal(t, http.StatusOK, resp.StatusCode) + + //Wrong password + req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) + assert.Nil(t, err) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + body, err = html.Parse(resp.Body) + assert.Nil(t, err) + anonymousCsrfToken = findCsrfToken(body) + assert.NotEqual(t, "", anonymousCsrfToken) + anonymousSession = findCookie(resp, "id") + assert.NotNil(t, anonymousSession) + + formData = url.Values{ + "email": {"mail@mail.de"}, + "password": {"wrong-password"}, + "csrf-token": {anonymousCsrfToken}, + } + + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+anonymousSession.Value) + req.Header.Set("HX-Request", "true") + + timeStart = time.Now() + resp, err = httpClient.Do(req) + timeEnd = time.Now() + assert.Nil(t, err) + if timeEnd.Sub(timeStart) > 253*time.Millisecond || timeEnd.Sub(timeStart) <= 250*time.Millisecond { + t.Fail() + t.Logf("Time did not match: %v", timeEnd.Sub(timeStart)) + } + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + + //Wrong username + req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) + assert.Nil(t, err) + resp, err = httpClient.Do(req) + assert.Nil(t, err) + body, err = html.Parse(resp.Body) + assert.Nil(t, err) + anonymousCsrfToken = findCsrfToken(body) + assert.NotEqual(t, "", anonymousCsrfToken) + anonymousSession = findCookie(resp, "id") + assert.NotNil(t, anonymousSession) + formData = url.Values{ + + "email": {"invalid-mail@mail.de"}, + "password": {"password"}, + "csrf-token": {anonymousCsrfToken}, + } + + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+anonymousSession.Value) + req.Header.Set("HX-Request", "true") + + timeStart = time.Now() + resp, err = httpClient.Do(req) + timeEnd = time.Now() + assert.Nil(t, err) + if timeEnd.Sub(timeStart) > 253*time.Millisecond || timeEnd.Sub(timeStart) <= 250*time.Millisecond { + t.Fail() + t.Logf("Time did not match: %v", timeEnd.Sub(timeStart)) + } + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) + t.Run("should create new session and invalidate old one (session fixation prevention)", func(t *testing.T) { + t.Parallel() + + db, basePath, ctx := setupIntegrationTest(t) + + pass := service.GetHashPassword("password", []byte("salt")) + _, err := db.Exec(` + INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) + assert.Nil(t, err) + + req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil) + assert.Nil(t, err) + resp, err := httpClient.Do(req) + assert.Nil(t, err) + + html, err := html.Parse(resp.Body) + assert.Nil(t, err) + anonymousCsrfToken := findCsrfToken(html) + assert.NotEqual(t, "", anonymousCsrfToken) + anonymousSession := findCookie(resp, "id") + assert.NotNil(t, anonymousSession) + + formData := url.Values{ + "email": {"mail@mail.de"}, + "password": {"password"}, + "csrf-token": {anonymousCsrfToken}, + } + + req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", "id="+anonymousSession.Value) + req.Header.Set("HX-Request", "true") + + resp, err = httpClient.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var rows int + err = db.QueryRow("SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows) + assert.Nil(t, err) + assert.Equal(t, 0, rows) + err = db.QueryRow("SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows) + assert.Nil(t, err) + assert.Equal(t, 0, rows) + }) }) t.Run("SignOut", func(t *testing.T) { t.Run("should fail if csrf token is not valid", func(t *testing.T) { diff --git a/service/auth.go b/service/auth.go index f168e11..b5e6b51 100644 --- a/service/auth.go +++ b/service/auth.go @@ -31,7 +31,7 @@ type Auth interface { SendVerificationMail(userId uuid.UUID, email string) VerifyUserEmail(token string) error - SignIn(email string, password string) (*types.Session, *types.User, error) + SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) SignInSession(sessionId string) (*types.Session, *types.User, error) SignInAnonymous() (*types.Session, error) SignOut(sessionId string) error @@ -65,7 +65,7 @@ func NewAuthImpl(db db.Auth, random Random, clock Clock, mail Mail, serverSettin } } -func (service AuthImpl) SignIn(email string, password string) (*types.Session, *types.User, error) { +func (service AuthImpl) SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) { user, err := service.db.GetUserByEmail(email) if err != nil { if errors.Is(err, db.ErrNotFound) { @@ -81,7 +81,12 @@ func (service AuthImpl) SignIn(email string, password string) (*types.Session, * return nil, nil, ErrInvalidCredentials } - session, err := service.createSession(user.Id) + err = service.cleanUpSessionWithTokens(session) + if err != nil { + return nil, nil, types.ErrInternal + } + + session, err = service.createSession(user.Id) if err != nil { return nil, nil, types.ErrInternal } @@ -89,6 +94,30 @@ func (service AuthImpl) SignIn(email string, password string) (*types.Session, * return session, user, nil } +func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error { + if session == nil { + return nil + } + + err := service.db.DeleteSession(session.Id) + if err != nil { + return types.ErrInternal + } + + tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf) + if err != nil { + return types.ErrInternal + } + for _, token := range tokens { + err = service.db.DeleteToken(token.Token) + if err != nil { + return types.ErrInternal + } + } + + return nil +} + func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) { if sessionId == "" { return nil, nil, ErrSessionIdInvalid diff --git a/service/auth_test.go b/service/auth_test.go index e3a5464..f12789d 100644 --- a/service/auth_test.go +++ b/service/auth_test.go @@ -47,7 +47,7 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - actualSession, actualUser, err := underTest.SignIn(user.Email, "password") + actualSession, actualUser, err := underTest.SignIn(nil, user.Email, "password") assert.Nil(t, err) assert.Equal(t, session, actualSession) @@ -78,7 +78,7 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - _, _, err := underTest.SignIn("test@test.de", "wrong password") + _, _, err := underTest.SignIn(nil, "test@test.de", "wrong password") assert.Equal(t, ErrInvalidCredentials, err) }) @@ -93,7 +93,7 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - _, _, err := underTest.SignIn("test", "test") + _, _, err := underTest.SignIn(nil, "test", "test") assert.Equal(t, ErrInvalidCredentials, err) }) t.Run("should forward ErrInternal on any other error", func(t *testing.T) { @@ -107,7 +107,7 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - _, _, err := underTest.SignIn("test", "test") + _, _, err := underTest.SignIn(nil, "test", "test") assert.Equal(t, types.ErrInternal, err) })