diff --git a/go.mod b/go.mod index 81b7149..f31f233 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/prometheus/client_golang v1.20.5 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.30.0 + golang.org/x/net v0.29.0 ) require ( diff --git a/go.sum b/go.sum index f8209bf..b673f51 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y= github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -21,14 +19,6 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -43,8 +33,6 @@ github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -53,12 +41,12 @@ go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.30.0 h1:RwoQn3GkWiMkzlX562cLB7OxWvjH1L8xutO2WoJcRoY= golang.org/x/crypto v0.30.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handler/middleware/authenticate.go b/handler/middleware/authenticate.go index fa540ef..bb6df5e 100644 --- a/handler/middleware/authenticate.go +++ b/handler/middleware/authenticate.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "me-fit/service" "net/http" @@ -43,5 +44,5 @@ func getSessionID(r *http.Request) string { return "" } - return cookie.Name + return cookie.Value } diff --git a/handler/middleware/cross_site_request_forgery.go b/handler/middleware/cross_site_request_forgery.go index 2edbaa9..eae2bc1 100644 --- a/handler/middleware/cross_site_request_forgery.go +++ b/handler/middleware/cross_site_request_forgery.go @@ -2,9 +2,10 @@ package middleware import ( "fmt" - "me-fit/service" "strings" + "me-fit/service" + "net/http" ) @@ -22,9 +23,6 @@ func newCsrfResponseWriter(w http.ResponseWriter, auth service.Auth, session *se } } - -TODO: Create session for CSRF token - func (rr *csrfResponseWriter) Write(data []byte) (int, error) { dataStr := string(data) if strings.Contains(dataStr, "") { @@ -38,6 +36,10 @@ func (rr *csrfResponseWriter) Write(data []byte) (int, error) { return rr.ResponseWriter.Write([]byte(dataStr)) } +func (rr *csrfResponseWriter) WriteHeader(statusCode int) { + rr.ResponseWriter.WriteHeader(statusCode) +} + func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -56,6 +58,25 @@ func CrossSiteRequestForgery(auth service.Auth) func(http.Handler) http.Handler } } + if session == nil { + var err error + session, err = auth.SignInAnonymous() + if err != nil { + http.Error(w, "", http.StatusInternalServerError) + return + } + } + cookie := http.Cookie{ + Name: "id", + Value: session.Id, + MaxAge: 60 * 60 * 8, // 8 hours + Secure: true, + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Path: "/", + } + http.SetCookie(w, &cookie) + responseWriter := newCsrfResponseWriter(w, auth, session) next.ServeHTTP(responseWriter, r) }) diff --git a/main.go b/main.go index 56e6b79..5fbcec2 100644 --- a/main.go +++ b/main.go @@ -130,6 +130,7 @@ func createHandler(d *sql.DB, serverSettings *types.Settings) http.Handler { middleware.Log, middleware.ContentSecurityPolicy, middleware.Cors(serverSettings), + middleware.Authenticate(authService), middleware.CrossSiteRequestForgery(authService), middleware.Corp, middleware.Coop, diff --git a/main_test.go b/main_test.go index a419bb7..e24737d 100644 --- a/main_test.go +++ b/main_test.go @@ -14,6 +14,8 @@ import ( "time" "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "golang.org/x/net/html" ) func TestHandleSignIn(t *testing.T) { @@ -39,25 +41,35 @@ func TestHandleSignIn(t *testing.T) { t.Fatalf("Error inserting user: %v", err) } - formData := url.Values{ - "email": {"mail@mail.de"}, - "password": {"password"}, - } - - req, err := http.NewRequestWithContext(ctx, "POST", "http://localhost:8080/api/auth/signin", strings.NewReader(formData.Encode())) - if err != nil { - t.Fatalf("Error creating request: %v", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost:8080/auth/signin", nil) + assert.Nil(t, err) resp, err := httpClient.Do(req) - if err != nil { - t.Fatalf("Error making request: %v", err) + assert.Nil(t, err) + + html, err := html.Parse(resp.Body) + assert.Nil(t, err) + + csrfToken := findCsrfToken(html) + assert.NotEqual(t, "", csrfToken) + anonymousSession := findCookie(resp, "id") + assert.NotNil(t, anonymousSession) + + formData := url.Values{ + "email": {"mail@mail.de"}, + "password": {"password"}, + "csrf-token": {csrfToken}, } - if resp.StatusCode != http.StatusSeeOther { - t.Fatalf("Expected status code 303, got %d", resp.StatusCode) - } + req, err = http.NewRequestWithContext(ctx, "POST", "http://localhost:8080/api/auth/signin", strings.NewReader(formData.Encode())) + assert.Nil(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", anonymousSession.Name+"="+anonymousSession.Value) + + resp, err = httpClient.Do(req) + assert.Nil(t, err) + + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) cookie := findCookie(resp, "id") if cookie == nil { @@ -165,3 +177,44 @@ func waitForReady( } } } + +func findCsrfToken(data *html.Node) string { + attr := getTokenAttribute(data) + if attr != nil { + return attr.Val + } + + if data.FirstChild != nil { + if token := findCsrfToken(data.FirstChild); token != "" { + return token + } + } + if data.NextSibling != nil { + if token := findCsrfToken(data.NextSibling); token != "" { + return token + } + } + + return "" +} + +func getTokenAttribute(data *html.Node) *html.Attribute { + returnValue := false + for _, attr := range data.Attr { + if attr.Key == "name" && attr.Val == "csrf-token" { + returnValue = true + } + } + + if !returnValue { + return nil + } + + for _, attr := range data.Attr { + if attr.Key == "value" { + return &attr + } + } + + return nil +} diff --git a/service/auth.go b/service/auth.go index 8c18a08..5c81af0 100644 --- a/service/auth.go +++ b/service/auth.go @@ -62,6 +62,7 @@ type Auth interface { SignIn(email string, password string) (*Session, error) SignInSession(sessionId string) (*Session, error) + SignInAnonymous() (*Session, error) SignOut(sessionId string) error DeleteAccount(user *User) error @@ -127,10 +128,14 @@ func (service AuthImpl) SignInSession(sessionId string) (*Session, error) { return nil, types.ErrInternal } - if sessionDb.ExpiresAt.After(service.clock.Now()) { + if sessionDb.ExpiresAt.Before(service.clock.Now()) { return nil, nil } + if sessionDb.UserId == uuid.Nil { + return NewSession(sessionDb, nil), nil + } + userDb, err := service.db.GetUser(sessionDb.UserId) if err != nil { return nil, types.ErrInternal @@ -142,6 +147,15 @@ func (service AuthImpl) SignInSession(sessionId string) (*Session, error) { return session, nil } +func (service AuthImpl) SignInAnonymous() (*Session, error) { + sessionDb, err := service.createSession(uuid.Nil) + if err != nil { + return nil, types.ErrInternal + } + + return NewSession(sessionDb, nil), nil +} + func (service AuthImpl) createSession(userId uuid.UUID) (*db.Session, error) { sessionId, err := service.random.String(32) if err != nil { @@ -411,6 +425,10 @@ func (service AuthImpl) IsCsrfTokenValid(tokenStr string, sessionId string) bool } func (service AuthImpl) GetCsrfToken(session *Session) (string, error) { + if session == nil { + return "", types.ErrInternal + } + tokens, _ := service.db.GetTokensBySessionIdAndType(session.Id, db.TokenTypeCsrf) if len(tokens) > 0 {