From 04c6d0e71d06e090d6a3c9dff92ac9d24b06469e Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Fri, 6 Dec 2024 21:54:59 +0100 Subject: [PATCH] tbs --- db/auth.go | 75 +++++++++---- db/auth_test.go | 104 +++++++++++++++--- .../middleware/cross_site_request_forgery.go | 7 +- service/auth.go | 47 +++++++- service/auth_test.go | 4 +- 5 files changed, 192 insertions(+), 45 deletions(-) diff --git a/db/auth.go b/db/auth.go index 1ab40ad..579faad 100644 --- a/db/auth.go +++ b/db/auth.go @@ -13,8 +13,8 @@ import ( ) var ( - ErrNotFound = errors.New("value not found") - ErrUserExists = errors.New("user already exists") + ErrNotFound = errors.New("value not found") + ErrAlreadyExists = errors.New("row already exists") ) type User struct { @@ -59,20 +59,25 @@ func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time type Token struct { UserId uuid.UUID + SessionId string Token string - Type string + Type TokenType CreatedAt time.Time ExpiresAt time.Time } +type TokenType string + var ( - TokenTypeEmailVerify = "email_verify" - TokenTypePasswordReset = "password_reset" + TokenTypeEmailVerify TokenType = "email_verify" + TokenTypePasswordReset TokenType = "password_reset" + TokenTypeCsrf TokenType = "csrf" ) -func NewToken(userId uuid.UUID, token string, tokenType string, createdAt time.Time, expiresAt time.Time) *Token { +func NewToken(userId uuid.UUID, sessionId string, token string, tokenType TokenType, createdAt time.Time, expiresAt time.Time) *Token { return &Token{ UserId: userId, + SessionId: sessionId, Token: token, Type: tokenType, CreatedAt: createdAt, @@ -89,7 +94,8 @@ type Auth interface { InsertToken(token *Token) error GetToken(token string) (*Token, error) - GetTokensByUserIdAndType(userId uuid.UUID, tokenType string) ([]*Token, error) + GetTokensByUserIdAndType(userId uuid.UUID, tokenType TokenType) ([]*Token, error) + GetTokensBySessionIdAndType(sessionId string, tokenType TokenType) ([]*Token, error) DeleteToken(token string) error InsertSession(session *Session) error @@ -114,7 +120,7 @@ func (db AuthSqlite) InsertUser(user *User) error { if err != nil { if strings.Contains(err.Error(), "email") { - return ErrUserExists + return ErrAlreadyExists } log.Error("SQL error InsertUser: %v", err) @@ -208,7 +214,7 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { return types.ErrInternal } - _, err = tx.Exec("DELETE FROM user_token WHERE user_id = ?", userId) + _, err = tx.Exec("DELETE FROM token WHERE user_id = ?", userId) if err != nil { _ = tx.Rollback() log.Error("Could not delete user tokens: %v", err) @@ -240,8 +246,8 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { func (db AuthSqlite) InsertToken(token *Token) error { _, err := db.db.Exec(` - INSERT INTO user_token (user_id, type, token, created_at, expires_at) - VALUES (?, ?, ?, ?, ?)`, token.UserId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt) + INSERT INTO token (user_id, session_id, type, token, created_at, expires_at) + VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt) if err != nil { log.Error("Could not insert token: %v", err) @@ -254,7 +260,8 @@ func (db AuthSqlite) InsertToken(token *Token) error { func (db AuthSqlite) GetToken(token string) (*Token, error) { var ( userId uuid.UUID - tokenType string + sessionId string + tokenType TokenType createdAtStr string expiresAtStr string createdAt time.Time @@ -262,10 +269,9 @@ func (db AuthSqlite) GetToken(token string) (*Token, error) { ) err := db.db.QueryRow(` - SELECT user_id, type, created_at, expires_at - FROM user_token - WHERE token = ? - AND type = 'email_verify'`, token).Scan(&userId, &tokenType, &createdAtStr, &expiresAtStr) + SELECT user_id, session_id, type, created_at, expires_at + FROM token + WHERE token = ?`, token).Scan(&userId, &sessionId, &tokenType, &createdAtStr, &expiresAtStr) if err != nil { if err == sql.ErrNoRows { @@ -289,14 +295,14 @@ func (db AuthSqlite) GetToken(token string) (*Token, error) { return nil, types.ErrInternal } - return NewToken(userId, token, tokenType, createdAt, expiresAt), nil + return NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil } -func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType string) ([]*Token, error) { +func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType TokenType) ([]*Token, error) { query, err := db.db.Query(` SELECT token, created_at, expires_at - FROM user_token + FROM token WHERE user_id = ? AND type = ?`, userId, tokenType) @@ -305,9 +311,32 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType string return nil, types.ErrInternal } + return getTokensFromQuery(query, userId, "", tokenType) +} + +func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType TokenType) ([]*Token, error) { + + query, err := db.db.Query(` + SELECT token, created_at, expires_at + FROM token + WHERE session_id = ? + AND type = ?`, sessionId, tokenType) + + if err != nil { + log.Error("Could not get token: %v", err) + return nil, types.ErrInternal + } + + return getTokensFromQuery(query, uuid.Nil, sessionId, tokenType) +} + +func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tokenType TokenType) ([]*Token, error) { var tokens []*Token + hasRows := false for query.Next() { + hasRows = true + var ( token string createdAtStr string @@ -334,14 +363,18 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType string return nil, types.ErrInternal } - tokens = append(tokens, NewToken(userId, token, tokenType, createdAt, expiresAt)) + tokens = append(tokens, NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt)) + } + + if !hasRows { + return nil, ErrNotFound } return tokens, nil } func (db AuthSqlite) DeleteToken(token string) error { - _, err := db.db.Exec("DELETE FROM user_token WHERE token = ?", token) + _, err := db.db.Exec("DELETE FROM token WHERE token = ?", token) if err != nil { log.Error("Could not delete token: %v", err) return types.ErrInternal diff --git a/db/auth_test.go b/db/auth_test.go index 810712b..d096c38 100644 --- a/db/auth_test.go +++ b/db/auth_test.go @@ -51,7 +51,7 @@ func TestUser(t *testing.T) { assert.Nil(t, err) assert.Equal(t, expected, actual) }) - t.Run("should return UserNotFound", func(t *testing.T) { + t.Run("should return ErrNotFound", func(t *testing.T) { t.Parallel() db := setupDb(t) @@ -74,7 +74,7 @@ func TestUser(t *testing.T) { assert.Nil(t, err) err = underTest.InsertUser(user) - assert.Equal(t, ErrUserExists, err) + assert.Equal(t, ErrAlreadyExists, err) }) t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { t.Parallel() @@ -90,38 +90,110 @@ func TestUser(t *testing.T) { }) } -func TestEmailVerification(t *testing.T) { +func TestToken(t *testing.T) { t.Parallel() - t.Run("should return NotFound", func(t *testing.T) { + t.Run("should insert and get the same", func(t *testing.T) { t.Parallel() db := setupDb(t) underTest := AuthSqlite{db: db} - token, err := underTest.GetToken("someNonExistentToken") + createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) + expiresAt := createAt.Add(24 * time.Hour) + expected := NewToken(uuid.New(), "sessionId", "token", TokenTypeCsrf, createAt, expiresAt) - assert.Equal(t, ErrNotFound, err) - assert.Nil(t, token) + err := underTest.InsertToken(expected) + assert.Nil(t, err) + + actual, err := underTest.GetToken(expected.Token) + assert.Nil(t, err) + assert.Equal(t, expected, actual) + + expected.SessionId = "" + actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type) + assert.Nil(t, err) + assert.Equal(t, []*Token{expected}, actuals) + + expected.SessionId = "sessionId" + expected.UserId = uuid.Nil + actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type) + assert.Nil(t, err) + assert.Equal(t, []*Token{expected}, actuals) }) - t.Run("should insert and return token", func(t *testing.T) { + t.Run("should insert and return multiple tokens", func(t *testing.T) { t.Parallel() db := setupDb(t) underTest := AuthSqlite{db: db} - tokenStr := "some secure token" - createdAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) - expectedToken := NewToken(uuid.New(), tokenStr, TokenTypeEmailVerify, createdAt, createdAt.Add(24*time.Hour)) + createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) + expiresAt := createAt.Add(24 * time.Hour) + userId := uuid.New() + expected1 := NewToken(userId, "sessionId", "token1", TokenTypeCsrf, createAt, expiresAt) + expected2 := NewToken(userId, "sessionId", "token2", TokenTypeCsrf, createAt, expiresAt) - err := underTest.InsertToken(expectedToken) + err := underTest.InsertToken(expected1) + assert.Nil(t, err) + err = underTest.InsertToken(expected2) assert.Nil(t, err) - actualToken, err := underTest.GetToken(tokenStr) + expected1.UserId = uuid.Nil + expected2.UserId = uuid.Nil + actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type) + assert.Nil(t, err) + assert.Equal(t, []*Token{expected1, expected2}, actuals) + + expected1.SessionId = "" + expected2.SessionId = "" + expected1.UserId = userId + expected2.UserId = userId + actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type) + assert.Nil(t, err) + assert.Equal(t, []*Token{expected1, expected2}, actuals) + + }) + t.Run("should return ErrNotFound", func(t *testing.T) { + t.Parallel() + db := setupDb(t) + + underTest := AuthSqlite{db: db} + + _, err := underTest.GetToken("nonExistent") + assert.Equal(t, ErrNotFound, err) + + _, err = underTest.GetTokensByUserIdAndType(uuid.New(), TokenTypeEmailVerify) + assert.Equal(t, ErrNotFound, err) + + _, err = underTest.GetTokensBySessionIdAndType("sessionId", TokenTypeEmailVerify) + assert.Equal(t, ErrNotFound, err) + }) + t.Run("should return ErrAlreadyExists", func(t *testing.T) { + t.Parallel() + db := setupDb(t) + + underTest := AuthSqlite{db: db} + + verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) + createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) + user := NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + + err := underTest.InsertUser(user) assert.Nil(t, err) - t.Logf("expectedToken: %v", expectedToken) - t.Logf("actualToken: %v", actualToken) - assert.Equal(t, expectedToken, actualToken) + err = underTest.InsertUser(user) + assert.Equal(t, ErrAlreadyExists, err) + }) + t.Run("should return ErrInternal on missing NOT NULL fields", func(t *testing.T) { + t.Parallel() + db := setupDb(t) + + underTest := AuthSqlite{db: db} + + createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) + user := NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) + + err := underTest.InsertUser(user) + assert.Equal(t, types.ErrInternal, err) }) } diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 3eb9d41..cfb45e7 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -11,12 +11,17 @@ func CrossSiteRequestForgery(auth *service.Auth) func(http.Handler) http.Handler // session := r.Context().Value(SessionKey) - if r.Method == "POST" { + if r.Method == http.MethodPost || + r.Method == http.MethodPut || + r.Method == http.MethodDelete || + r.Method == http.MethodPatch { + csrfToken := r.FormValue("csrf-token") if csrfToken == "" { http.Error(w, "", http.StatusForbidden) return } + } next.ServeHTTP(w, r) diff --git a/service/auth.go b/service/auth.go index eb565d9..5f71614 100644 --- a/service/auth.go +++ b/service/auth.go @@ -71,8 +71,8 @@ type Auth interface { SendForgotPasswordMail(email string) error ForgotPassword(token string, newPass string) error - // IsCsrfTokenValid(token string, user *User) bool - // GetCsrfToken(token string, user *User) bool + IsCsrfTokenValid(tokenStr string, userId uuid.UUID) bool + GetCsrfToken(session *Session) (string, error) } type AuthImpl struct { @@ -193,7 +193,7 @@ func (service AuthImpl) SignUp(email string, password string) (*User, error) { err = service.db.InsertUser(dbUser) if err != nil { - if err == db.ErrUserExists { + if err == db.ErrAlreadyExists { return nil, ErrAccountExists } else { return nil, types.ErrInternal @@ -222,7 +222,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { return } - token = db.NewToken(userId, newTokenStr, db.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) + token = db.NewToken(userId, "", newTokenStr, db.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) err = service.db.InsertToken(token) if err != nil { @@ -343,7 +343,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error { } } - token := db.NewToken(user.Id, tokenStr, db.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute)) + token := db.NewToken(user.Id, "", tokenStr, db.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute)) err = service.db.InsertToken(token) if err != nil { @@ -394,6 +394,43 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { return nil } +func (service AuthImpl) IsCsrfTokenValid(tokenStr string, userId uuid.UUID) bool { + token, err := service.db.GetToken(tokenStr) + if err != nil { + return false + } + + if token.Type != db.TokenTypeCsrf || + token.UserId != userId || + token.ExpiresAt.Before(service.clock.Now()) { + + return false + } + + return true +} + +func (service AuthImpl) GetCsrfToken(session *Session) (string, error) { + tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, db.TokenTypeCsrf) + + if len(tokens) > 0 { + return tokens[0].Token, nil + } + + tokenStr, err := service.random.String(32) + if err != nil { + return "", types.ErrInternal + } + + token := db.NewToken(uuid.Nil, session.Id, tokenStr, db.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) + err = service.db.InsertToken(token) + if err != nil { + return "", types.ErrInternal + } + + return tokenStr, nil +} + func GetHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) } diff --git a/service/auth_test.go b/service/auth_test.go index ddb2f33..96ac697 100644 --- a/service/auth_test.go +++ b/service/auth_test.go @@ -212,7 +212,7 @@ func TestSignUp(t *testing.T) { mockClock.EXPECT().Now().Return(createTime) - mockAuthDb.EXPECT().InsertUser(db.NewUser(user.Id, user.Email, false, nil, false, GetHashPassword(password, salt), salt, createTime)).Return(db.ErrUserExists) + mockAuthDb.EXPECT().InsertUser(db.NewUser(user.Id, user.Email, false, nil, false, GetHashPassword(password, salt), salt, createTime)).Return(db.ErrAlreadyExists) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) @@ -227,7 +227,7 @@ func TestSendVerificationMail(t *testing.T) { t.Run("should use stored token and send mail", func(t *testing.T) { t.Parallel() - token := db.NewToken(uuid.New(), "someRandomTokenToUse", db.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) + token := db.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", db.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) tokens := []*db.Token{token} email := "some@email.de"