diff --git a/db/auth.go b/db/auth.go index 189a43e..b2bb846 100644 --- a/db/auth.go +++ b/db/auth.go @@ -17,90 +17,22 @@ var ( ErrAlreadyExists = errors.New("row already exists") ) -type User struct { - Id uuid.UUID - Email string - EmailVerified bool - EmailVerifiedAt *time.Time - IsAdmin bool - Password []byte - Salt []byte - CreateAt time.Time -} - -func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User { - return &User{ - Id: id, - Email: email, - EmailVerified: emailVerified, - EmailVerifiedAt: emailVerifiedAt, - IsAdmin: isAdmin, - Password: password, - Salt: salt, - CreateAt: createAt, - } -} - -type Session struct { - Id string - UserId uuid.UUID - CreatedAt time.Time - ExpiresAt time.Time -} - -func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session { - return &Session{ - Id: id, - UserId: userId, - CreatedAt: createdAt, - ExpiresAt: expiresAt, - } -} - -type Token struct { - UserId uuid.UUID - SessionId string - Token string - Type TokenType - CreatedAt time.Time - ExpiresAt time.Time -} - -type TokenType string - -var ( - TokenTypeEmailVerify TokenType = "email_verify" - TokenTypePasswordReset TokenType = "password_reset" - TokenTypeCsrf TokenType = "csrf" -) - -func NewToken(userId uuid.UUID, sessionId string, token string, tokenType TokenType, createdAt time.Time, expiresAt time.Time) *Token { - return &Token{ - UserId: userId, - SessionId: sessionId, - Token: token, - Type: tokenType, - CreatedAt: createdAt, - ExpiresAt: expiresAt, - } -} - type Auth interface { - InsertUser(user *User) error - UpdateUser(user *User) error - GetUserByEmail(email string) (*User, error) - GetUser(userId uuid.UUID) (*User, error) + InsertUser(user *types.User) error + UpdateUser(user *types.User) error + GetUserByEmail(email string) (*types.User, error) + GetUser(userId uuid.UUID) (*types.User, error) DeleteUser(userId uuid.UUID) error - InsertToken(token *Token) error - GetToken(token string) (*Token, error) - GetTokensByUserIdAndType(userId uuid.UUID, tokenType TokenType) ([]*Token, error) - GetTokensBySessionIdAndType(sessionId string, tokenType TokenType) ([]*Token, error) + InsertToken(token *types.Token) error + GetToken(token string) (*types.Token, error) + GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) + GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) DeleteToken(token string) error - InsertSession(session *Session) error - GetSession(sessionId string) (*Session, error) - GetSessions(userId uuid.UUID) ([]*Session, error) + InsertSession(session *types.Session) error + GetSession(sessionId string) (*types.Session, error) + GetSessions(userId uuid.UUID) ([]*types.Session, error) DeleteSession(sessionId string) error DeleteOldSessions(userId uuid.UUID) error } @@ -113,7 +45,7 @@ func NewAuthSqlite(db *sql.DB) *AuthSqlite { return &AuthSqlite{db: db} } -func (db AuthSqlite) InsertUser(user *User) error { +func (db AuthSqlite) InsertUser(user *types.User) error { _, err := db.db.Exec(` INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, @@ -131,7 +63,7 @@ func (db AuthSqlite) InsertUser(user *User) error { return nil } -func (db AuthSqlite) UpdateUser(user *User) error { +func (db AuthSqlite) UpdateUser(user *types.User) error { _, err := db.db.Exec(` UPDATE user SET email_verified = ?, email_verified_at = ?, password = ? @@ -146,7 +78,7 @@ func (db AuthSqlite) UpdateUser(user *User) error { return nil } -func (db AuthSqlite) GetUserByEmail(email string) (*User, error) { +func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) { var ( userId uuid.UUID emailVerified bool @@ -170,10 +102,10 @@ func (db AuthSqlite) GetUserByEmail(email string) (*User, error) { } } - return NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil + return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil } -func (db AuthSqlite) GetUser(userId uuid.UUID) (*User, error) { +func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) { var ( email string emailVerified bool @@ -197,7 +129,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*User, error) { } } - return NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil + return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil } func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { @@ -245,7 +177,7 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { return nil } -func (db AuthSqlite) InsertToken(token *Token) error { +func (db AuthSqlite) InsertToken(token *types.Token) error { _, err := db.db.Exec(` INSERT INTO token (user_id, session_id, type, token, created_at, expires_at) VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt) @@ -258,11 +190,11 @@ func (db AuthSqlite) InsertToken(token *Token) error { return nil } -func (db AuthSqlite) GetToken(token string) (*Token, error) { +func (db AuthSqlite) GetToken(token string) (*types.Token, error) { var ( userId uuid.UUID sessionId string - tokenType TokenType + tokenType types.TokenType createdAtStr string expiresAtStr string createdAt time.Time @@ -296,10 +228,10 @@ func (db AuthSqlite) GetToken(token string) (*Token, error) { return nil, types.ErrInternal } - return NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil + return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil } -func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType TokenType) ([]*Token, error) { +func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) { query, err := db.db.Query(` SELECT token, created_at, expires_at @@ -315,7 +247,7 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType TokenT return getTokensFromQuery(query, userId, "", tokenType) } -func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType TokenType) ([]*Token, error) { +func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) { query, err := db.db.Query(` SELECT token, created_at, expires_at @@ -331,8 +263,8 @@ func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType Tok return getTokensFromQuery(query, uuid.Nil, sessionId, tokenType) } -func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tokenType TokenType) ([]*Token, error) { - var tokens []*Token +func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tokenType types.TokenType) ([]*types.Token, error) { + var tokens []*types.Token hasRows := false for query.Next() { @@ -364,7 +296,7 @@ func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tok return nil, types.ErrInternal } - tokens = append(tokens, NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt)) + tokens = append(tokens, types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt)) } if !hasRows { @@ -383,7 +315,7 @@ func (db AuthSqlite) DeleteToken(token string) error { return nil } -func (db AuthSqlite) InsertSession(session *Session) error { +func (db AuthSqlite) InsertSession(session *types.Session) error { _, err := db.db.Exec(` INSERT INTO session (session_id, user_id, created_at, expires_at) @@ -397,7 +329,7 @@ func (db AuthSqlite) InsertSession(session *Session) error { return nil } -func (db AuthSqlite) GetSession(sessionId string) (*Session, error) { +func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) { var ( userId uuid.UUID @@ -414,10 +346,10 @@ func (db AuthSqlite) GetSession(sessionId string) (*Session, error) { return nil, ErrNotFound } - return NewSession(sessionId, userId, createdAt, expiresAt), nil + return types.NewSession(sessionId, userId, createdAt, expiresAt), nil } -func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*Session, error) { +func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) { sessions, err := db.db.Query(` SELECT session_id, created_at, expires_at @@ -428,7 +360,7 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*Session, error) { return nil, types.ErrInternal } - var result []*Session + var result []*types.Session for sessions.Next() { var ( @@ -443,7 +375,7 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*Session, error) { return nil, types.ErrInternal } - session := NewSession(sessionId, userId, createdAt, expiresAt) + session := types.NewSession(sessionId, userId, createdAt, expiresAt) result = append(result, session) } diff --git a/db/auth_test.go b/db/auth_test.go index d096c38..d243290 100644 --- a/db/auth_test.go +++ b/db/auth_test.go @@ -38,7 +38,7 @@ func TestUser(t *testing.T) { verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - expected := NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(expected) assert.Nil(t, err) @@ -68,7 +68,7 @@ func TestUser(t *testing.T) { verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(user) assert.Nil(t, err) @@ -83,7 +83,7 @@ func TestUser(t *testing.T) { underTest := AuthSqlite{db: db} createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) + user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) err := underTest.InsertUser(user) assert.Equal(t, types.ErrInternal, err) @@ -101,7 +101,7 @@ func TestToken(t *testing.T) { createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) expiresAt := createAt.Add(24 * time.Hour) - expected := NewToken(uuid.New(), "sessionId", "token", TokenTypeCsrf, createAt, expiresAt) + expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt) err := underTest.InsertToken(expected) assert.Nil(t, err) @@ -113,13 +113,13 @@ func TestToken(t *testing.T) { expected.SessionId = "" actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type) assert.Nil(t, err) - assert.Equal(t, []*Token{expected}, actuals) + assert.Equal(t, []*types.Token{expected}, actuals) expected.SessionId = "sessionId" expected.UserId = uuid.Nil actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type) assert.Nil(t, err) - assert.Equal(t, []*Token{expected}, actuals) + assert.Equal(t, []*types.Token{expected}, actuals) }) t.Run("should insert and return multiple tokens", func(t *testing.T) { t.Parallel() @@ -130,8 +130,8 @@ func TestToken(t *testing.T) { createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) expiresAt := createAt.Add(24 * time.Hour) userId := uuid.New() - expected1 := NewToken(userId, "sessionId", "token1", TokenTypeCsrf, createAt, expiresAt) - expected2 := NewToken(userId, "sessionId", "token2", TokenTypeCsrf, createAt, expiresAt) + expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt) + expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt) err := underTest.InsertToken(expected1) assert.Nil(t, err) @@ -142,7 +142,7 @@ func TestToken(t *testing.T) { expected2.UserId = uuid.Nil actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type) assert.Nil(t, err) - assert.Equal(t, []*Token{expected1, expected2}, actuals) + assert.Equal(t, []*types.Token{expected1, expected2}, actuals) expected1.SessionId = "" expected2.SessionId = "" @@ -150,7 +150,7 @@ func TestToken(t *testing.T) { expected2.UserId = userId actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type) assert.Nil(t, err) - assert.Equal(t, []*Token{expected1, expected2}, actuals) + assert.Equal(t, []*types.Token{expected1, expected2}, actuals) }) t.Run("should return ErrNotFound", func(t *testing.T) { @@ -162,10 +162,10 @@ func TestToken(t *testing.T) { _, err := underTest.GetToken("nonExistent") assert.Equal(t, ErrNotFound, err) - _, err = underTest.GetTokensByUserIdAndType(uuid.New(), TokenTypeEmailVerify) + _, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify) assert.Equal(t, ErrNotFound, err) - _, err = underTest.GetTokensBySessionIdAndType("sessionId", TokenTypeEmailVerify) + _, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify) assert.Equal(t, ErrNotFound, err) }) t.Run("should return ErrAlreadyExists", func(t *testing.T) { @@ -176,7 +176,7 @@ func TestToken(t *testing.T) { verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) + user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) err := underTest.InsertUser(user) assert.Nil(t, err) @@ -191,7 +191,7 @@ func TestToken(t *testing.T) { underTest := AuthSqlite{db: db} createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) - user := NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) + user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) err := underTest.InsertUser(user) assert.Equal(t, types.ErrInternal, err) diff --git a/handler/auth.go b/handler/auth.go index f533fa7..42db354 100644 --- a/handler/auth.go +++ b/handler/auth.go @@ -77,11 +77,11 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc { func (handler AuthImpl) handleSignIn() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*service.User, error) { + user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) { var email = r.FormValue("email") var password = r.FormValue("password") - session, err := handler.service.SignIn(email, password) + session, user, err := handler.service.SignIn(email, password) if err != nil { return nil, err } @@ -89,7 +89,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc { cookie := middleware.CreateSessionCookie(session.Id) http.SetCookie(w, &cookie) - return session.User, nil + return user, nil }) if err != nil { @@ -294,7 +294,8 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { session := middleware.GetSession(r) - if session == nil || session.User == nil { + user := middleware.GetUser(r) + if session == nil || user == nil { utils.DoRedirect(w, r, "/auth/signin") return } @@ -302,7 +303,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc { currPass := r.FormValue("current-password") newPass := r.FormValue("new-password") - err := handler.service.ChangePassword(session, currPass, newPass) + err := handler.service.ChangePassword(user, session.Id, currPass, newPass) if err != nil { utils.TriggerToast(w, r, "error", "Password not correct", http.StatusUnauthorized) return diff --git a/handler/index_and_404.go b/handler/index_and_404.go index 30c67ee..0c81c6b 100644 --- a/handler/index_and_404.go +++ b/handler/index_and_404.go @@ -32,11 +32,7 @@ func (handler IndexImpl) Handle(router *http.ServeMux) { func (handler IndexImpl) handleIndexAnd404() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - session := middleware.GetSession(r) - var user *service.User - if session != nil { - user = session.User - } + user := middleware.GetUser(r) var comp templ.Component diff --git a/handler/middleware/authenticate.go b/handler/middleware/authenticate.go index 23b585c..9071215 100644 --- a/handler/middleware/authenticate.go +++ b/handler/middleware/authenticate.go @@ -2,24 +2,29 @@ package middleware import ( "context" + "net/http" "me-fit/service" - - "net/http" + "me-fit/types" ) type ContextKey string var SessionKey ContextKey = "session" +var UserKey ContextKey = "user" func Authenticate(service service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + sessionId := getSessionID(r) - session, _ := service.SignInSession(sessionId) + session, user, _ := service.SignInSession(sessionId) if session != nil { - ctx := context.WithValue(r.Context(), SessionKey, session) + + ctx = context.WithValue(ctx, UserKey, user) + ctx = context.WithValue(ctx, SessionKey, session) next.ServeHTTP(w, r.WithContext(ctx)) } else { @@ -29,23 +34,22 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler { } } -func GetUser(r *http.Request) *service.User { - - session := GetSession(r) - if session == nil { +func GetUser(r *http.Request) *types.User { + obj := r.Context().Value(UserKey) + if obj == nil { return nil } - return session.User + return obj.(*types.User) } -func GetSession(r *http.Request) *service.Session { +func GetSession(r *http.Request) *types.Session { obj := r.Context().Value(SessionKey) if obj == nil { return nil } - return obj.(*service.Session) + return obj.(*types.Session) } func getSessionID(r *http.Request) string { diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 1186228..2343cf4 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -6,15 +6,16 @@ import ( "strings" "me-fit/service" + "me-fit/types" ) type csrfResponseWriter struct { http.ResponseWriter auth service.Auth - session *service.Session + session *types.Session } -func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *service.Session) *csrfResponseWriter { +func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter { return &csrfResponseWriter{ ResponseWriter: w, auth: auth, diff --git a/handler/render.go b/handler/render.go index e2bf3e8..c3f007c 100644 --- a/handler/render.go +++ b/handler/render.go @@ -2,7 +2,6 @@ package handler import ( "me-fit/log" - "me-fit/service" "me-fit/template" "me-fit/template/auth" "me-fit/types" @@ -31,14 +30,14 @@ func (render *Render) Render(r *http.Request, w http.ResponseWriter, comp templ. } } -func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *service.User) { +func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *types.User) { userComp := render.getUserComp(user) layout := template.Layout(slot, userComp, render.settings.Environment) render.Render(r, w, layout) } -func (render *Render) getUserComp(user *service.User) templ.Component { +func (render *Render) getUserComp(user *types.User) templ.Component { if user != nil { return auth.UserComp(user.Email) diff --git a/handler/workout.go b/handler/workout.go index 169deff..e0c2ecc 100644 --- a/handler/workout.go +++ b/handler/workout.go @@ -38,22 +38,22 @@ func (handler WorkoutImpl) Handle(router *http.ServeMux) { func (handler WorkoutImpl) handleWorkoutPage() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - session := middleware.GetSession(r) - if session == nil { + user := middleware.GetUser(r) + if user == nil { utils.DoRedirect(w, r, "/auth/signin") return } currentDate := time.Now().Format("2006-01-02") comp := workout.WorkoutComp(currentDate) - handler.render.RenderLayout(r, w, comp, session.User) + handler.render.RenderLayout(r, w, comp, user) } } func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - session := middleware.GetSession(r) - if session == nil { + user := middleware.GetUser(r) + if user == nil { utils.DoRedirect(w, r, "/auth/signin") return } @@ -64,7 +64,7 @@ func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc { var repsStr = r.FormValue("reps") wo := service.NewWorkoutDto("", dateStr, typeStr, setsStr, repsStr) - wo, err := handler.service.AddWorkout(session.User, wo) + wo, err := handler.service.AddWorkout(user, wo) if err != nil { utils.TriggerToast(w, r, "error", "Invalid input values", http.StatusBadRequest) http.Error(w, "Invalid input values", http.StatusBadRequest) @@ -79,13 +79,13 @@ func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc { func (handler WorkoutImpl) handleGetWorkout() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - session := middleware.GetSession(r) - if session == nil { + user := middleware.GetUser(r) + if user == nil { utils.DoRedirect(w, r, "/auth/signin") return } - workouts, err := handler.service.GetWorkouts(session.User) + workouts, err := handler.service.GetWorkouts(user) if err != nil { return } @@ -102,8 +102,8 @@ func (handler WorkoutImpl) handleGetWorkout() http.HandlerFunc { func (handler WorkoutImpl) handleDeleteWorkout() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - session := middleware.GetSession(r) - if session == nil { + user := middleware.GetUser(r) + if user == nil { utils.DoRedirect(w, r, "/auth/signin") return } @@ -120,7 +120,7 @@ func (handler WorkoutImpl) handleDeleteWorkout() http.HandlerFunc { return } - err = handler.service.DeleteWorkout(session.User, rowIdInt) + err = handler.service.DeleteWorkout(user, rowIdInt) if err != nil { utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) return diff --git a/main_test.go b/main_test.go index 40fabf2..da7395b 100644 --- a/main_test.go +++ b/main_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "me-fit/db" "me-fit/service" "me-fit/types" @@ -271,7 +270,7 @@ func TestIntegrationAuth(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) var token string - err = d.QueryRow("SELECT token FROM token WHERE type = ?", db.TokenTypePasswordReset).Scan(&token) + err = d.QueryRow("SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token) assert.Nil(t, err) formData = url.Values{ diff --git a/service/auth.go b/service/auth.go index 04592fc..b907a81 100644 --- a/service/auth.go +++ b/service/auth.go @@ -26,55 +26,25 @@ var ( ErrTokenInvalid = errors.New("token is invalid") ) -type User struct { - Id uuid.UUID - Email string - EmailVerified bool -} - -func NewUser(user *db.User) *User { - return &User{ - Id: user.Id, - Email: user.Email, - EmailVerified: user.EmailVerified, - } -} - -type Session struct { - Id string - CreatedAt time.Time - ExpiresAt time.Time - User *User -} - -func NewSession(session *db.Session, user *User) *Session { - return &Session{ - Id: session.Id, - CreatedAt: session.CreatedAt, - ExpiresAt: session.ExpiresAt, - User: user, - } -} - type Auth interface { - SignUp(email string, password string) (*User, error) + SignUp(email string, password string) (*types.User, error) SendVerificationMail(userId uuid.UUID, email string) VerifyUserEmail(token string) error - SignIn(email string, password string) (*Session, error) - SignInSession(sessionId string) (*Session, error) - SignInAnonymous() (*Session, error) + SignIn(email string, password string) (*types.Session, *types.User, error) + SignInSession(sessionId string) (*types.Session, *types.User, error) + SignInAnonymous() (*types.Session, error) SignOut(sessionId string) error - DeleteAccount(user *User, currPass string) error + DeleteAccount(user *types.User, currPass string) error - ChangePassword(session *Session, currPass, newPass string) error + ChangePassword(user *types.User, sessionId string, currPass, newPass string) error SendForgotPasswordMail(email string) error ForgotPassword(token string, newPass string) error IsCsrfTokenValid(tokenStr string, sessionId string) bool - GetCsrfToken(session *Session) (string, error) + GetCsrfToken(session *types.Session) (string, error) } type AuthImpl struct { @@ -95,69 +65,65 @@ func NewAuthImpl(db db.Auth, random Random, clock Clock, mail Mail, serverSettin } } -func (service AuthImpl) SignIn(email string, password string) (*Session, error) { +func (service AuthImpl) SignIn(email string, password string) (*types.Session, *types.User, error) { user, err := service.db.GetUserByEmail(email) if err != nil { if errors.Is(err, db.ErrNotFound) { - return nil, ErrInvalidCredentials + return nil, nil, ErrInvalidCredentials } else { - return nil, types.ErrInternal + return nil, nil, types.ErrInternal } } hash := GetHashPassword(password, user.Salt) if subtle.ConstantTimeCompare(hash, user.Password) == 0 { - return nil, ErrInvalidCredentials + return nil, nil, ErrInvalidCredentials } session, err := service.createSession(user.Id) if err != nil { - return nil, types.ErrInternal + return nil, nil, types.ErrInternal } - return NewSession(session, NewUser(user)), nil + return session, user, nil } -func (service AuthImpl) SignInSession(sessionId string) (*Session, error) { +func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) { if sessionId == "" { - return nil, ErrSessionIdInvalid + return nil, nil, ErrSessionIdInvalid } - sessionDb, err := service.db.GetSession(sessionId) + session, err := service.db.GetSession(sessionId) + if err != nil { + return nil, nil, types.ErrInternal + } + if session.ExpiresAt.Before(service.clock.Now()) { + return nil, nil, nil + } + + if session.UserId == uuid.Nil { + return session, nil, nil + } + + user, err := service.db.GetUser(session.UserId) + if err != nil { + return nil, nil, types.ErrInternal + } + + return session, user, nil +} + +func (service AuthImpl) SignInAnonymous() (*types.Session, error) { + session, err := service.createSession(uuid.Nil) if err != nil { return nil, types.ErrInternal } - if sessionDb.ExpiresAt.Before(service.clock.Now()) { - return nil, nil - } - - if sessionDb.UserId == uuid.Nil { - return NewSession(sessionDb, nil), nil - } - - userDb, err := service.db.GetUser(sessionDb.UserId) - if err != nil { - return nil, types.ErrInternal - } - - user := NewUser(userDb) - session := NewSession(sessionDb, user) - return session, nil } -func (service AuthImpl) SignInAnonymous() (*Session, error) { - sessionDb, err := service.createSession(uuid.Nil) - if err != nil { - return nil, types.ErrInternal - } - - return NewSession(sessionDb, nil), nil -} - -func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) { +func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) { sessionId, err := service.random.String(32) if err != nil { return nil, types.ErrInternal @@ -172,7 +138,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) { createAt := service.clock.Now() expiresAt := createAt.Add(24 * time.Hour) - session := db.NewSession(sessionId, userId, createAt, expiresAt) + session := types.NewSession(sessionId, userId, createAt, expiresAt) err = service.db.InsertSession(session) if err != nil { @@ -182,7 +148,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) { return session, nil } -func (service AuthImpl) SignUp(email string, password string) (*User, error) { +func (service AuthImpl) SignUp(email string, password string) (*types.User, error) { _, err := mail.ParseAddress(email) if err != nil { return nil, ErrInvalidEmail @@ -204,9 +170,9 @@ func (service AuthImpl) SignUp(email string, password string) (*User, error) { hash := GetHashPassword(password, salt) - dbUser := db.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) + user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) - err = service.db.InsertUser(dbUser) + err = service.db.InsertUser(user) if err != nil { if err == db.ErrAlreadyExists { return nil, ErrAccountExists @@ -215,17 +181,17 @@ func (service AuthImpl) SignUp(email string, password string) (*User, error) { } } - return NewUser(dbUser), nil + return user, nil } func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { - tokens, err := service.db.GetTokensByUserIdAndType(userId, db.TokenTypeEmailVerify) + tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify) if err != nil && err != db.ErrNotFound { return } - var token *db.Token + var token *types.Token if len(tokens) > 0 { token = tokens[0] @@ -237,7 +203,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { return } - token = db.NewToken(userId, "", newTokenStr, db.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) + token = types.NewToken(userId, "", newTokenStr, types.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) err = service.db.InsertToken(token) if err != nil { @@ -271,7 +237,7 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error { return types.ErrInternal } - if token.Type != db.TokenTypeEmailVerify { + if token.Type != types.TokenTypeEmailVerify { return types.ErrInternal } @@ -298,7 +264,7 @@ func (service AuthImpl) SignOut(sessionId string) error { return service.db.DeleteSession(sessionId) } -func (service AuthImpl) DeleteAccount(user *User, currPass string) error { +func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error { userDb, err := service.db.GetUser(user.Id) if err != nil { @@ -320,7 +286,7 @@ func (service AuthImpl) DeleteAccount(user *User, currPass string) error { return nil } -func (service AuthImpl) ChangePassword(session *Session, currPass, newPass string) error { +func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error { if !isPasswordValid(newPass) { return ErrInvalidPassword @@ -330,31 +296,26 @@ func (service AuthImpl) ChangePassword(session *Session, currPass, newPass strin return ErrInvalidPassword } - userDb, err := service.db.GetUser(session.User.Id) - if err != nil { - return err - } + currHash := GetHashPassword(currPass, user.Salt) - currHash := GetHashPassword(currPass, userDb.Salt) - - if subtle.ConstantTimeCompare(currHash, userDb.Password) == 0 { + if subtle.ConstantTimeCompare(currHash, user.Password) == 0 { return ErrInvalidCredentials } - newHash := GetHashPassword(newPass, userDb.Salt) - userDb.Password = newHash + newHash := GetHashPassword(newPass, user.Salt) + user.Password = newHash - err = service.db.UpdateUser(userDb) + err := service.db.UpdateUser(user) if err != nil { return err } - sessions, err := service.db.GetSessions(userDb.Id) + sessions, err := service.db.GetSessions(user.Id) if err != nil { return types.ErrInternal } for _, s := range sessions { - if s.Id != session.Id { + if s.Id != sessionId { err = service.db.DeleteSession(s.Id) if err != nil { return types.ErrInternal @@ -380,7 +341,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error { } } - token := db.NewToken(user.Id, "", tokenStr, db.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute)) + token := types.NewToken(user.Id, "", tokenStr, types.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute)) err = service.db.InsertToken(token) if err != nil { @@ -414,7 +375,7 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error { return err } - if token.Type != db.TokenTypePasswordReset || + if token.Type != types.TokenTypePasswordReset || token.ExpiresAt.Before(service.clock.Now()) { return ErrTokenInvalid } @@ -454,7 +415,7 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool return false } - if token.Type != db.TokenTypeCsrf || + if token.Type != types.TokenTypeCsrf || token.SessionId != sessionId || token.ExpiresAt.Before(service.clock.Now()) { @@ -464,12 +425,12 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool return true } -func (service AuthImpl) GetCsrfToken(session *Session) (string, error) { +func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) { if session == nil { return "", types.ErrInternal } - tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, db.TokenTypeCsrf) + tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf) if len(tokens) > 0 { return tokens[0].Token, nil @@ -480,7 +441,7 @@ func (service AuthImpl) GetCsrfToken(session *Session) (string, error) { return "", types.ErrInternal } - token := db.NewToken(uuid.Nil, session.Id, tokenStr, db.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) + token := types.NewToken(uuid.Nil, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) err = service.db.InsertToken(token) if err != nil { return "", types.ErrInternal diff --git a/service/auth_test.go b/service/auth_test.go index 513dcc8..e3a5464 100644 --- a/service/auth_test.go +++ b/service/auth_test.go @@ -22,7 +22,7 @@ func TestSignIn(t *testing.T) { t.Parallel() salt := []byte("salt") verifiedAt := time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC) - user := db.NewUser( + user := types.NewUser( uuid.New(), "test@test.de", true, @@ -33,12 +33,12 @@ func TestSignIn(t *testing.T) { time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), ) - dbSession := db.NewSession("sessionId", user.Id, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) + session := types.NewSession("sessionId", user.Id, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) mockAuthDb := mocks.NewMockAuth(t) mockAuthDb.EXPECT().GetUserByEmail("test@test.de").Return(user, nil) mockAuthDb.EXPECT().DeleteOldSessions(user.Id).Return(nil) - mockAuthDb.EXPECT().InsertSession(dbSession).Return(nil) + mockAuthDb.EXPECT().InsertSession(session).Return(nil) mockRandom := mocks.NewMockRandom(t) mockRandom.EXPECT().String(32).Return("sessionId", nil) mockClock := mocks.NewMockClock(t) @@ -47,11 +47,11 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - actualSession, err := underTest.SignIn(user.Email, "password") + actualSession, actualUser, err := underTest.SignIn(user.Email, "password") assert.Nil(t, err) - expectedSession := NewSession(dbSession, NewUser(user)) - assert.Equal(t, expectedSession, actualSession) + assert.Equal(t, session, actualSession) + assert.Equal(t, user, actualUser) }) t.Run("should return ErrInvalidCretentials if password is not correct", func(t *testing.T) { @@ -59,7 +59,7 @@ func TestSignIn(t *testing.T) { salt := []byte("salt") verifiedAt := time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC) - user := db.NewUser( + user := types.NewUser( uuid.New(), "test@test.de", true, @@ -78,7 +78,7 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - _, err := underTest.SignIn("test@test.de", "wrong password") + _, _, err := underTest.SignIn("test@test.de", "wrong password") assert.Equal(t, ErrInvalidCredentials, err) }) @@ -93,7 +93,7 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - _, err := underTest.SignIn("test", "test") + _, _, err := underTest.SignIn("test", "test") assert.Equal(t, ErrInvalidCredentials, err) }) t.Run("should forward ErrInternal on any other error", func(t *testing.T) { @@ -107,7 +107,7 @@ func TestSignIn(t *testing.T) { underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - _, err := underTest.SignIn("test", "test") + _, _, err := underTest.SignIn("test", "test") assert.Equal(t, types.ErrInternal, err) }) @@ -159,33 +159,25 @@ func TestSignUp(t *testing.T) { mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) - expected := User{ - Id: uuid.New(), - Email: "some@valid.email", - EmailVerified: false, - } - - random := NewRandomImpl() - salt, err := random.Bytes(16) - assert.Nil(t, err) + userId := uuid.New() + email := "mail@mail.de" password := "SomeStrongPassword123!" - - mockRandom.EXPECT().UUID().Return(expected.Id, nil) - mockRandom.EXPECT().Bytes(16).Return(salt, nil) - + salt := []byte("salt") createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) - mockClock.EXPECT().Now().Return(createTime) + expected := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime) - mockAuthDb.EXPECT().InsertUser(db.NewUser(expected.Id, expected.Email, false, nil, false, GetHashPassword(password, salt), salt, createTime)).Return(nil) + mockRandom.EXPECT().UUID().Return(userId, nil) + mockRandom.EXPECT().Bytes(16).Return(salt, nil) + mockClock.EXPECT().Now().Return(createTime) + mockAuthDb.EXPECT().InsertUser(expected).Return(nil) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - - actual, err := underTest.SignUp(expected.Email, password) + actual, err := underTest.SignUp(email, password) assert.Nil(t, err) - assert.Equal(t, expected, *actual) + assert.Equal(t, expected, actual) }) t.Run("should return ErrAccountExists", func(t *testing.T) { t.Parallel() @@ -195,28 +187,22 @@ func TestSignUp(t *testing.T) { mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) - user := User{ - Id: uuid.New(), - Email: "some@valid.email", - } - - random := NewRandomImpl() - salt, err := random.Bytes(16) - assert.Nil(t, err) + userId := uuid.New() + email := "some@valid.email" + createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) password := "SomeStrongPassword123!" + salt := []byte("salt") + user := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime) mockRandom.EXPECT().UUID().Return(user.Id, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil) - - createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) - mockClock.EXPECT().Now().Return(createTime) - mockAuthDb.EXPECT().InsertUser(db.NewUser(user.Id, user.Email, false, nil, false, GetHashPassword(password, salt), salt, createTime)).Return(db.ErrAlreadyExists) + mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) - _, err = underTest.SignUp(user.Email, password) + _, err := underTest.SignUp(user.Email, password) assert.Equal(t, ErrAccountExists, err) }) } @@ -227,8 +213,8 @@ func TestSendVerificationMail(t *testing.T) { t.Run("should use stored token and send mail", func(t *testing.T) { t.Parallel() - token := db.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", db.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) - tokens := []*db.Token{token} + token := types.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", types.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) + tokens := []*types.Token{token} email := "some@email.de" userId := uuid.New() @@ -238,7 +224,7 @@ func TestSendVerificationMail(t *testing.T) { mockClock := mocks.NewMockClock(t) mockMail := mocks.NewMockMail(t) - mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, db.TokenTypeEmailVerify).Return(tokens, nil) + mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify).Return(tokens, nil) mockMail.EXPECT().SendMail(email, "Welcome to ME-FIT", mock.MatchedBy(func(message string) bool { return strings.Contains(message, token.Token) diff --git a/service/workout.go b/service/workout.go index 1c22aeb..e5f60aa 100644 --- a/service/workout.go +++ b/service/workout.go @@ -10,9 +10,9 @@ import ( ) type Workout interface { - AddWorkout(user *User, workoutDto *WorkoutDto) (*WorkoutDto, error) - DeleteWorkout(user *User, rowId int) error - GetWorkouts(user *User) ([]*WorkoutDto, error) + AddWorkout(user *types.User, workoutDto *WorkoutDto) (*WorkoutDto, error) + DeleteWorkout(user *types.User, rowId int) error + GetWorkouts(user *types.User) ([]*WorkoutDto, error) } type WorkoutImpl struct { @@ -64,7 +64,7 @@ var ( ErrInputValues = errors.New("invalid input values") ) -func (service WorkoutImpl) AddWorkout(user *User, workoutDto *WorkoutDto) (*WorkoutDto, error) { +func (service WorkoutImpl) AddWorkout(user *types.User, workoutDto *WorkoutDto) (*WorkoutDto, error) { if workoutDto.Date == "" || workoutDto.Type == "" || workoutDto.Sets == "" || workoutDto.Reps == "" { return nil, ErrInputValues @@ -95,7 +95,7 @@ func (service WorkoutImpl) AddWorkout(user *User, workoutDto *WorkoutDto) (*Work return NewWorkoutDtoFromDb(workout), nil } -func (service WorkoutImpl) DeleteWorkout(user *User, rowId int) error { +func (service WorkoutImpl) DeleteWorkout(user *types.User, rowId int) error { if user == nil { return types.ErrInternal } @@ -103,7 +103,7 @@ func (service WorkoutImpl) DeleteWorkout(user *User, rowId int) error { return service.db.DeleteWorkout(user.Id, rowId) } -func (service WorkoutImpl) GetWorkouts(user *User) ([]*WorkoutDto, error) { +func (service WorkoutImpl) GetWorkouts(user *types.User) ([]*WorkoutDto, error) { if user == nil { return nil, types.ErrInternal } diff --git a/types/auth.go b/types/auth.go new file mode 100644 index 0000000..9f46957 --- /dev/null +++ b/types/auth.go @@ -0,0 +1,75 @@ +package types + +import ( + "time" + + "github.com/google/uuid" +) + +type User struct { + Id uuid.UUID + Email string + EmailVerified bool + EmailVerifiedAt *time.Time + IsAdmin bool + Password []byte + Salt []byte + CreateAt time.Time +} + +func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User { + return &User{ + Id: id, + Email: email, + EmailVerified: emailVerified, + EmailVerifiedAt: emailVerifiedAt, + IsAdmin: isAdmin, + Password: password, + Salt: salt, + CreateAt: createAt, + } +} + +type Session struct { + Id string + UserId uuid.UUID + CreatedAt time.Time + ExpiresAt time.Time +} + +func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session { + return &Session{ + Id: id, + UserId: userId, + CreatedAt: createdAt, + ExpiresAt: expiresAt, + } +} + +type Token struct { + UserId uuid.UUID + SessionId string + Token string + Type TokenType + CreatedAt time.Time + ExpiresAt time.Time +} + +type TokenType string + +var ( + TokenTypeEmailVerify TokenType = "email_verify" + TokenTypePasswordReset TokenType = "password_reset" + TokenTypeCsrf TokenType = "csrf" +) + +func NewToken(userId uuid.UUID, sessionId string, token string, tokenType TokenType, createdAt time.Time, expiresAt time.Time) *Token { + return &Token{ + UserId: userId, + SessionId: sessionId, + Token: token, + Type: tokenType, + CreatedAt: createdAt, + ExpiresAt: expiresAt, + } +}