diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..fb24b7d --- /dev/null +++ b/auth_test.go @@ -0,0 +1,136 @@ +package main + +import ( + "me-fit/service" + "me-fit/utils" + + "context" + "database/sql" + "fmt" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/google/uuid" +) + +func TestHandleSignIn(t *testing.T) { + t.Parallel() + t.Run("should signIn and return session cookie", func(t *testing.T) { + t.Parallel() + ctx, done := context.WithCancel(context.Background()) + t.Cleanup(done) + + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatalf("Could not open Database data.db: %v", err) + } + t.Cleanup(func() { + db.Close() + }) + err = utils.RunMigrations(db, "") + if err != nil { + t.Fatalf("Could not run migrations: %v", err) + } + + pass := service.GetHashPassword("password", []byte("salt")) + _, err = db.Exec(` + INSERT INTO user (user_uuid, email, email_verified, is_admin, password, salt, created_at) + VALUES (?, "mail@mail.de", FALSE, FALSE, ?, ?, datetime())`, uuid.New(), pass, []byte("salt")) + if err != nil { + t.Fatalf("Error inserting user: %v", err) + } + + go run(ctx, db, func(key string) string { + if key == "PORT" { + return "8080" + } else if key == "SMTP_ENABLED" { + return "false" + } else if key == "PROMETHEUS_ENABLED" { + return "false" + } else if key == "BASE_URL" { + return "https://localhost:8080" + } else if key == "ENVIRONMENT" { + return "test" + } else { + return "" + } + }) + + err = waitForReady(ctx, 5*time.Second, "http://localhost:8080") + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + formData := url.Values{ + "email": {"mail@mail.de"}, + "password": {"password"}, + } + + req, err := http.NewRequestWithContext(ctx, "POST", "http://localhost:8080/api/auth/signin", strings.NewReader(formData.Encode())) + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Set the content type to application/x-www-form-urlencoded + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Error making request: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("Expected status code 200, got %d", resp.StatusCode) + } + + }) +} + +// waitForReady calls the specified endpoint until it gets a 200 +// response or until the context is cancelled or the timeout is +// reached. +func waitForReady( + ctx context.Context, + timeout time.Duration, + endpoint string, +) error { + client := http.Client{} + startTime := time.Now() + for { + req, err := http.NewRequestWithContext( + ctx, + http.MethodGet, + endpoint, + nil, + ) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + fmt.Printf("Error making request: %s\n", err.Error()) + continue + } + if resp.StatusCode == http.StatusOK { + fmt.Println("Endpoint is ready!") + resp.Body.Close() + return nil + } + resp.Body.Close() + + select { + case <-ctx.Done(): + return ctx.Err() + default: + if time.Since(startTime) >= timeout { + return fmt.Errorf("timeout reached while waiting for endpoint") + } + // wait a little while between checks + time.Sleep(250 * time.Millisecond) + } + } +} diff --git a/db/auth.go b/db/auth.go index c557a7d..1789dbb 100644 --- a/db/auth.go +++ b/db/auth.go @@ -18,14 +18,14 @@ type User struct { Id uuid.UUID Email string EmailVerified bool - EmailVerifiedAt time.Time + EmailVerifiedAt *time.Time IsAdmin bool Password []byte Salt []byte CreateAt time.Time } -func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User { +func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User { return &User{ Id: id, Email: email, @@ -54,7 +54,7 @@ func (db DbAuthSqlite) GetUser(email string) (*User, error) { var ( userId uuid.UUID emailVerified bool - emailVerifiedAt time.Time + emailVerifiedAt *time.Time isAdmin bool password []byte salt []byte diff --git a/db/auth_test.go b/db/auth_test.go index bc2729e..8cb34d4 100644 --- a/db/auth_test.go +++ b/db/auth_test.go @@ -16,7 +16,10 @@ func setupDb(t *testing.T) *sql.DB { t.Fatalf("Error opening database: %v", err) } - utils.MustRunMigrations(db, "../") + err = utils.RunMigrations(db, "../") + if err != nil { + t.Fatalf("Error running migrations: %v", err) + } return db } @@ -46,7 +49,7 @@ func TestGetUser(t *testing.T) { verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := NewUser(uuid.New(), "some@email.de", true, verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + user := NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) _, err := db.Exec(` INSERT INTO user (user_uuid, email, email_verified, email_verified_at, is_admin, password, salt, created_at) diff --git a/handler/auth_test.go b/handler/auth_test.go deleted file mode 100644 index 921902a..0000000 --- a/handler/auth_test.go +++ /dev/null @@ -1,11 +0,0 @@ -package handler - -import ( - "testing" -) - -func TestHandleSignIn(t *testing.T) { - t.Parallel() - t.Run("should signIn and return session cookie", func(t *testing.T) { - }) -} diff --git a/main.go b/main.go index ecaae18..4a15f3d 100644 --- a/main.go +++ b/main.go @@ -27,10 +27,16 @@ func main() { log.Fatal("Error loading .env file") } - run(context.Background(), os.Getenv) + db, err := sql.Open("sqlite3", "./data.db") + if err != nil { + log.Fatal("Could not open Database data.db: ", err) + } + defer db.Close() + + run(context.Background(), db, os.Getenv) } -func run(ctx context.Context, env func(string) string) { +func run(ctx context.Context, db *sql.DB, env func(string) string) { ctx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) defer cancel() @@ -40,15 +46,13 @@ func run(ctx context.Context, env func(string) string) { serverSettings := types.NewServerSettingsFromEnv(env) // init db - db, err := sql.Open("sqlite3", serverSettings.DbPath) + err := utils.RunMigrations(db, "") if err != nil { - log.Fatal("Could not open Database data.db: ", err) + slog.Error("Could not run migrations: " + err.Error()) + os.Exit(1) } - defer db.Close() - utils.MustRunMigrations(db, "") // init servers - var prometheusServer *http.Server if serverSettings.PrometheusEnabled { prometheusServer := &http.Server{ diff --git a/service/auth.go b/service/auth.go index 24edf01..655b492 100644 --- a/service/auth.go +++ b/service/auth.go @@ -68,7 +68,7 @@ func (service ServiceAuthImpl) SignIn(email string, password string) (*User, err } } - hash := getHashPassword(password, user.Salt) + hash := GetHashPassword(password, user.Salt) if subtle.ConstantTimeCompare(hash, user.Password) == 0 { return nil, ErrInvaidCredentials @@ -279,7 +279,7 @@ func HandleSignUpComp(db *sql.DB, serverSettings *types.ServerSettings) http.Han return } - hash := getHashPassword(password, salt) + hash := GetHashPassword(password, salt) _, err = db.Exec("INSERT INTO user (user_uuid, email, email_verified, is_admin, password, salt, created_at) VALUES (?, ?, FALSE, FALSE, ?, ?, datetime())", userId, email, hash, salt) if err != nil { @@ -366,7 +366,7 @@ func HandleDeleteAccountComp(db *sql.DB, serverSettings *types.ServerSettings) h return } - currHash := getHashPassword(password, salt) + currHash := GetHashPassword(password, salt) if subtle.ConstantTimeCompare(currHash, storedHash) == 0 { utils.TriggerToast(w, r, "error", "Password is not correct") return @@ -455,13 +455,13 @@ func HandleChangePasswordComp(db *sql.DB) http.HandlerFunc { return } - currHash := getHashPassword(currPass, salt) + currHash := GetHashPassword(currPass, salt) if subtle.ConstantTimeCompare(currHash, storedHash) == 0 { utils.TriggerToast(w, r, "error", "Current Password is not correct") return } - newHash := getHashPassword(newPass, salt) + newHash := GetHashPassword(newPass, salt) _, err = db.Exec("UPDATE user SET password = ? WHERE user_uuid = ?", newHash, user.Id) if err != nil { @@ -524,7 +524,7 @@ func HandleActualResetPasswordComp(db *sql.DB) http.HandlerFunc { return } - passHash := getHashPassword(newPass, salt) + passHash := GetHashPassword(newPass, salt) _, err = db.Exec("UPDATE user SET password = ? WHERE user_uuid = ?", passHash, userId) if err != nil { @@ -653,7 +653,7 @@ func TryCreateSessionAndSetCookie(r *http.Request, w http.ResponseWriter, db *sq return nil } -func getHashPassword(password string, salt []byte) []byte { +func GetHashPassword(password string, salt []byte) []byte { return argon2.IDKey([]byte(password), salt, 1, 64*1024, 1, 16) } diff --git a/service/auth_test.go b/service/auth_test.go index 295ada2..1ee4250 100644 --- a/service/auth_test.go +++ b/service/auth_test.go @@ -25,14 +25,15 @@ func TestSignIn(t *testing.T) { t.Run("should return user if password is correct", func(t *testing.T) { t.Parallel() salt := []byte("salt") + verifiedAt := time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC) stub := DbAuthStub{ user: db.NewUser( uuid.New(), "test@test.de", true, - time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), + &verifiedAt, false, - getHashPassword("password", salt), + GetHashPassword("password", salt), salt, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), ), @@ -58,14 +59,15 @@ func TestSignIn(t *testing.T) { t.Run("should return ErrInvalidCretentials if password is not correct", func(t *testing.T) { t.Parallel() salt := []byte("salt") + verifiedAt := time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC) stub := DbAuthStub{ user: db.NewUser( uuid.New(), "test@test.de", true, - time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC), + &verifiedAt, false, - getHashPassword("password", salt), + GetHashPassword("password", salt), salt, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), ), diff --git a/types/server_settings.go b/types/server_settings.go index 0ca858f..812a5a0 100644 --- a/types/server_settings.go +++ b/types/server_settings.go @@ -11,7 +11,6 @@ type ServerSettings struct { BaseUrl string Environment string - DbPath string Smtp *SmtpSettings } @@ -61,15 +60,10 @@ func NewServerSettingsFromEnv(env func(string) string) *ServerSettings { Port: env("PORT"), PrometheusEnabled: env("PROMETHEUS_ENABLED") == "true", BaseUrl: env("BASE_URL"), - DbPath: env("DB_PATH"), Environment: env("ENVIRONMENT"), Smtp: smtp, } - if settings.DbPath == "" { - settings.DbPath = "./data.db" - } - if settings.BaseUrl == "" { log.Fatal("BASE_URL must be set") } diff --git a/utils/db.go b/utils/db.go index 7ff44e4..5cac76d 100644 --- a/utils/db.go +++ b/utils/db.go @@ -2,17 +2,20 @@ package utils import ( "database/sql" - "log" + "errors" + "log/slog" + "me-fit/types" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/golang-migrate/migrate/v4/source/file" ) -func MustRunMigrations(db *sql.DB, pathPrefix string) { +func RunMigrations(db *sql.DB, pathPrefix string) error { driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { - log.Fatal(err) + slog.Error("Could not create Migration instance: " + err.Error()) + return types.ErrInternal } m, err := migrate.NewWithDatabaseInstance( @@ -20,13 +23,17 @@ func MustRunMigrations(db *sql.DB, pathPrefix string) { "", driver) if err != nil { - log.Fatal("Could not create migrations instance: ", err) + slog.Error("Could not create migrations instance: " + err.Error()) + return types.ErrInternal } err = m.Up() if err != nil { - if err.Error() != "no change" { - log.Fatal("Could not run migrations: ", err) + if !errors.Is(err, migrate.ErrNoChange) { + slog.Error("Could not run migrations: " + err.Error()) + return types.ErrInternal } } + + return nil }