feat(security): #286 implement csrf middleware

This commit is contained in:
2024-12-04 23:15:40 +01:00
parent bbcdbf7a01
commit 57989c9b03
18 changed files with 484 additions and 204 deletions

View File

@@ -42,6 +42,7 @@ func NewUser(user *db.User) *User {
type Session struct {
Id string
CreatedAt time.Time
ExpiresAt time.Time
User *User
}
@@ -49,6 +50,7 @@ func NewSession(session *db.Session, user *User) *Session {
return &Session{
Id: session.Id,
CreatedAt: session.CreatedAt,
ExpiresAt: session.ExpiresAt,
User: user,
}
}
@@ -59,6 +61,7 @@ type Auth interface {
VerifyUserEmail(token string) error
SignIn(email string, password string) (*Session, error)
SignInSession(sessionId string) (*Session, error)
SignOut(sessionId string) error
DeleteAccount(user *User) error
@@ -68,7 +71,8 @@ type Auth interface {
SendForgotPasswordMail(email string) error
ForgotPassword(token string, newPass string) error
GetUserFromSessionId(sessionId string) (*User, error)
IsCsrfTokenValid(tokenStr string, sessionId string) bool
GetCsrfToken(session *Session) (string, error)
}
type AuthImpl struct {
@@ -113,6 +117,31 @@ func (service AuthImpl) SignIn(email string, password string) (*Session, error)
return NewSession(session, NewUser(user)), nil
}
func (service AuthImpl) SignInSession(sessionId string) (*Session, error) {
if sessionId == "" {
return nil, ErrSessionIdInvalid
}
sessionDb, err := service.db.GetSession(sessionId)
if err != nil {
return nil, types.ErrInternal
}
if sessionDb.ExpiresAt.After(service.clock.Now()) {
return nil, nil
}
userDb, err := service.db.GetUser(sessionDb.UserId)
if err != nil {
return nil, types.ErrInternal
}
user := NewUser(userDb)
session := NewSession(sessionDb, user)
return session, nil
}
func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) {
sessionId, err := service.random.String(32)
if err != nil {
@@ -125,7 +154,10 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) {
return nil, types.ErrInternal
}
session := db.NewSession(sessionId, userId, service.clock.Now())
createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour)
session := db.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session)
if err != nil {
@@ -161,7 +193,7 @@ func (service AuthImpl) SignUp(email string, password string) (*User, error) {
err = service.db.InsertUser(dbUser)
if err != nil {
if err == db.ErrUserExists {
if err == db.ErrAlreadyExists {
return nil, ErrAccountExists
} else {
return nil, types.ErrInternal
@@ -190,7 +222,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
return
}
token = db.NewToken(userId, newTokenStr, db.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour))
token = db.NewToken(userId, "", newTokenStr, db.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token)
if err != nil {
@@ -251,28 +283,6 @@ func (service AuthImpl) SignOut(sessionId string) error {
return service.db.DeleteSession(sessionId)
}
func (service AuthImpl) GetUserFromSessionId(sessionId string) (*User, error) {
if sessionId == "" {
return nil, ErrSessionIdInvalid
}
session, err := service.db.GetSession(sessionId)
if err != nil {
return nil, types.ErrInternal
}
user, err := service.db.GetUser(session.UserId)
if err != nil {
return nil, types.ErrInternal
}
if session.CreatedAt.Add(time.Duration(8 * time.Hour)).Before(service.clock.Now()) {
return nil, nil
} else {
return NewUser(user), nil
}
}
func (service AuthImpl) DeleteAccount(user *User) error {
err := service.db.DeleteUser(user.Id)
@@ -333,7 +343,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
}
}
token := db.NewToken(user.Id, tokenStr, db.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute))
token := db.NewToken(user.Id, "", tokenStr, db.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(token)
if err != nil {
@@ -384,6 +394,43 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
return nil
}
func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool {
token, err := service.db.GetToken(tokenStr)
if err != nil {
return false
}
if token.Type != db.TokenTypeCsrf ||
token.SessionId != sessionId ||
token.ExpiresAt.Before(service.clock.Now()) {
return false
}
return true
}
func (service AuthImpl) GetCsrfToken(session *Session) (string, error) {
tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, db.TokenTypeCsrf)
if len(tokens) > 0 {
return tokens[0].Token, nil
}
tokenStr, err := service.random.String(32)
if err != nil {
return "", types.ErrInternal
}
token := db.NewToken(uuid.Nil, session.Id, tokenStr, db.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token)
if err != nil {
return "", types.ErrInternal
}
return tokenStr, nil
}
func GetHashPassword(password string, salt []byte) []byte {
return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16)
}

View File

@@ -33,7 +33,7 @@ func TestSignIn(t *testing.T) {
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
)
dbSession := db.NewSession("sessionId", user.Id, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC))
dbSession := db.NewSession("sessionId", user.Id, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
mockAuthDb := mocks.NewMockAuth(t)
mockAuthDb.EXPECT().GetUserByEmail("test@test.de").Return(user, nil)
@@ -212,7 +212,7 @@ func TestSignUp(t *testing.T) {
mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(db.NewUser(user.Id, user.Email, false, nil, false, GetHashPassword(password, salt), salt, createTime)).Return(db.ErrUserExists)
mockAuthDb.EXPECT().InsertUser(db.NewUser(user.Id, user.Email, false, nil, false, GetHashPassword(password, salt), salt, createTime)).Return(db.ErrAlreadyExists)
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
@@ -227,7 +227,7 @@ func TestSendVerificationMail(t *testing.T) {
t.Run("should use stored token and send mail", func(t *testing.T) {
t.Parallel()
token := db.NewToken(uuid.New(), "someRandomTokenToUse", db.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
token := db.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", db.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
tokens := []*db.Token{token}
email := "some@email.de"