diff --git a/internal/auth/auth.go b/internal/auth/auth.go index dcc81eee2aa97aa18f39d2dc857c441561fb3bcb..bbb8daa61cc4540813365f03fae7e796e1ea4093 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -91,10 +91,12 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains s return false } + logRequest(r).Debug("> Inside TryAuthenticate: before checkSession") session, err := a.checkSession(w, r) if err != nil { return true } + logRequest(r).WithField("Session state", session.Values["state"]).Debug("> Inside TryAuthenticate: after checkSession") // Request is for auth if r.URL.Path != callbackPath { @@ -140,7 +142,7 @@ func (a *Auth) checkAuthenticationResponse(session *hostSession, w http.Response return } - decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r)) + decryptedCode, err := a.DecryptCode(r.URL.Query().Get("code"), getRequestDomain(r, getNamespaceInPathFromSession(session))) if err != nil { logRequest(r).WithError(err).Error("failed to decrypt secure code") errortracking.CaptureErrWithReqAndStackTrace(err, r) @@ -256,6 +258,7 @@ func (a *Auth) handleProxyingAuth(session *hostSession, w http.ResponseWriter, r "domain_query": domain, }).Info("Redirecting user to gitlab for oauth") + logRequest(r).WithField("redirect url", url).Debug("> shouldProxyAuthToGitlab redirect") http.Redirect(w, r, url, http.StatusFound) return true @@ -303,6 +306,7 @@ func (a *Auth) handleProxyingAuth(session *hostSession, w http.ResponseWriter, r // Redirect pages to originating domain with code and state to finish // authentication process + logRequest(r).WithField("redirect url", proxyDomain+r.URL.Path+"?"+query.Encode()).Debug("> shouldProxyCallbackToCustomDomain redirect") http.Redirect(w, r, proxyDomain+r.URL.Path+"?"+query.Encode(), http.StatusFound) return true } @@ -317,11 +321,24 @@ func getRequestAddress(r *http.Request) string { return "http://" + r.Host + r.RequestURI } -func getRequestDomain(r *http.Request) string { +func getRequestDomain(r *http.Request, namespace string) string { + requestDomain := r.Host + if len(namespace) > 0 && strings.HasPrefix(r.Host, namespace) { + requestDomain = strings.TrimPrefix(r.Host, namespace+".") + "/" + namespace + } + if request.IsHTTPS(r) { - return "https://" + r.Host + return "https://" + requestDomain + } + return "http://" + requestDomain +} + +func getNamespaceInPathFromSession(session *hostSession) string { + namespaceInPath := "" + if len(session.Options.Path) > 1 && session.Options.Path[0] == '/' { + namespaceInPath = session.Options.Path[1:] } - return "http://" + r.Host + return namespaceInPath } func shouldProxyAuthToGitlab(r *http.Request) bool { @@ -334,6 +351,7 @@ func shouldProxyCallbackToCustomDomain(session *hostSession) bool { func validateState(r *http.Request, session *hostSession) bool { state := r.URL.Query().Get("state") + logRequest(r).WithField("Session state", session.Values["state"]).Debug("> Inside validateState") if state == "" { // No state param return false @@ -422,6 +440,7 @@ func (a *Auth) checkTokenExists(session *hostSession, w http.ResponseWriter, r * if session.Values["state"] == nil { //Generate state hash and store requested address session.Values["state"] = base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(16)) + logRequest(r).WithField("Session state", session.Values["state"]).Debug("> Inside checkTokenExists session state nil") } session.Values["uri"] = getRequestAddress(r) @@ -440,15 +459,15 @@ func (a *Auth) checkTokenExists(session *hostSession, w http.ResponseWriter, r * // Because the pages domain might be in public suffix list, we have to // redirect to pages domain to trigger authorization flow - http.Redirect(w, r, a.getProxyAddress(r, session.Values["state"].(string)), http.StatusFound) + http.Redirect(w, r, a.getProxyAddress(r, session.Values["state"].(string), getNamespaceInPathFromSession(session)), http.StatusFound) return true } return false } -func (a *Auth) getProxyAddress(r *http.Request, state string) string { - return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, getRequestDomain(r), state) +func (a *Auth) getProxyAddress(r *http.Request, state string, namespace string) string { + return fmt.Sprintf(authorizeProxyTemplate, a.redirectURI, getRequestDomain(r, namespace), state) } func destroySession(session *hostSession, w http.ResponseWriter, r *http.Request) { diff --git a/internal/auth/session.go b/internal/auth/session.go index 5bfc8e0349bb6d95f87445f4208cc8daec634d04..e2a7c3c1ccd3a126350c35925c9322a57cea635d 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -17,6 +17,8 @@ type hostSession struct { const sessionHostKey = "_session_host" func (s *hostSession) Save(r *http.Request, w http.ResponseWriter) error { + logRequest(r).WithField("_session_host", s.Session.Values[sessionHostKey]).Debug("> Session save") + logRequest(r).WithField("request host", r.Host).Debug("> Session save") s.Session.Values[sessionHostKey] = r.Host return s.Session.Save(r, w) @@ -26,12 +28,17 @@ func (a *Auth) getSessionFromStore(r *http.Request) (*hostSession, error) { session, err := a.store.Get(r, "gitlab-pages") if session != nil { + namespaceInPath := request.GetNamespaceInPathFromRequest(r, a.pagesDomain) + logRequest(r).WithField("X-Gitlab-Namespace-In-Path", namespaceInPath).Debug("> Inside getSessionFromStore") + // Cookie just for this domain - session.Options.Path = "/" + session.Options.Path = "/" + namespaceInPath session.Options.HttpOnly = true session.Options.Secure = request.IsHTTPS(r) session.Options.MaxAge = int(a.cookieSessionTimeout.Seconds()) + logRequest(r).WithField("_session_host", session.Values[sessionHostKey]).Debug("> Inside getSessionFromStore: before host compare") + logRequest(r).WithField("request host", r.Host).Debug("> Inside getSessionFromStore: before host compare") if session.Values[sessionHostKey] == nil || session.Values[sessionHostKey] != r.Host { session.Values = make(map[interface{}]interface{}) } diff --git a/internal/request/request.go b/internal/request/request.go index f98b081955f575369f9faf1bf36ae74fbe0d0a33..4a7e1e5202623b01524b63cfe636f4ddebf3be28 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -3,6 +3,7 @@ package request import ( "net" "net/http" + "strings" ) const ( @@ -38,3 +39,15 @@ func GetRemoteAddrWithoutPort(r *http.Request) string { return remoteAddr } + +// GetNamespaceInPathFromRequest GetNamespaceInPath fetches X-Gitlab-Namespace-In-Path from r.Header and validates against pagesDomain before returning +func GetNamespaceInPathFromRequest(r *http.Request, pagesDomain string) string { + namespaceInPath := "" + if pagesDomain != "" { + namespaceInPath = r.Header.Get("X-Gitlab-Namespace-In-Path") + if namespaceInPath != "" && !strings.HasPrefix(r.Host, namespaceInPath+"."+pagesDomain) { + namespaceInPath = "" + } + } + return namespaceInPath +} diff --git a/internal/request/request_test.go b/internal/request/request_test.go index 9e71db37e66e9b2be51dccf1b914abd14de7bc95..8ee16f8c4a0d31c8640dafa25079f8f604b3ec94 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -87,3 +87,40 @@ func TestGetRemoteAddrWithoutPort(t *testing.T) { }) } } + +func TestGetNamespaceInPathFromRequest(t *testing.T) { + tests := map[string]struct { + pagesDomain string + u string + namespace string + expected string + }{ + "when valid X-Gitlab-Namespace-In-Path is provided in request header": { + pagesDomain: "example.com", + u: "https://namespace.example.com/myProject", + namespace: "namespace", + expected: "namespace", + }, + "when valid X-Gitlab-Namespace-In-Path with '.' in between in request header": { + pagesDomain: "example.com", + u: "https://namespace.test.example.com/myProject", + namespace: "namespace.test", + expected: "namespace.test", + }, + "when forged X-Gitlab-Namespace-In-Path is provided in request header": { + pagesDomain: "example.com", + u: "https://namespace.example.com/myProject", + namespace: "namespace-forged", + expected: "", + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, test.u, nil) + req.Header.Set("X-Gitlab-Namespace-In-Path", test.namespace) + + namespace := GetNamespaceInPathFromRequest(req, test.pagesDomain) + require.Equal(t, test.expected, namespace) + }) + } +}