diff --git a/app.go b/app.go index 22264eb8e22d3806f1bcee6e6157a5a5dfe61a59..de6b3f56f7307c550fe122b1261d6bc835f7c9ea 100644 --- a/app.go +++ b/app.go @@ -214,17 +214,6 @@ func (a *theApp) acmeMiddleware(handler http.Handler) http.Handler { }) } -// authMiddleware handles authentication requests -func (a *theApp) authMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if a.Auth.TryAuthenticate(w, r, a.domains) { - return - } - - handler.ServeHTTP(w, r) - }) -} - // auxiliaryMiddleware will handle status updates, not-ready requests and other // not static-content responses func (a *theApp) auxiliaryMiddleware(handler http.Handler) http.Handler { @@ -323,7 +312,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { } handler = a.accessControlMiddleware(handler) handler = a.auxiliaryMiddleware(handler) - handler = a.authMiddleware(handler) + handler = a.Auth.Middleware(handler) handler = a.acmeMiddleware(handler) handler, err := logging.AccessLogger(handler, a.LogFormat) if err != nil { @@ -457,7 +446,7 @@ func runApp(config appConfig) { if config.ClientID != "" { a.Auth = auth.New(config.Domain, config.StoreSecret, config.ClientID, config.ClientSecret, - config.RedirectURI, config.GitLabServer) + config.RedirectURI, config.GitLabServer, domains) } a.Handlers = handlers.New(a.Auth, a.Artifact) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index f30c7407000f3d0d6680c8630ceaa948f30af61b..454fbd16a54b42131ee839d9f3a49d48a46c4267 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -50,6 +50,7 @@ var ( // Auth handles authenticating users with GitLab API type Auth struct { pagesDomain string + domains source.Source clientID string clientSecret string redirectURI string @@ -106,8 +107,19 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S return session, nil } -// TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth -func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains source.Source) bool { +// Middleware handles authentication requests +func (a *Auth) Middleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if a.tryAuthenticate(w, r) { + return + } + + handler.ServeHTTP(w, r) + }) +} + +// tryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth +func (a *Auth) tryAuthenticate(w http.ResponseWriter, r *http.Request) bool { if a == nil { return false } @@ -124,7 +136,7 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domains s logRequest(r).Info("Receive OAuth authentication callback") - if a.handleProxyingAuth(session, w, r, domains) { + if a.handleProxyingAuth(session, w, r) { return true } @@ -198,20 +210,20 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res http.Redirect(w, r, redirectURI, 302) } -func (a *Auth) domainAllowed(name string, domains source.Source) bool { +func (a *Auth) domainAllowed(name string) bool { isConfigured := (name == a.pagesDomain) || strings.HasSuffix("."+name, a.pagesDomain) if isConfigured { return true } - domain, err := domains.GetDomain(name) + domain, err := a.domains.GetDomain(name) // domain exists and there is no error return (domain != nil && err == nil) } -func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domains source.Source) bool { +func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request) bool { // If request is for authenticating via custom domain if shouldProxyAuth(r) { domain := r.URL.Query().Get("domain") @@ -230,7 +242,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit host = proxyurl.Host } - if !a.domainAllowed(host, domains) { + if !a.domainAllowed(host) { logRequest(r).WithField("domain", host).Warn("Domain is not configured") httperrors.Serve401(w) return true @@ -609,9 +621,10 @@ func createCookieStore(storeSecret string) sessions.Store { // New when authentication supported this will be used to create authentication handler func New(pagesDomain string, storeSecret string, clientID string, clientSecret string, - redirectURI string, gitLabServer string) *Auth { + redirectURI string, gitLabServer string, domains source.Source) *Auth { return &Auth{ pagesDomain: pagesDomain, + domains: domains, clientID: clientID, clientSecret: clientSecret, redirectURI: redirectURI, diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 4a5d63fa9680d80a7e2cdc5004af20ded50b7943..2cedcef8b318bac9fbf0a900c4c0197fa249bac6 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -22,7 +22,8 @@ func createAuth(t *testing.T) *Auth { "id", "secret", "http://pages.gitlab-example.com/auth", - "http://gitlab-example.com") + "http://gitlab-example.com", + source.NewMockSource()) } func defaultCookieStore() sessions.Store { @@ -57,7 +58,7 @@ func TestTryAuthenticate(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - require.Equal(t, false, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, false, auth.tryAuthenticate(result, r)) } func TestTryAuthenticateWithError(t *testing.T) { @@ -69,7 +70,7 @@ func TestTryAuthenticateWithError(t *testing.T) { reqURL.Scheme = request.SchemeHTTPS r := &http.Request{URL: reqURL} - require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, true, auth.tryAuthenticate(result, r)) require.Equal(t, 401, result.Code) } @@ -87,7 +88,7 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { session.Values["state"] = "state" session.Save(r, result) - require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, true, auth.tryAuthenticate(result, r)) require.Equal(t, 401, result.Code) } @@ -116,7 +117,9 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { "id", "secret", "http://pages.gitlab-example.com/auth", - apiServer.URL) + apiServer.URL, + source.NewMockSource(), + ) r, err := http.NewRequest("GET", "/auth?code=1&state=state", nil) require.NoError(t, err) @@ -132,7 +135,7 @@ func testTryAuthenticateWithCodeAndState(t *testing.T, https bool) { }) result := httptest.NewRecorder() - require.Equal(t, true, auth.TryAuthenticate(result, r, source.NewMockSource())) + require.Equal(t, true, auth.tryAuthenticate(result, r)) require.Equal(t, 302, result.Code) require.Equal(t, "https://pages.gitlab-example.com/project/", result.Header().Get("Location")) require.Equal(t, 600, result.Result().Cookies()[0].MaxAge) @@ -169,7 +172,8 @@ func TestCheckAuthenticationWhenAccess(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - apiServer.URL) + apiServer.URL, + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -207,7 +211,8 @@ func TestCheckAuthenticationWhenNoAccess(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - apiServer.URL) + apiServer.URL, + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -246,7 +251,8 @@ func TestCheckAuthenticationWhenInvalidToken(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - apiServer.URL) + apiServer.URL, + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -283,7 +289,8 @@ func TestCheckAuthenticationWithoutProject(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - apiServer.URL) + apiServer.URL, + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -322,7 +329,8 @@ func TestCheckAuthenticationWithoutProjectWhenInvalidToken(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - apiServer.URL) + apiServer.URL, + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("/auth?code=1&state=state") @@ -350,7 +358,8 @@ func TestGetTokenIfExistsWhenTokenExists(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - "") + "", + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("/") @@ -372,7 +381,8 @@ func TestGetTokenIfExistsWhenTokenDoesNotExist(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - "") + "", + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -393,7 +403,8 @@ func TestCheckResponseForInvalidTokenWhenInvalidToken(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - "") + "", + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("http://pages.gitlab-example.com/test") @@ -413,7 +424,8 @@ func TestCheckResponseForInvalidTokenWhenNotInvalidToken(t *testing.T) { "id", "secret", "http://pages.gitlab-example.com/auth", - "") + "", + source.NewMockSource()) result := httptest.NewRecorder() reqURL, err := url.Parse("/something")