chore(auth): #331 add tests for sign in
This commit is contained in:
@@ -397,7 +397,6 @@ func (db AuthSqlite) DeleteOldSessions(userId uuid.UUID) error {
|
|||||||
|
|
||||||
func (db AuthSqlite) DeleteSession(sessionId string) error {
|
func (db AuthSqlite) DeleteSession(sessionId string) error {
|
||||||
if sessionId != "" {
|
if sessionId != "" {
|
||||||
|
|
||||||
_, err := db.db.Exec("DELETE FROM session WHERE session_id = ?", sessionId)
|
_, err := db.db.Exec("DELETE FROM session WHERE session_id = ?", sessionId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Could not delete session: %v", err)
|
log.Error("Could not delete session: %v", err)
|
||||||
|
|||||||
@@ -77,11 +77,13 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
|
|||||||
|
|
||||||
func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
func (handler AuthImpl) handleSignIn() http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) {
|
|
||||||
var email = r.FormValue("email")
|
|
||||||
var password = r.FormValue("password")
|
|
||||||
|
|
||||||
session, user, err := handler.service.SignIn(email, password)
|
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) {
|
||||||
|
session := middleware.GetSession(r)
|
||||||
|
email := r.FormValue("email")
|
||||||
|
password := r.FormValue("password")
|
||||||
|
|
||||||
|
session, user, err := handler.service.SignIn(session, email, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
160
main_test.go
160
main_test.go
@@ -294,6 +294,166 @@ func TestIntegrationAuth(t *testing.T) {
|
|||||||
|
|
||||||
assert.NotEqual(t, anonymousSession.Value, cookie.Value, "Session ID did not change")
|
assert.NotEqual(t, anonymousSession.Value, cookie.Value, "Session ID did not change")
|
||||||
})
|
})
|
||||||
|
t.Run("should return in ~250 ms in all cases", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
_, err := db.Exec(`
|
||||||
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
// Everythings correct
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
body, err := html.Parse(resp.Body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
anonymousCsrfToken := findCsrfToken(body)
|
||||||
|
assert.NotEqual(t, "", anonymousCsrfToken)
|
||||||
|
anonymousSession := findCookie(resp, "id")
|
||||||
|
assert.NotNil(t, anonymousSession)
|
||||||
|
|
||||||
|
formData := url.Values{
|
||||||
|
"email": {"mail@mail.de"},
|
||||||
|
"password": {"password"},
|
||||||
|
"csrf-token": {anonymousCsrfToken},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+anonymousSession.Value)
|
||||||
|
req.Header.Set("HX-Request", "true")
|
||||||
|
|
||||||
|
timeStart := time.Now()
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
timeEnd := time.Now()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
if timeEnd.Sub(timeStart) > 253*time.Millisecond || timeEnd.Sub(timeStart) <= 250*time.Millisecond {
|
||||||
|
t.Fail()
|
||||||
|
t.Logf("Time did not match: %v", timeEnd.Sub(timeStart))
|
||||||
|
}
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
//Wrong password
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
body, err = html.Parse(resp.Body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
anonymousCsrfToken = findCsrfToken(body)
|
||||||
|
assert.NotEqual(t, "", anonymousCsrfToken)
|
||||||
|
anonymousSession = findCookie(resp, "id")
|
||||||
|
assert.NotNil(t, anonymousSession)
|
||||||
|
|
||||||
|
formData = url.Values{
|
||||||
|
"email": {"mail@mail.de"},
|
||||||
|
"password": {"wrong-password"},
|
||||||
|
"csrf-token": {anonymousCsrfToken},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+anonymousSession.Value)
|
||||||
|
req.Header.Set("HX-Request", "true")
|
||||||
|
|
||||||
|
timeStart = time.Now()
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
timeEnd = time.Now()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
if timeEnd.Sub(timeStart) > 253*time.Millisecond || timeEnd.Sub(timeStart) <= 250*time.Millisecond {
|
||||||
|
t.Fail()
|
||||||
|
t.Logf("Time did not match: %v", timeEnd.Sub(timeStart))
|
||||||
|
}
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
|
||||||
|
//Wrong username
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
body, err = html.Parse(resp.Body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
anonymousCsrfToken = findCsrfToken(body)
|
||||||
|
assert.NotEqual(t, "", anonymousCsrfToken)
|
||||||
|
anonymousSession = findCookie(resp, "id")
|
||||||
|
assert.NotNil(t, anonymousSession)
|
||||||
|
formData = url.Values{
|
||||||
|
|
||||||
|
"email": {"invalid-mail@mail.de"},
|
||||||
|
"password": {"password"},
|
||||||
|
"csrf-token": {anonymousCsrfToken},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+anonymousSession.Value)
|
||||||
|
req.Header.Set("HX-Request", "true")
|
||||||
|
|
||||||
|
timeStart = time.Now()
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
timeEnd = time.Now()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
if timeEnd.Sub(timeStart) > 253*time.Millisecond || timeEnd.Sub(timeStart) <= 250*time.Millisecond {
|
||||||
|
t.Fail()
|
||||||
|
t.Logf("Time did not match: %v", timeEnd.Sub(timeStart))
|
||||||
|
}
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
})
|
||||||
|
t.Run("should create new session and invalidate old one (session fixation prevention)", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, basePath, ctx := setupIntegrationTest(t)
|
||||||
|
|
||||||
|
pass := service.GetHashPassword("password", []byte("salt"))
|
||||||
|
_, err := db.Exec(`
|
||||||
|
INSERT INTO user (user_id, email, email_verified, is_admin, password, salt, created_at)
|
||||||
|
VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", basePath+"/auth/signin", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
html, err := html.Parse(resp.Body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
anonymousCsrfToken := findCsrfToken(html)
|
||||||
|
assert.NotEqual(t, "", anonymousCsrfToken)
|
||||||
|
anonymousSession := findCookie(resp, "id")
|
||||||
|
assert.NotNil(t, anonymousSession)
|
||||||
|
|
||||||
|
formData := url.Values{
|
||||||
|
"email": {"mail@mail.de"},
|
||||||
|
"password": {"password"},
|
||||||
|
"csrf-token": {anonymousCsrfToken},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "POST", basePath+"/api/auth/signin", strings.NewReader(formData.Encode()))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Cookie", "id="+anonymousSession.Value)
|
||||||
|
req.Header.Set("HX-Request", "true")
|
||||||
|
|
||||||
|
resp, err = httpClient.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var rows int
|
||||||
|
err = db.QueryRow("SELECT COUNT(*) FROM session WHERE session_id = ?", anonymousSession.Value).Scan(&rows)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, 0, rows)
|
||||||
|
err = db.QueryRow("SELECT COUNT(*) FROM token WHERE token = ?", anonymousCsrfToken).Scan(&rows)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, 0, rows)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
t.Run("SignOut", func(t *testing.T) {
|
t.Run("SignOut", func(t *testing.T) {
|
||||||
t.Run("should fail if csrf token is not valid", func(t *testing.T) {
|
t.Run("should fail if csrf token is not valid", func(t *testing.T) {
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ type Auth interface {
|
|||||||
SendVerificationMail(userId uuid.UUID, email string)
|
SendVerificationMail(userId uuid.UUID, email string)
|
||||||
VerifyUserEmail(token string) error
|
VerifyUserEmail(token string) error
|
||||||
|
|
||||||
SignIn(email string, password string) (*types.Session, *types.User, error)
|
SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error)
|
||||||
SignInSession(sessionId string) (*types.Session, *types.User, error)
|
SignInSession(sessionId string) (*types.Session, *types.User, error)
|
||||||
SignInAnonymous() (*types.Session, error)
|
SignInAnonymous() (*types.Session, error)
|
||||||
SignOut(sessionId string) error
|
SignOut(sessionId string) error
|
||||||
@@ -65,7 +65,7 @@ func NewAuthImpl(db db.Auth, random Random, clock Clock, mail Mail, serverSettin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignIn(email string, password string) (*types.Session, *types.User, error) {
|
func (service AuthImpl) SignIn(session *types.Session, email string, password string) (*types.Session, *types.User, error) {
|
||||||
user, err := service.db.GetUserByEmail(email)
|
user, err := service.db.GetUserByEmail(email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, db.ErrNotFound) {
|
if errors.Is(err, db.ErrNotFound) {
|
||||||
@@ -81,7 +81,12 @@ func (service AuthImpl) SignIn(email string, password string) (*types.Session, *
|
|||||||
return nil, nil, ErrInvalidCredentials
|
return nil, nil, ErrInvalidCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := service.createSession(user.Id)
|
err = service.cleanUpSessionWithTokens(session)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, types.ErrInternal
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err = service.createSession(user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, types.ErrInternal
|
return nil, nil, types.ErrInternal
|
||||||
}
|
}
|
||||||
@@ -89,6 +94,30 @@ func (service AuthImpl) SignIn(email string, password string) (*types.Session, *
|
|||||||
return session, user, nil
|
return session, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (service AuthImpl) cleanUpSessionWithTokens(session *types.Session) error {
|
||||||
|
if session == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := service.db.DeleteSession(session.Id)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
for _, token := range tokens {
|
||||||
|
err = service.db.DeleteToken(token.Token)
|
||||||
|
if err != nil {
|
||||||
|
return types.ErrInternal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
|
func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
|
||||||
if sessionId == "" {
|
if sessionId == "" {
|
||||||
return nil, nil, ErrSessionIdInvalid
|
return nil, nil, ErrSessionIdInvalid
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ func TestSignIn(t *testing.T) {
|
|||||||
|
|
||||||
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
||||||
|
|
||||||
actualSession, actualUser, err := underTest.SignIn(user.Email, "password")
|
actualSession, actualUser, err := underTest.SignIn(nil, user.Email, "password")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
assert.Equal(t, session, actualSession)
|
assert.Equal(t, session, actualSession)
|
||||||
@@ -78,7 +78,7 @@ func TestSignIn(t *testing.T) {
|
|||||||
|
|
||||||
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
||||||
|
|
||||||
_, _, err := underTest.SignIn("test@test.de", "wrong password")
|
_, _, err := underTest.SignIn(nil, "test@test.de", "wrong password")
|
||||||
|
|
||||||
assert.Equal(t, ErrInvalidCredentials, err)
|
assert.Equal(t, ErrInvalidCredentials, err)
|
||||||
})
|
})
|
||||||
@@ -93,7 +93,7 @@ func TestSignIn(t *testing.T) {
|
|||||||
|
|
||||||
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
||||||
|
|
||||||
_, _, err := underTest.SignIn("test", "test")
|
_, _, err := underTest.SignIn(nil, "test", "test")
|
||||||
assert.Equal(t, ErrInvalidCredentials, err)
|
assert.Equal(t, ErrInvalidCredentials, err)
|
||||||
})
|
})
|
||||||
t.Run("should forward ErrInternal on any other error", func(t *testing.T) {
|
t.Run("should forward ErrInternal on any other error", func(t *testing.T) {
|
||||||
@@ -107,7 +107,7 @@ func TestSignIn(t *testing.T) {
|
|||||||
|
|
||||||
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
|
||||||
|
|
||||||
_, _, err := underTest.SignIn("test", "test")
|
_, _, err := underTest.SignIn(nil, "test", "test")
|
||||||
|
|
||||||
assert.Equal(t, types.ErrInternal, err)
|
assert.Equal(t, types.ErrInternal, err)
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user