feat: #337 unify types for auth module #338

Merged
tim merged 1 commits from 337-share-db-types into prod 2024-12-18 22:48:21 +00:00
13 changed files with 259 additions and 305 deletions

View File

@@ -17,90 +17,22 @@ var (
ErrAlreadyExists = errors.New("row already exists") ErrAlreadyExists = errors.New("row already exists")
) )
type User struct {
Id uuid.UUID
Email string
EmailVerified bool
EmailVerifiedAt *time.Time
IsAdmin bool
Password []byte
Salt []byte
CreateAt time.Time
}
func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User {
return &User{
Id: id,
Email: email,
EmailVerified: emailVerified,
EmailVerifiedAt: emailVerifiedAt,
IsAdmin: isAdmin,
Password: password,
Salt: salt,
CreateAt: createAt,
}
}
type Session struct {
Id string
UserId uuid.UUID
CreatedAt time.Time
ExpiresAt time.Time
}
func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session {
return &Session{
Id: id,
UserId: userId,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
}
}
type Token struct {
UserId uuid.UUID
SessionId string
Token string
Type TokenType
CreatedAt time.Time
ExpiresAt time.Time
}
type TokenType string
var (
TokenTypeEmailVerify TokenType = "email_verify"
TokenTypePasswordReset TokenType = "password_reset"
TokenTypeCsrf TokenType = "csrf"
)
func NewToken(userId uuid.UUID, sessionId string, token string, tokenType TokenType, createdAt time.Time, expiresAt time.Time) *Token {
return &Token{
UserId: userId,
SessionId: sessionId,
Token: token,
Type: tokenType,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
}
}
type Auth interface { type Auth interface {
InsertUser(user *User) error InsertUser(user *types.User) error
UpdateUser(user *User) error UpdateUser(user *types.User) error
GetUserByEmail(email string) (*User, error) GetUserByEmail(email string) (*types.User, error)
GetUser(userId uuid.UUID) (*User, error) GetUser(userId uuid.UUID) (*types.User, error)
DeleteUser(userId uuid.UUID) error DeleteUser(userId uuid.UUID) error
InsertToken(token *Token) error InsertToken(token *types.Token) error
GetToken(token string) (*Token, error) GetToken(token string) (*types.Token, error)
GetTokensByUserIdAndType(userId uuid.UUID, tokenType TokenType) ([]*Token, error) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error)
GetTokensBySessionIdAndType(sessionId string, tokenType TokenType) ([]*Token, error) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error)
DeleteToken(token string) error DeleteToken(token string) error
InsertSession(session *Session) error InsertSession(session *types.Session) error
GetSession(sessionId string) (*Session, error) GetSession(sessionId string) (*types.Session, error)
GetSessions(userId uuid.UUID) ([]*Session, error) GetSessions(userId uuid.UUID) ([]*types.Session, error)
DeleteSession(sessionId string) error DeleteSession(sessionId string) error
DeleteOldSessions(userId uuid.UUID) error DeleteOldSessions(userId uuid.UUID) error
} }
@@ -113,7 +45,7 @@ func NewAuthSqlite(db *sql.DB) *AuthSqlite {
return &AuthSqlite{db: db} return &AuthSqlite{db: db}
} }
func (db AuthSqlite) InsertUser(user *User) error { func (db AuthSqlite) InsertUser(user *types.User) error {
_, err := db.db.Exec(` _, err := db.db.Exec(`
INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at) INSERT INTO user (user_id, email, email_verified, email_verified_at, is_admin, password, salt, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
@@ -131,7 +63,7 @@ func (db AuthSqlite) InsertUser(user *User) error {
return nil return nil
} }
func (db AuthSqlite) UpdateUser(user *User) error { func (db AuthSqlite) UpdateUser(user *types.User) error {
_, err := db.db.Exec(` _, err := db.db.Exec(`
UPDATE user UPDATE user
SET email_verified = ?, email_verified_at = ?, password = ? SET email_verified = ?, email_verified_at = ?, password = ?
@@ -146,7 +78,7 @@ func (db AuthSqlite) UpdateUser(user *User) error {
return nil return nil
} }
func (db AuthSqlite) GetUserByEmail(email string) (*User, error) { func (db AuthSqlite) GetUserByEmail(email string) (*types.User, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
emailVerified bool emailVerified bool
@@ -170,10 +102,10 @@ func (db AuthSqlite) GetUserByEmail(email string) (*User, error) {
} }
} }
return NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
} }
func (db AuthSqlite) GetUser(userId uuid.UUID) (*User, error) { func (db AuthSqlite) GetUser(userId uuid.UUID) (*types.User, error) {
var ( var (
email string email string
emailVerified bool emailVerified bool
@@ -197,7 +129,7 @@ func (db AuthSqlite) GetUser(userId uuid.UUID) (*User, error) {
} }
} }
return NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil return types.NewUser(userId, email, emailVerified, emailVerifiedAt, isAdmin, password, salt, createdAt), nil
} }
func (db AuthSqlite) DeleteUser(userId uuid.UUID) error { func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
@@ -245,7 +177,7 @@ func (db AuthSqlite) DeleteUser(userId uuid.UUID) error {
return nil return nil
} }
func (db AuthSqlite) InsertToken(token *Token) error { func (db AuthSqlite) InsertToken(token *types.Token) error {
_, err := db.db.Exec(` _, err := db.db.Exec(`
INSERT INTO token (user_id, session_id, type, token, created_at, expires_at) INSERT INTO token (user_id, session_id, type, token, created_at, expires_at)
VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt) VALUES (?, ?, ?, ?, ?, ?)`, token.UserId, token.SessionId, token.Type, token.Token, token.CreatedAt, token.ExpiresAt)
@@ -258,11 +190,11 @@ func (db AuthSqlite) InsertToken(token *Token) error {
return nil return nil
} }
func (db AuthSqlite) GetToken(token string) (*Token, error) { func (db AuthSqlite) GetToken(token string) (*types.Token, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
sessionId string sessionId string
tokenType TokenType tokenType types.TokenType
createdAtStr string createdAtStr string
expiresAtStr string expiresAtStr string
createdAt time.Time createdAt time.Time
@@ -296,10 +228,10 @@ func (db AuthSqlite) GetToken(token string) (*Token, error) {
return nil, types.ErrInternal return nil, types.ErrInternal
} }
return NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil return types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt), nil
} }
func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType TokenType) ([]*Token, error) { func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(` query, err := db.db.Query(`
SELECT token, created_at, expires_at SELECT token, created_at, expires_at
@@ -315,7 +247,7 @@ func (db AuthSqlite) GetTokensByUserIdAndType(userId uuid.UUID, tokenType TokenT
return getTokensFromQuery(query, userId, "", tokenType) return getTokensFromQuery(query, userId, "", tokenType)
} }
func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType TokenType) ([]*Token, error) { func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
query, err := db.db.Query(` query, err := db.db.Query(`
SELECT token, created_at, expires_at SELECT token, created_at, expires_at
@@ -331,8 +263,8 @@ func (db AuthSqlite) GetTokensBySessionIdAndType(sessionId string, tokenType Tok
return getTokensFromQuery(query, uuid.Nil, sessionId, tokenType) return getTokensFromQuery(query, uuid.Nil, sessionId, tokenType)
} }
func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tokenType TokenType) ([]*Token, error) { func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tokenType types.TokenType) ([]*types.Token, error) {
var tokens []*Token var tokens []*types.Token
hasRows := false hasRows := false
for query.Next() { for query.Next() {
@@ -364,7 +296,7 @@ func getTokensFromQuery(query *sql.Rows, userId uuid.UUID, sessionId string, tok
return nil, types.ErrInternal return nil, types.ErrInternal
} }
tokens = append(tokens, NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt)) tokens = append(tokens, types.NewToken(userId, sessionId, token, tokenType, createdAt, expiresAt))
} }
if !hasRows { if !hasRows {
@@ -383,7 +315,7 @@ func (db AuthSqlite) DeleteToken(token string) error {
return nil return nil
} }
func (db AuthSqlite) InsertSession(session *Session) error { func (db AuthSqlite) InsertSession(session *types.Session) error {
_, err := db.db.Exec(` _, err := db.db.Exec(`
INSERT INTO session (session_id, user_id, created_at, expires_at) INSERT INTO session (session_id, user_id, created_at, expires_at)
@@ -397,7 +329,7 @@ func (db AuthSqlite) InsertSession(session *Session) error {
return nil return nil
} }
func (db AuthSqlite) GetSession(sessionId string) (*Session, error) { func (db AuthSqlite) GetSession(sessionId string) (*types.Session, error) {
var ( var (
userId uuid.UUID userId uuid.UUID
@@ -414,10 +346,10 @@ func (db AuthSqlite) GetSession(sessionId string) (*Session, error) {
return nil, ErrNotFound return nil, ErrNotFound
} }
return NewSession(sessionId, userId, createdAt, expiresAt), nil return types.NewSession(sessionId, userId, createdAt, expiresAt), nil
} }
func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*Session, error) { func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*types.Session, error) {
sessions, err := db.db.Query(` sessions, err := db.db.Query(`
SELECT session_id, created_at, expires_at SELECT session_id, created_at, expires_at
@@ -428,7 +360,7 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*Session, error) {
return nil, types.ErrInternal return nil, types.ErrInternal
} }
var result []*Session var result []*types.Session
for sessions.Next() { for sessions.Next() {
var ( var (
@@ -443,7 +375,7 @@ func (db AuthSqlite) GetSessions(userId uuid.UUID) ([]*Session, error) {
return nil, types.ErrInternal return nil, types.ErrInternal
} }
session := NewSession(sessionId, userId, createdAt, expiresAt) session := types.NewSession(sessionId, userId, createdAt, expiresAt)
result = append(result, session) result = append(result, session)
} }

View File

@@ -38,7 +38,7 @@ func TestUser(t *testing.T) {
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expected := NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) expected := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(expected) err := underTest.InsertUser(expected)
assert.Nil(t, err) assert.Nil(t, err)
@@ -68,7 +68,7 @@ func TestUser(t *testing.T) {
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(user)
assert.Nil(t, err) assert.Nil(t, err)
@@ -83,7 +83,7 @@ func TestUser(t *testing.T) {
underTest := AuthSqlite{db: db} underTest := AuthSqlite{db: db}
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(user)
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, types.ErrInternal, err)
@@ -101,7 +101,7 @@ func TestToken(t *testing.T) {
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
expected := NewToken(uuid.New(), "sessionId", "token", TokenTypeCsrf, createAt, expiresAt) expected := types.NewToken(uuid.New(), "sessionId", "token", types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(expected) err := underTest.InsertToken(expected)
assert.Nil(t, err) assert.Nil(t, err)
@@ -113,13 +113,13 @@ func TestToken(t *testing.T) {
expected.SessionId = "" expected.SessionId = ""
actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type) actuals, err := underTest.GetTokensByUserIdAndType(expected.UserId, expected.Type)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []*Token{expected}, actuals) assert.Equal(t, []*types.Token{expected}, actuals)
expected.SessionId = "sessionId" expected.SessionId = "sessionId"
expected.UserId = uuid.Nil expected.UserId = uuid.Nil
actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type) actuals, err = underTest.GetTokensBySessionIdAndType(expected.SessionId, expected.Type)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []*Token{expected}, actuals) assert.Equal(t, []*types.Token{expected}, actuals)
}) })
t.Run("should insert and return multiple tokens", func(t *testing.T) { t.Run("should insert and return multiple tokens", func(t *testing.T) {
t.Parallel() t.Parallel()
@@ -130,8 +130,8 @@ func TestToken(t *testing.T) {
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
userId := uuid.New() userId := uuid.New()
expected1 := NewToken(userId, "sessionId", "token1", TokenTypeCsrf, createAt, expiresAt) expected1 := types.NewToken(userId, "sessionId", "token1", types.TokenTypeCsrf, createAt, expiresAt)
expected2 := NewToken(userId, "sessionId", "token2", TokenTypeCsrf, createAt, expiresAt) expected2 := types.NewToken(userId, "sessionId", "token2", types.TokenTypeCsrf, createAt, expiresAt)
err := underTest.InsertToken(expected1) err := underTest.InsertToken(expected1)
assert.Nil(t, err) assert.Nil(t, err)
@@ -142,7 +142,7 @@ func TestToken(t *testing.T) {
expected2.UserId = uuid.Nil expected2.UserId = uuid.Nil
actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type) actuals, err := underTest.GetTokensBySessionIdAndType(expected1.SessionId, expected1.Type)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []*Token{expected1, expected2}, actuals) assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
expected1.SessionId = "" expected1.SessionId = ""
expected2.SessionId = "" expected2.SessionId = ""
@@ -150,7 +150,7 @@ func TestToken(t *testing.T) {
expected2.UserId = userId expected2.UserId = userId
actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type) actuals, err = underTest.GetTokensByUserIdAndType(userId, expected1.Type)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, []*Token{expected1, expected2}, actuals) assert.Equal(t, []*types.Token{expected1, expected2}, actuals)
}) })
t.Run("should return ErrNotFound", func(t *testing.T) { t.Run("should return ErrNotFound", func(t *testing.T) {
@@ -162,10 +162,10 @@ func TestToken(t *testing.T) {
_, err := underTest.GetToken("nonExistent") _, err := underTest.GetToken("nonExistent")
assert.Equal(t, ErrNotFound, err) assert.Equal(t, ErrNotFound, err)
_, err = underTest.GetTokensByUserIdAndType(uuid.New(), TokenTypeEmailVerify) _, err = underTest.GetTokensByUserIdAndType(uuid.New(), types.TokenTypeEmailVerify)
assert.Equal(t, ErrNotFound, err) assert.Equal(t, ErrNotFound, err)
_, err = underTest.GetTokensBySessionIdAndType("sessionId", TokenTypeEmailVerify) _, err = underTest.GetTokensBySessionIdAndType("sessionId", types.TokenTypeEmailVerify)
assert.Equal(t, ErrNotFound, err) assert.Equal(t, ErrNotFound, err)
}) })
t.Run("should return ErrAlreadyExists", func(t *testing.T) { t.Run("should return ErrAlreadyExists", func(t *testing.T) {
@@ -176,7 +176,7 @@ func TestToken(t *testing.T) {
verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 5, 13, 0, 0, 0, time.UTC)
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt) user := types.NewUser(uuid.New(), "some@email.de", true, &verifiedAt, false, []byte("somePass"), []byte("someSalt"), createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(user)
assert.Nil(t, err) assert.Nil(t, err)
@@ -191,7 +191,7 @@ func TestToken(t *testing.T) {
underTest := AuthSqlite{db: db} underTest := AuthSqlite{db: db}
createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC) createAt := time.Date(2020, 1, 5, 12, 0, 0, 0, time.UTC)
user := NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt) user := types.NewUser(uuid.New(), "some@email.de", false, nil, false, []byte("somePass"), nil, createAt)
err := underTest.InsertUser(user) err := underTest.InsertUser(user)
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, types.ErrInternal, err)

View File

@@ -77,11 +77,11 @@ func (handler AuthImpl) handleSignInPage() http.HandlerFunc {
func (handler AuthImpl) handleSignIn() http.HandlerFunc { func (handler AuthImpl) handleSignIn() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*service.User, error) { user, err := utils.WaitMinimumTime(securityWaitDuration, func() (*types.User, error) {
var email = r.FormValue("email") var email = r.FormValue("email")
var password = r.FormValue("password") var password = r.FormValue("password")
session, err := handler.service.SignIn(email, password) session, user, err := handler.service.SignIn(email, password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -89,7 +89,7 @@ func (handler AuthImpl) handleSignIn() http.HandlerFunc {
cookie := middleware.CreateSessionCookie(session.Id) cookie := middleware.CreateSessionCookie(session.Id)
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)
return session.User, nil return user, nil
}) })
if err != nil { if err != nil {
@@ -294,7 +294,8 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := middleware.GetSession(r) session := middleware.GetSession(r)
if session == nil || session.User == nil { user := middleware.GetUser(r)
if session == nil || user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return
} }
@@ -302,7 +303,7 @@ func (handler AuthImpl) handleChangePasswordComp() http.HandlerFunc {
currPass := r.FormValue("current-password") currPass := r.FormValue("current-password")
newPass := r.FormValue("new-password") newPass := r.FormValue("new-password")
err := handler.service.ChangePassword(session, currPass, newPass) err := handler.service.ChangePassword(user, session.Id, currPass, newPass)
if err != nil { if err != nil {
utils.TriggerToast(w, r, "error", "Password not correct", http.StatusUnauthorized) utils.TriggerToast(w, r, "error", "Password not correct", http.StatusUnauthorized)
return return

View File

@@ -32,11 +32,7 @@ func (handler IndexImpl) Handle(router *http.ServeMux) {
func (handler IndexImpl) handleIndexAnd404() http.HandlerFunc { func (handler IndexImpl) handleIndexAnd404() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := middleware.GetSession(r) user := middleware.GetUser(r)
var user *service.User
if session != nil {
user = session.User
}
var comp templ.Component var comp templ.Component

View File

@@ -2,24 +2,29 @@ package middleware
import ( import (
"context" "context"
"net/http"
"me-fit/service" "me-fit/service"
"me-fit/types"
"net/http"
) )
type ContextKey string type ContextKey string
var SessionKey ContextKey = "session" var SessionKey ContextKey = "session"
var UserKey ContextKey = "user"
func Authenticate(service service.Auth) func(http.Handler) http.Handler { func Authenticate(service service.Auth) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
sessionId := getSessionID(r) sessionId := getSessionID(r)
session, _ := service.SignInSession(sessionId) session, user, _ := service.SignInSession(sessionId)
if session != nil { if session != nil {
ctx := context.WithValue(r.Context(), SessionKey, session)
ctx = context.WithValue(ctx, UserKey, user)
ctx = context.WithValue(ctx, SessionKey, session)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} else { } else {
@@ -29,23 +34,22 @@ func Authenticate(service service.Auth) func(http.Handler) http.Handler {
} }
} }
func GetUser(r *http.Request) *service.User { func GetUser(r *http.Request) *types.User {
obj := r.Context().Value(UserKey)
session := GetSession(r) if obj == nil {
if session == nil {
return nil return nil
} }
return session.User return obj.(*types.User)
} }
func GetSession(r *http.Request) *service.Session { func GetSession(r *http.Request) *types.Session {
obj := r.Context().Value(SessionKey) obj := r.Context().Value(SessionKey)
if obj == nil { if obj == nil {
return nil return nil
} }
return obj.(*service.Session) return obj.(*types.Session)
} }
func getSessionID(r *http.Request) string { func getSessionID(r *http.Request) string {

View File

@@ -6,15 +6,16 @@ import (
"strings" "strings"
"me-fit/service" "me-fit/service"
"me-fit/types"
) )
type csrfResponseWriter struct { type csrfResponseWriter struct {
http.ResponseWriter http.ResponseWriter
auth service.Auth auth service.Auth
session *service.Session session *types.Session
} }
func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *service.Session) *csrfResponseWriter { func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *types.Session) *csrfResponseWriter {
return &csrfResponseWriter{ return &csrfResponseWriter{
ResponseWriter: w, ResponseWriter: w,
auth: auth, auth: auth,

View File

@@ -2,7 +2,6 @@ package handler
import ( import (
"me-fit/log" "me-fit/log"
"me-fit/service"
"me-fit/template" "me-fit/template"
"me-fit/template/auth" "me-fit/template/auth"
"me-fit/types" "me-fit/types"
@@ -31,14 +30,14 @@ func (render *Render) Render(r *http.Request, w http.ResponseWriter, comp templ.
} }
} }
func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *service.User) { func (render *Render) RenderLayout(r *http.Request, w http.ResponseWriter, slot templ.Component, user *types.User) {
userComp := render.getUserComp(user) userComp := render.getUserComp(user)
layout := template.Layout(slot, userComp, render.settings.Environment) layout := template.Layout(slot, userComp, render.settings.Environment)
render.Render(r, w, layout) render.Render(r, w, layout)
} }
func (render *Render) getUserComp(user *service.User) templ.Component { func (render *Render) getUserComp(user *types.User) templ.Component {
if user != nil { if user != nil {
return auth.UserComp(user.Email) return auth.UserComp(user.Email)

View File

@@ -38,22 +38,22 @@ func (handler WorkoutImpl) Handle(router *http.ServeMux) {
func (handler WorkoutImpl) handleWorkoutPage() http.HandlerFunc { func (handler WorkoutImpl) handleWorkoutPage() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := middleware.GetSession(r) user := middleware.GetUser(r)
if session == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return
} }
currentDate := time.Now().Format("2006-01-02") currentDate := time.Now().Format("2006-01-02")
comp := workout.WorkoutComp(currentDate) comp := workout.WorkoutComp(currentDate)
handler.render.RenderLayout(r, w, comp, session.User) handler.render.RenderLayout(r, w, comp, user)
} }
} }
func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc { func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := middleware.GetSession(r) user := middleware.GetUser(r)
if session == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return
} }
@@ -64,7 +64,7 @@ func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc {
var repsStr = r.FormValue("reps") var repsStr = r.FormValue("reps")
wo := service.NewWorkoutDto("", dateStr, typeStr, setsStr, repsStr) wo := service.NewWorkoutDto("", dateStr, typeStr, setsStr, repsStr)
wo, err := handler.service.AddWorkout(session.User, wo) wo, err := handler.service.AddWorkout(user, wo)
if err != nil { if err != nil {
utils.TriggerToast(w, r, "error", "Invalid input values", http.StatusBadRequest) utils.TriggerToast(w, r, "error", "Invalid input values", http.StatusBadRequest)
http.Error(w, "Invalid input values", http.StatusBadRequest) http.Error(w, "Invalid input values", http.StatusBadRequest)
@@ -79,13 +79,13 @@ func (handler WorkoutImpl) handleAddWorkout() http.HandlerFunc {
func (handler WorkoutImpl) handleGetWorkout() http.HandlerFunc { func (handler WorkoutImpl) handleGetWorkout() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := middleware.GetSession(r) user := middleware.GetUser(r)
if session == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return
} }
workouts, err := handler.service.GetWorkouts(session.User) workouts, err := handler.service.GetWorkouts(user)
if err != nil { if err != nil {
return return
} }
@@ -102,8 +102,8 @@ func (handler WorkoutImpl) handleGetWorkout() http.HandlerFunc {
func (handler WorkoutImpl) handleDeleteWorkout() http.HandlerFunc { func (handler WorkoutImpl) handleDeleteWorkout() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := middleware.GetSession(r) user := middleware.GetUser(r)
if session == nil { if user == nil {
utils.DoRedirect(w, r, "/auth/signin") utils.DoRedirect(w, r, "/auth/signin")
return return
} }
@@ -120,7 +120,7 @@ func (handler WorkoutImpl) handleDeleteWorkout() http.HandlerFunc {
return return
} }
err = handler.service.DeleteWorkout(session.User, rowIdInt) err = handler.service.DeleteWorkout(user, rowIdInt)
if err != nil { if err != nil {
utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError) utils.TriggerToast(w, r, "error", "Internal Server Error", http.StatusInternalServerError)
return return

View File

@@ -11,7 +11,6 @@ import (
"testing" "testing"
"time" "time"
"me-fit/db"
"me-fit/service" "me-fit/service"
"me-fit/types" "me-fit/types"
@@ -271,7 +270,7 @@ func TestIntegrationAuth(t *testing.T) {
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
var token string var token string
err = d.QueryRow("SELECT token FROM token WHERE type = ?", db.TokenTypePasswordReset).Scan(&token) err = d.QueryRow("SELECT token FROM token WHERE type = ?", types.TokenTypePasswordReset).Scan(&token)
assert.Nil(t, err) assert.Nil(t, err)
formData = url.Values{ formData = url.Values{

View File

@@ -26,55 +26,25 @@ var (
ErrTokenInvalid = errors.New("token is invalid") ErrTokenInvalid = errors.New("token is invalid")
) )
type User struct {
Id uuid.UUID
Email string
EmailVerified bool
}
func NewUser(user *db.User) *User {
return &User{
Id: user.Id,
Email: user.Email,
EmailVerified: user.EmailVerified,
}
}
type Session struct {
Id string
CreatedAt time.Time
ExpiresAt time.Time
User *User
}
func NewSession(session *db.Session, user *User) *Session {
return &Session{
Id: session.Id,
CreatedAt: session.CreatedAt,
ExpiresAt: session.ExpiresAt,
User: user,
}
}
type Auth interface { type Auth interface {
SignUp(email string, password string) (*User, error) SignUp(email string, password string) (*types.User, error)
SendVerificationMail(userId uuid.UUID, email string) SendVerificationMail(userId uuid.UUID, email string)
VerifyUserEmail(token string) error VerifyUserEmail(token string) error
SignIn(email string, password string) (*Session, error) SignIn(email string, password string) (*types.Session, *types.User, error)
SignInSession(sessionId string) (*Session, error) SignInSession(sessionId string) (*types.Session, *types.User, error)
SignInAnonymous() (*Session, error) SignInAnonymous() (*types.Session, error)
SignOut(sessionId string) error SignOut(sessionId string) error
DeleteAccount(user *User, currPass string) error DeleteAccount(user *types.User, currPass string) error
ChangePassword(session *Session, currPass, newPass string) error ChangePassword(user *types.User, sessionId string, currPass, newPass string) error
SendForgotPasswordMail(email string) error SendForgotPasswordMail(email string) error
ForgotPassword(token string, newPass string) error ForgotPassword(token string, newPass string) error
IsCsrfTokenValid(tokenStr string, sessionId string) bool IsCsrfTokenValid(tokenStr string, sessionId string) bool
GetCsrfToken(session *Session) (string, error) GetCsrfToken(session *types.Session) (string, error)
} }
type AuthImpl struct { type AuthImpl struct {
@@ -95,69 +65,65 @@ func NewAuthImpl(db db.Auth, random Random, clock Clock, mail Mail, serverSettin
} }
} }
func (service AuthImpl) SignIn(email string, password string) (*Session, error) { func (service AuthImpl) SignIn(email string, password string) (*types.Session, *types.User, error) {
user, err := service.db.GetUserByEmail(email) user, err := service.db.GetUserByEmail(email)
if err != nil { if err != nil {
if errors.Is(err, db.ErrNotFound) { if errors.Is(err, db.ErrNotFound) {
return nil, ErrInvalidCredentials return nil, nil, ErrInvalidCredentials
} else { } else {
return nil, types.ErrInternal return nil, nil, types.ErrInternal
} }
} }
hash := GetHashPassword(password, user.Salt) hash := GetHashPassword(password, user.Salt)
if subtle.ConstantTimeCompare(hash, user.Password) == 0 { if subtle.ConstantTimeCompare(hash, user.Password) == 0 {
return nil, ErrInvalidCredentials return nil, nil, ErrInvalidCredentials
} }
session, err := service.createSession(user.Id) session, err := service.createSession(user.Id)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, nil, types.ErrInternal
} }
return NewSession(session, NewUser(user)), nil return session, user, nil
} }
func (service AuthImpl) SignInSession(sessionId string) (*Session, error) { func (service AuthImpl) SignInSession(sessionId string) (*types.Session, *types.User, error) {
if sessionId == "" { if sessionId == "" {
return nil, ErrSessionIdInvalid return nil, nil, ErrSessionIdInvalid
} }
sessionDb, err := service.db.GetSession(sessionId) session, err := service.db.GetSession(sessionId)
if err != nil {
return nil, nil, types.ErrInternal
}
if session.ExpiresAt.Before(service.clock.Now()) {
return nil, nil, nil
}
if session.UserId == uuid.Nil {
return session, nil, nil
}
user, err := service.db.GetUser(session.UserId)
if err != nil {
return nil, nil, types.ErrInternal
}
return session, user, nil
}
func (service AuthImpl) SignInAnonymous() (*types.Session, error) {
session, err := service.createSession(uuid.Nil)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, types.ErrInternal
} }
if sessionDb.ExpiresAt.Before(service.clock.Now()) {
return nil, nil
}
if sessionDb.UserId == uuid.Nil {
return NewSession(sessionDb, nil), nil
}
userDb, err := service.db.GetUser(sessionDb.UserId)
if err != nil {
return nil, types.ErrInternal
}
user := NewUser(userDb)
session := NewSession(sessionDb, user)
return session, nil return session, nil
} }
func (service AuthImpl) SignInAnonymous() (*Session, error) { func (service AuthImpl) createSession(userId uuid.UUID) (*types.Session, error) {
sessionDb, err := service.createSession(uuid.Nil)
if err != nil {
return nil, types.ErrInternal
}
return NewSession(sessionDb, nil), nil
}
func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) {
sessionId, err := service.random.String(32) sessionId, err := service.random.String(32)
if err != nil { if err != nil {
return nil, types.ErrInternal return nil, types.ErrInternal
@@ -172,7 +138,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) {
createAt := service.clock.Now() createAt := service.clock.Now()
expiresAt := createAt.Add(24 * time.Hour) expiresAt := createAt.Add(24 * time.Hour)
session := db.NewSession(sessionId, userId, createAt, expiresAt) session := types.NewSession(sessionId, userId, createAt, expiresAt)
err = service.db.InsertSession(session) err = service.db.InsertSession(session)
if err != nil { if err != nil {
@@ -182,7 +148,7 @@ func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) {
return session, nil return session, nil
} }
func (service AuthImpl) SignUp(email string, password string) (*User, error) { func (service AuthImpl) SignUp(email string, password string) (*types.User, error) {
_, err := mail.ParseAddress(email) _, err := mail.ParseAddress(email)
if err != nil { if err != nil {
return nil, ErrInvalidEmail return nil, ErrInvalidEmail
@@ -204,9 +170,9 @@ func (service AuthImpl) SignUp(email string, password string) (*User, error) {
hash := GetHashPassword(password, salt) hash := GetHashPassword(password, salt)
dbUser := db.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now()) user := types.NewUser(userId, email, false, nil, false, hash, salt, service.clock.Now())
err = service.db.InsertUser(dbUser) err = service.db.InsertUser(user)
if err != nil { if err != nil {
if err == db.ErrAlreadyExists { if err == db.ErrAlreadyExists {
return nil, ErrAccountExists return nil, ErrAccountExists
@@ -215,17 +181,17 @@ func (service AuthImpl) SignUp(email string, password string) (*User, error) {
} }
} }
return NewUser(dbUser), nil return user, nil
} }
func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) { func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
tokens, err := service.db.GetTokensByUserIdAndType(userId, db.TokenTypeEmailVerify) tokens, err := service.db.GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify)
if err != nil && err != db.ErrNotFound { if err != nil && err != db.ErrNotFound {
return return
} }
var token *db.Token var token *types.Token
if len(tokens) > 0 { if len(tokens) > 0 {
token = tokens[0] token = tokens[0]
@@ -237,7 +203,7 @@ func (service AuthImpl) SendVerificationMail(userId uuid.UUID, email string) {
return return
} }
token = db.NewToken(userId, "", newTokenStr, db.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) token = types.NewToken(userId, "", newTokenStr, types.TokenTypeEmailVerify, service.clock.Now(), service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token) err = service.db.InsertToken(token)
if err != nil { if err != nil {
@@ -271,7 +237,7 @@ func (service AuthImpl) VerifyUserEmail(tokenStr string) error {
return types.ErrInternal return types.ErrInternal
} }
if token.Type != db.TokenTypeEmailVerify { if token.Type != types.TokenTypeEmailVerify {
return types.ErrInternal return types.ErrInternal
} }
@@ -298,7 +264,7 @@ func (service AuthImpl) SignOut(sessionId string) error {
return service.db.DeleteSession(sessionId) return service.db.DeleteSession(sessionId)
} }
func (service AuthImpl) DeleteAccount(user *User, currPass string) error { func (service AuthImpl) DeleteAccount(user *types.User, currPass string) error {
userDb, err := service.db.GetUser(user.Id) userDb, err := service.db.GetUser(user.Id)
if err != nil { if err != nil {
@@ -320,7 +286,7 @@ func (service AuthImpl) DeleteAccount(user *User, currPass string) error {
return nil return nil
} }
func (service AuthImpl) ChangePassword(session *Session, currPass, newPass string) error { func (service AuthImpl) ChangePassword(user *types.User, sessionId string, currPass, newPass string) error {
if !isPasswordValid(newPass) { if !isPasswordValid(newPass) {
return ErrInvalidPassword return ErrInvalidPassword
@@ -330,31 +296,26 @@ func (service AuthImpl) ChangePassword(session *Session, currPass, newPass strin
return ErrInvalidPassword return ErrInvalidPassword
} }
userDb, err := service.db.GetUser(session.User.Id) currHash := GetHashPassword(currPass, user.Salt)
if err != nil {
return err
}
currHash := GetHashPassword(currPass, userDb.Salt) if subtle.ConstantTimeCompare(currHash, user.Password) == 0 {
if subtle.ConstantTimeCompare(currHash, userDb.Password) == 0 {
return ErrInvalidCredentials return ErrInvalidCredentials
} }
newHash := GetHashPassword(newPass, userDb.Salt) newHash := GetHashPassword(newPass, user.Salt)
userDb.Password = newHash user.Password = newHash
err = service.db.UpdateUser(userDb) err := service.db.UpdateUser(user)
if err != nil { if err != nil {
return err return err
} }
sessions, err := service.db.GetSessions(userDb.Id) sessions, err := service.db.GetSessions(user.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
} }
for _, s := range sessions { for _, s := range sessions {
if s.Id != session.Id { if s.Id != sessionId {
err = service.db.DeleteSession(s.Id) err = service.db.DeleteSession(s.Id)
if err != nil { if err != nil {
return types.ErrInternal return types.ErrInternal
@@ -380,7 +341,7 @@ func (service AuthImpl) SendForgotPasswordMail(email string) error {
} }
} }
token := db.NewToken(user.Id, "", tokenStr, db.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute)) token := types.NewToken(user.Id, "", tokenStr, types.TokenTypePasswordReset, service.clock.Now(), service.clock.Now().Add(15*time.Minute))
err = service.db.InsertToken(token) err = service.db.InsertToken(token)
if err != nil { if err != nil {
@@ -414,7 +375,7 @@ func (service AuthImpl) ForgotPassword(tokenStr string, newPass string) error {
return err return err
} }
if token.Type != db.TokenTypePasswordReset || if token.Type != types.TokenTypePasswordReset ||
token.ExpiresAt.Before(service.clock.Now()) { token.ExpiresAt.Before(service.clock.Now()) {
return ErrTokenInvalid return ErrTokenInvalid
} }
@@ -454,7 +415,7 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
return false return false
} }
if token.Type != db.TokenTypeCsrf || if token.Type != types.TokenTypeCsrf ||
token.SessionId != sessionId || token.SessionId != sessionId ||
token.ExpiresAt.Before(service.clock.Now()) { token.ExpiresAt.Before(service.clock.Now()) {
@@ -464,12 +425,12 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool
return true return true
} }
func (service AuthImpl) GetCsrfToken(session *Session) (string, error) { func (service AuthImpl) GetCsrfToken(session *types.Session) (string, error) {
if session == nil { if session == nil {
return "", types.ErrInternal return "", types.ErrInternal
} }
tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, db.TokenTypeCsrf) tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, types.TokenTypeCsrf)
if len(tokens) > 0 { if len(tokens) > 0 {
return tokens[0].Token, nil return tokens[0].Token, nil
@@ -480,7 +441,7 @@ func (service AuthImpl) GetCsrfToken(session *Session) (string, error) {
return "", types.ErrInternal return "", types.ErrInternal
} }
token := db.NewToken(uuid.Nil, session.Id, tokenStr, db.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(24*time.Hour)) token := types.NewToken(uuid.Nil, session.Id, tokenStr, types.TokenTypeCsrf, service.clock.Now(), service.clock.Now().Add(24*time.Hour))
err = service.db.InsertToken(token) err = service.db.InsertToken(token)
if err != nil { if err != nil {
return "", types.ErrInternal return "", types.ErrInternal

View File

@@ -22,7 +22,7 @@ func TestSignIn(t *testing.T) {
t.Parallel() t.Parallel()
salt := []byte("salt") salt := []byte("salt")
verifiedAt := time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)
user := db.NewUser( user := types.NewUser(
uuid.New(), uuid.New(),
"test@test.de", "test@test.de",
true, true,
@@ -33,12 +33,12 @@ func TestSignIn(t *testing.T) {
time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
) )
dbSession := db.NewSession("sessionId", user.Id, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) session := types.NewSession("sessionId", user.Id, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
mockAuthDb := mocks.NewMockAuth(t) mockAuthDb := mocks.NewMockAuth(t)
mockAuthDb.EXPECT().GetUserByEmail("test@test.de").Return(user, nil) mockAuthDb.EXPECT().GetUserByEmail("test@test.de").Return(user, nil)
mockAuthDb.EXPECT().DeleteOldSessions(user.Id).Return(nil) mockAuthDb.EXPECT().DeleteOldSessions(user.Id).Return(nil)
mockAuthDb.EXPECT().InsertSession(dbSession).Return(nil) mockAuthDb.EXPECT().InsertSession(session).Return(nil)
mockRandom := mocks.NewMockRandom(t) mockRandom := mocks.NewMockRandom(t)
mockRandom.EXPECT().String(32).Return("sessionId", nil) mockRandom.EXPECT().String(32).Return("sessionId", nil)
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
@@ -47,11 +47,11 @@ func TestSignIn(t *testing.T) {
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
actualSession, err := underTest.SignIn(user.Email, "password") actualSession, actualUser, err := underTest.SignIn(user.Email, "password")
assert.Nil(t, err) assert.Nil(t, err)
expectedSession := NewSession(dbSession, NewUser(user)) assert.Equal(t, session, actualSession)
assert.Equal(t, expectedSession, actualSession) assert.Equal(t, user, actualUser)
}) })
t.Run("should return ErrInvalidCretentials if password is not correct", func(t *testing.T) { t.Run("should return ErrInvalidCretentials if password is not correct", func(t *testing.T) {
@@ -59,7 +59,7 @@ func TestSignIn(t *testing.T) {
salt := []byte("salt") salt := []byte("salt")
verifiedAt := time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC) verifiedAt := time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)
user := db.NewUser( user := types.NewUser(
uuid.New(), uuid.New(),
"test@test.de", "test@test.de",
true, true,
@@ -78,7 +78,7 @@ func TestSignIn(t *testing.T) {
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
_, err := underTest.SignIn("test@test.de", "wrong password") _, _, err := underTest.SignIn("test@test.de", "wrong password")
assert.Equal(t, ErrInvalidCredentials, err) assert.Equal(t, ErrInvalidCredentials, err)
}) })
@@ -93,7 +93,7 @@ func TestSignIn(t *testing.T) {
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
_, err := underTest.SignIn("test", "test") _, _, err := underTest.SignIn("test", "test")
assert.Equal(t, ErrInvalidCredentials, err) assert.Equal(t, ErrInvalidCredentials, err)
}) })
t.Run("should forward ErrInternal on any other error", func(t *testing.T) { t.Run("should forward ErrInternal on any other error", func(t *testing.T) {
@@ -107,7 +107,7 @@ func TestSignIn(t *testing.T) {
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
_, err := underTest.SignIn("test", "test") _, _, err := underTest.SignIn("test", "test")
assert.Equal(t, types.ErrInternal, err) assert.Equal(t, types.ErrInternal, err)
}) })
@@ -159,33 +159,25 @@ func TestSignUp(t *testing.T) {
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
expected := User{ userId := uuid.New()
Id: uuid.New(), email := "mail@mail.de"
Email: "some@valid.email",
EmailVerified: false,
}
random := NewRandomImpl()
salt, err := random.Bytes(16)
assert.Nil(t, err)
password := "SomeStrongPassword123!" password := "SomeStrongPassword123!"
salt := []byte("salt")
mockRandom.EXPECT().UUID().Return(expected.Id, nil)
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
mockClock.EXPECT().Now().Return(createTime) expected := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime)
mockAuthDb.EXPECT().InsertUser(db.NewUser(expected.Id, expected.Email, false, nil, false, GetHashPassword(password, salt), salt, createTime)).Return(nil) mockRandom.EXPECT().UUID().Return(userId, nil)
mockRandom.EXPECT().Bytes(16).Return(salt, nil)
mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(expected).Return(nil)
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
actual, err := underTest.SignUp(email, password)
actual, err := underTest.SignUp(expected.Email, password)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, expected, *actual) assert.Equal(t, expected, actual)
}) })
t.Run("should return ErrAccountExists", func(t *testing.T) { t.Run("should return ErrAccountExists", func(t *testing.T) {
t.Parallel() t.Parallel()
@@ -195,28 +187,22 @@ func TestSignUp(t *testing.T) {
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
user := User{ userId := uuid.New()
Id: uuid.New(), email := "some@valid.email"
Email: "some@valid.email", createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
}
random := NewRandomImpl()
salt, err := random.Bytes(16)
assert.Nil(t, err)
password := "SomeStrongPassword123!" password := "SomeStrongPassword123!"
salt := []byte("salt")
user := types.NewUser(userId, email, false, nil, false, GetHashPassword(password, salt), salt, createTime)
mockRandom.EXPECT().UUID().Return(user.Id, nil) mockRandom.EXPECT().UUID().Return(user.Id, nil)
mockRandom.EXPECT().Bytes(16).Return(salt, nil) mockRandom.EXPECT().Bytes(16).Return(salt, nil)
createTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
mockClock.EXPECT().Now().Return(createTime) mockClock.EXPECT().Now().Return(createTime)
mockAuthDb.EXPECT().InsertUser(db.NewUser(user.Id, user.Email, false, nil, false, GetHashPassword(password, salt), salt, createTime)).Return(db.ErrAlreadyExists) mockAuthDb.EXPECT().InsertUser(user).Return(db.ErrAlreadyExists)
underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{}) underTest := NewAuthImpl(mockAuthDb, mockRandom, mockClock, mockMail, &types.Settings{})
_, err = underTest.SignUp(user.Email, password) _, err := underTest.SignUp(user.Email, password)
assert.Equal(t, ErrAccountExists, err) assert.Equal(t, ErrAccountExists, err)
}) })
} }
@@ -227,8 +213,8 @@ func TestSendVerificationMail(t *testing.T) {
t.Run("should use stored token and send mail", func(t *testing.T) { t.Run("should use stored token and send mail", func(t *testing.T) {
t.Parallel() t.Parallel()
token := db.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", db.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC)) token := types.NewToken(uuid.New(), "sessionId", "someRandomTokenToUse", types.TokenTypeEmailVerify, time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2020, 1, 2, 0, 0, 0, 0, time.UTC))
tokens := []*db.Token{token} tokens := []*types.Token{token}
email := "some@email.de" email := "some@email.de"
userId := uuid.New() userId := uuid.New()
@@ -238,7 +224,7 @@ func TestSendVerificationMail(t *testing.T) {
mockClock := mocks.NewMockClock(t) mockClock := mocks.NewMockClock(t)
mockMail := mocks.NewMockMail(t) mockMail := mocks.NewMockMail(t)
mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, db.TokenTypeEmailVerify).Return(tokens, nil) mockAuthDb.EXPECT().GetTokensByUserIdAndType(userId, types.TokenTypeEmailVerify).Return(tokens, nil)
mockMail.EXPECT().SendMail(email, "Welcome to ME-FIT", mock.MatchedBy(func(message string) bool { mockMail.EXPECT().SendMail(email, "Welcome to ME-FIT", mock.MatchedBy(func(message string) bool {
return strings.Contains(message, token.Token) return strings.Contains(message, token.Token)

View File

@@ -10,9 +10,9 @@ import (
) )
type Workout interface { type Workout interface {
AddWorkout(user *User, workoutDto *WorkoutDto) (*WorkoutDto, error) AddWorkout(user *types.User, workoutDto *WorkoutDto) (*WorkoutDto, error)
DeleteWorkout(user *User, rowId int) error DeleteWorkout(user *types.User, rowId int) error
GetWorkouts(user *User) ([]*WorkoutDto, error) GetWorkouts(user *types.User) ([]*WorkoutDto, error)
} }
type WorkoutImpl struct { type WorkoutImpl struct {
@@ -64,7 +64,7 @@ var (
ErrInputValues = errors.New("invalid input values") ErrInputValues = errors.New("invalid input values")
) )
func (service WorkoutImpl) AddWorkout(user *User, workoutDto *WorkoutDto) (*WorkoutDto, error) { func (service WorkoutImpl) AddWorkout(user *types.User, workoutDto *WorkoutDto) (*WorkoutDto, error) {
if workoutDto.Date == "" || workoutDto.Type == "" || workoutDto.Sets == "" || workoutDto.Reps == "" { if workoutDto.Date == "" || workoutDto.Type == "" || workoutDto.Sets == "" || workoutDto.Reps == "" {
return nil, ErrInputValues return nil, ErrInputValues
@@ -95,7 +95,7 @@ func (service WorkoutImpl) AddWorkout(user *User, workoutDto *WorkoutDto) (*Work
return NewWorkoutDtoFromDb(workout), nil return NewWorkoutDtoFromDb(workout), nil
} }
func (service WorkoutImpl) DeleteWorkout(user *User, rowId int) error { func (service WorkoutImpl) DeleteWorkout(user *types.User, rowId int) error {
if user == nil { if user == nil {
return types.ErrInternal return types.ErrInternal
} }
@@ -103,7 +103,7 @@ func (service WorkoutImpl) DeleteWorkout(user *User, rowId int) error {
return service.db.DeleteWorkout(user.Id, rowId) return service.db.DeleteWorkout(user.Id, rowId)
} }
func (service WorkoutImpl) GetWorkouts(user *User) ([]*WorkoutDto, error) { func (service WorkoutImpl) GetWorkouts(user *types.User) ([]*WorkoutDto, error) {
if user == nil { if user == nil {
return nil, types.ErrInternal return nil, types.ErrInternal
} }

75
types/auth.go Normal file
View File

@@ -0,0 +1,75 @@
package types
import (
"time"
"github.com/google/uuid"
)
type User struct {
Id uuid.UUID
Email string
EmailVerified bool
EmailVerifiedAt *time.Time
IsAdmin bool
Password []byte
Salt []byte
CreateAt time.Time
}
func NewUser(id uuid.UUID, email string, emailVerified bool, emailVerifiedAt *time.Time, isAdmin bool, password []byte, salt []byte, createAt time.Time) *User {
return &User{
Id: id,
Email: email,
EmailVerified: emailVerified,
EmailVerifiedAt: emailVerifiedAt,
IsAdmin: isAdmin,
Password: password,
Salt: salt,
CreateAt: createAt,
}
}
type Session struct {
Id string
UserId uuid.UUID
CreatedAt time.Time
ExpiresAt time.Time
}
func NewSession(id string, userId uuid.UUID, createdAt time.Time, expiresAt time.Time) *Session {
return &Session{
Id: id,
UserId: userId,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
}
}
type Token struct {
UserId uuid.UUID
SessionId string
Token string
Type TokenType
CreatedAt time.Time
ExpiresAt time.Time
}
type TokenType string
var (
TokenTypeEmailVerify TokenType = "email_verify"
TokenTypePasswordReset TokenType = "password_reset"
TokenTypeCsrf TokenType = "csrf"
)
func NewToken(userId uuid.UUID, sessionId string, token string, tokenType TokenType, createdAt time.Time, expiresAt time.Time) *Token {
return &Token{
UserId: userId,
SessionId: sessionId,
Token: token,
Type: tokenType,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
}
}