diff --git a/handler/default.go b/handler/default.go index 34a1cf7..3810adc 100644 --- a/handler/default.go +++ b/handler/default.go @@ -1,6 +1,7 @@ package handler import ( + "me-fit/db" "me-fit/middleware" "me-fit/service" @@ -8,17 +9,17 @@ import ( "net/http" ) -func GetHandler(db *sql.DB) http.Handler { +func GetHandler(d *sql.DB) http.Handler { var router = http.NewServeMux() - router.HandleFunc("/", service.HandleIndexAnd404(db)) + router.HandleFunc("/", service.HandleIndexAnd404(d)) - handlerAuth := NewHandlerAuth(db, service.NewServiceAuthImpl(db)) + handlerAuth := NewHandlerAuth(d, service.NewServiceAuthImpl(db.NewDbAuthSqlite(d))) // Serve static files (CSS, JS and images) router.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("./static/")))) - handleWorkout(db, router) + handleWorkout(d, router) handlerAuth.handle(router) diff --git a/service/auth.go b/service/auth.go index 4340c1f..e8b2eaa 100644 --- a/service/auth.go +++ b/service/auth.go @@ -51,9 +51,9 @@ type ServiceAuthImpl struct { dbAuth db.DbAuth } -func NewServiceAuthImpl(d *sql.DB) *ServiceAuthImpl { +func NewServiceAuthImpl(dbAuth db.DbAuth) *ServiceAuthImpl { return &ServiceAuthImpl{ - dbAuth: db.NewDbAuthSqlite(d), + dbAuth: dbAuth, } } @@ -64,7 +64,7 @@ func (service ServiceAuthImpl) SignIn(email string, password string) (*User, err if errors.Is(err, db.ErrUserNotFound) { return nil, ErrInvaidCredentials } else { - return nil, err + return nil, types.ErrInternal } } diff --git a/service/auth_test.go b/service/auth_test.go new file mode 100644 index 0000000..295ada2 --- /dev/null +++ b/service/auth_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "me-fit/db" + "me-fit/types" + + "errors" + "testing" + "time" + + "github.com/google/uuid" +) + +type DbAuthStub struct { + user *db.User + err error +} + +func (d DbAuthStub) GetUser(email string) (*db.User, error) { + return d.user, d.err +} + +func TestSignIn(t *testing.T) { + t.Parallel() + t.Run("should return user if password is correct", func(t *testing.T) { + t.Parallel() + salt := []byte("salt") + stub := DbAuthStub{ + user: db.NewUser( + uuid.New(), + "test@test.de", + true, + time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), + false, + getHashPassword("password", salt), + salt, + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + ), + err: nil, + } + underTest := NewServiceAuthImpl(stub) + + actualUser, err := underTest.SignIn("test@test.de", "password") + if err != nil { + t.Errorf("Expected nil, got %v", err) + } + + expectedUser := User{ + Id: stub.user.Id, + Email: stub.user.Email, + EmailVerified: stub.user.EmailVerified, + } + if *actualUser != expectedUser { + t.Errorf("Expected %v, got %v", expectedUser, actualUser) + } + }) + + t.Run("should return ErrInvalidCretentials if password is not correct", func(t *testing.T) { + t.Parallel() + salt := []byte("salt") + stub := DbAuthStub{ + user: db.NewUser( + uuid.New(), + "test@test.de", + true, + time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), + false, + getHashPassword("password", salt), + salt, + time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + ), + err: nil, + } + underTest := NewServiceAuthImpl(stub) + + _, err := underTest.SignIn("test@test.de", "wrong password") + if err != ErrInvaidCredentials { + t.Errorf("Expected %v, got %v", ErrInvaidCredentials, err) + } + }) + t.Run("should return ErrInvalidCretentials if user has not been found", func(t *testing.T) { + t.Parallel() + stub := DbAuthStub{ + user: nil, + err: db.ErrUserNotFound, + } + underTest := NewServiceAuthImpl(stub) + + _, err := underTest.SignIn("test", "test") + if err != ErrInvaidCredentials { + t.Errorf("Expected %v, got %v", ErrInvaidCredentials, err) + } + }) + t.Run("should forward ErrInternal on any other error", func(t *testing.T) { + t.Parallel() + stub := DbAuthStub{ + user: nil, + err: errors.New("Some error"), + } + underTest := NewServiceAuthImpl(stub) + + _, err := underTest.SignIn("test", "test") + if err != types.ErrInternal { + t.Errorf("Expected %v, got %v", types.ErrInternal, err) + } + }) +}