From bc24d5a705f3c96d5e71ad14372bd46f6a38ca09 Mon Sep 17 00:00:00 2001 From: Tim Wundenberg Date: Sat, 5 Oct 2024 13:49:43 +0200 Subject: [PATCH] fix: create RandomGenerator interface and struct for testing purpose #181 --- .mockery.yaml | 3 ++ handler/default.go | 3 +- service/auth.go | 62 +++++++++++++++++-------------------- service/auth_test.go | 48 +++++++++++++++++++++++++--- service/random_generator.go | 48 ++++++++++++++++++++++++++++ utils/ctypto.go | 16 ---------- 6 files changed, 125 insertions(+), 55 deletions(-) create mode 100644 service/random_generator.go delete mode 100644 utils/ctypto.go diff --git a/.mockery.yaml b/.mockery.yaml index 9adcc90..b6a04fb 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -2,6 +2,9 @@ with-expecter: True dir: mocks/ outpkg: mocks packages: + me-fit/service: + interfaces: + RandomGenerator: me-fit/db: interfaces: DbAuth: diff --git a/handler/default.go b/handler/default.go index af9eb74..c19cc5f 100644 --- a/handler/default.go +++ b/handler/default.go @@ -15,8 +15,9 @@ func GetHandler(d *sql.DB, serverSettings *types.ServerSettings) http.Handler { router.HandleFunc("/", service.HandleIndexAnd404(d, serverSettings)) + randomGenerator := service.NewRandomGeneratorImpl() dbAuth := db.NewDbAuthSqlite(d) - serviceAuth := service.NewServiceAuthImpl(dbAuth, serverSettings) + serviceAuth := service.NewServiceAuthImpl(dbAuth, randomGenerator, serverSettings) handlerAuth := NewHandlerAuth(d, serviceAuth, serverSettings) // Serve static files (CSS, JS and images) diff --git a/service/auth.go b/service/auth.go index 10edde2..99dcda5 100644 --- a/service/auth.go +++ b/service/auth.go @@ -2,7 +2,6 @@ package service import ( "context" - "crypto/rand" "crypto/subtle" "database/sql" "errors" @@ -26,10 +25,10 @@ import ( ) var ( - ErrInvaidCredentials = errors.New("Invalid email or password") - ErrPasswordComplexity = errors.New("Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character") - ErrInvalidEmail = errors.New("Invalid email") - ErrAccountExists = errors.New("Account already exists") + ErrInvaidCredentials = errors.New("Invalid email or password") + ErrInvalidPassword = errors.New("Password needs to be 8 characters long, contain at least one number, one special, one uppercase and one lowercase character") + ErrInvalidEmail = errors.New("Invalid email") + ErrAccountExists = errors.New("Account already exists") ) type User struct { @@ -53,16 +52,18 @@ type ServiceAuth interface { } type ServiceAuthImpl struct { - dbAuth db.DbAuth - serverSettings *types.ServerSettings - mailService MailService + dbAuth db.DbAuth + randomGenerator RandomGenerator + serverSettings *types.ServerSettings + mailService MailService } -func NewServiceAuthImpl(dbAuth db.DbAuth, serverSettings *types.ServerSettings) *ServiceAuthImpl { +func NewServiceAuthImpl(dbAuth db.DbAuth, randomGenerator RandomGenerator, serverSettings *types.ServerSettings) *ServiceAuthImpl { return &ServiceAuthImpl{ - dbAuth: dbAuth, - serverSettings: serverSettings, - mailService: NewMailService(serverSettings), + dbAuth: dbAuth, + randomGenerator: randomGenerator, + serverSettings: serverSettings, + mailService: NewMailService(serverSettings), } } @@ -91,9 +92,8 @@ func (service ServiceAuthImpl) SignUp(email string, password string) (*User, err return nil, ErrInvalidEmail } - err = checkPassword(password) - if err != nil { - return nil, err + if !isPasswordValid(password) { + return nil, ErrInvalidPassword } userId, err := uuid.NewRandom() @@ -102,10 +102,8 @@ func (service ServiceAuthImpl) SignUp(email string, password string) (*User, err return nil, types.ErrInternal } - salt := make([]byte, 16) - _, err = rand.Read(salt) + salt, err := service.randomGenerator.Bytes(16) if err != nil { - utils.LogError("Could not generate salt", err) return nil, types.ErrInternal } @@ -134,9 +132,8 @@ func (service ServiceAuthImpl) SendVerificationMail(user *User) { } if token == "" { - token, err := utils.RandomToken() + token, err := service.randomGenerator.String(32) if err != nil { - utils.LogError("Could not generate token", err) return } @@ -407,9 +404,8 @@ func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc { currPass := r.FormValue("current-password") newPass := r.FormValue("new-password") - err := checkPassword(newPass) - if err != nil { - utils.TriggerToast(w, r, "error", err.Error()) + if !isPasswordValid(newPass) { + utils.TriggerToast(w, r, "error", ErrInvalidPassword.Error()) return } @@ -423,7 +419,7 @@ func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc { salt []byte ) - err = db.QueryRow("SELECT password, salt FROM user WHERE user_uuid = ?", user.Id).Scan(&storedHash, &salt) + err := db.QueryRow("SELECT password, salt FROM user WHERE user_uuid = ?", user.Id).Scan(&storedHash, &salt) if err != nil { utils.LogError("Could not get password", err) utils.TriggerToast(w, r, "error", "Internal Server Error") @@ -467,9 +463,8 @@ func HandleActualResetPasswordComp(db *sql.DB) http.HandlerFunc { newPass := r.FormValue("new-password") - err = checkPassword(newPass) - if err != nil { - utils.TriggerToast(w, r, "error", err.Error()) + if !isPasswordValid(newPass) { + utils.TriggerToast(w, r, "error", ErrInvalidPassword.Error()) return } @@ -511,6 +506,7 @@ func HandleActualResetPasswordComp(db *sql.DB) http.HandlerFunc { utils.TriggerToast(w, r, "success", "Password changed") } } + func HandleResetPasswordComp(db *sql.DB, serverSettings *types.ServerSettings) http.HandlerFunc { mailService := NewMailService(serverSettings) return func(w http.ResponseWriter, r *http.Request) { @@ -521,9 +517,8 @@ func HandleResetPasswordComp(db *sql.DB, serverSettings *types.ServerSettings) h return } - token, err := utils.RandomToken() + token, err := NewRandomGeneratorImpl().String(32) if err != nil { - utils.LogError("Could not generate token", err) return } @@ -562,9 +557,8 @@ func HandleResetPasswordComp(db *sql.DB, serverSettings *types.ServerSettings) h } func TryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sql.DB, user_uuid uuid.UUID) error { - sessionId, err := utils.RandomToken() + sessionId, err := NewRandomGeneratorImpl().String(32) if err != nil { - utils.LogError("Could not generate session ID", err) return types.ErrInternal } @@ -599,15 +593,15 @@ func GetHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) } -func checkPassword(password string) error { +func isPasswordValid(password string) bool { if len(password) < 8 || !strings.ContainsAny(password, "0123456789") || !strings.ContainsAny(password, "ABCDEFGHIJKLMNOPQRSTUVWXYZ") || !strings.ContainsAny(password, "abcdefghijklmnopqrstuvwxyz") || !strings.ContainsAny(password, "!@#$%^&*()_+-=[]{}\\|;:'\",.<>/?") { - return ErrPasswordComplexity + return false } else { - return nil + return true } } diff --git a/service/auth_test.go b/service/auth_test.go index 691fb7e..2b2ebb2 100644 --- a/service/auth_test.go +++ b/service/auth_test.go @@ -32,8 +32,9 @@ func TestSignIn(t *testing.T) { mockDbAuth := mocks.NewMockDbAuth(t) mockDbAuth.EXPECT().GetUser("test@test.de").Return(user, nil) + mockRandom := mocks.NewMockRandomGenerator(t) - underTest := NewServiceAuthImpl(mockDbAuth, &types.ServerSettings{}) + underTest := NewServiceAuthImpl(mockDbAuth, mockRandom, &types.ServerSettings{}) actualUser, err := underTest.SignIn(user.Email, "password") assert.Nil(t, err) @@ -65,8 +66,9 @@ func TestSignIn(t *testing.T) { mockDbAuth := mocks.NewMockDbAuth(t) mockDbAuth.EXPECT().GetUser(user.Email).Return(user, nil) + mockRandom := mocks.NewMockRandomGenerator(t) - underTest := NewServiceAuthImpl(mockDbAuth, &types.ServerSettings{}) + underTest := NewServiceAuthImpl(mockDbAuth, mockRandom, &types.ServerSettings{}) _, err := underTest.SignIn("test@test.de", "wrong password") @@ -77,8 +79,9 @@ func TestSignIn(t *testing.T) { mockDbAuth := mocks.NewMockDbAuth(t) mockDbAuth.EXPECT().GetUser("test").Return(nil, db.ErrUserNotFound) + mockRandom := mocks.NewMockRandomGenerator(t) - underTest := NewServiceAuthImpl(mockDbAuth, &types.ServerSettings{}) + underTest := NewServiceAuthImpl(mockDbAuth, mockRandom, &types.ServerSettings{}) _, err := underTest.SignIn("test", "test") assert.Equal(t, ErrInvaidCredentials, err) @@ -88,11 +91,48 @@ func TestSignIn(t *testing.T) { mockDbAuth := mocks.NewMockDbAuth(t) mockDbAuth.EXPECT().GetUser("test").Return(nil, errors.New("Some undefined error")) + mockRandom := mocks.NewMockRandomGenerator(t) - underTest := NewServiceAuthImpl(mockDbAuth, &types.ServerSettings{}) + underTest := NewServiceAuthImpl(mockDbAuth, mockRandom, &types.ServerSettings{}) _, err := underTest.SignIn("test", "test") assert.Equal(t, types.ErrInternal, err) }) } + +func TestSignUp(t *testing.T) { + t.Parallel() + t.Run("should check for correct email address", func(t *testing.T) { + t.Parallel() + + mockDbAuth := mocks.NewMockDbAuth(t) + mockRandom := mocks.NewMockRandomGenerator(t) + + underTest := NewServiceAuthImpl(mockDbAuth, mockRandom, &types.ServerSettings{}) + + _, err := underTest.SignUp("invalid email address", "SomeStrongPassword123!") + + assert.Equal(t, ErrInvalidEmail, err) + }) + t.Run("should check for password complexity", func(t *testing.T) { + t.Parallel() + + mockDbAuth := mocks.NewMockDbAuth(t) + mockRandom := mocks.NewMockRandomGenerator(t) + + underTest := NewServiceAuthImpl(mockDbAuth, mockRandom, &types.ServerSettings{}) + + weakPasswords := []string{ + "123!ab", // too short + "no_upper_case_123", + "NO_LOWER_CASE_123", + "noSpecialChar123", + } + + for _, password := range weakPasswords { + _, err := underTest.SignUp("some@valid.email", password) + assert.Equal(t, ErrInvalidPassword, err) + } + }) +} diff --git a/service/random_generator.go b/service/random_generator.go new file mode 100644 index 0000000..4e4e4c5 --- /dev/null +++ b/service/random_generator.go @@ -0,0 +1,48 @@ +package service + +import ( + "me-fit/types" + + "crypto/rand" + "encoding/base64" + "log/slog" + + "github.com/google/uuid" +) + +type RandomGenerator interface { + Bytes(size int) ([]byte, error) + String(size int) (string, error) + UUID() (uuid.UUID, error) +} + +type RandomGeneratorImpl struct { +} + +func NewRandomGeneratorImpl() *RandomGeneratorImpl { + return &RandomGeneratorImpl{} +} + +func (r *RandomGeneratorImpl) Bytes(size int) ([]byte, error) { + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + slog.Error("Error generating random bytes: " + err.Error()) + return []byte{}, types.ErrInternal + } + + return b, nil +} + +func (r *RandomGeneratorImpl) String(size int) (string, error) { + bytes, err := r.Bytes(size) + if err != nil { + return "", types.ErrInternal + } + + return base64.StdEncoding.EncodeToString(bytes), nil +} + +func (r *RandomGeneratorImpl) UUID() (uuid.UUID, error) { + return uuid.NewRandom() +} diff --git a/utils/ctypto.go b/utils/ctypto.go deleted file mode 100644 index 4f3cffe..0000000 --- a/utils/ctypto.go +++ /dev/null @@ -1,16 +0,0 @@ -package utils - -import ( - "crypto/rand" - "encoding/base64" -) - -func RandomToken() (string, error) { - b := make([]byte, 32) - _, err := rand.Read(b) - if err != nil { - return "", err - } - - return base64.StdEncoding.EncodeToString(b), nil -}