diff --git a/app.go b/app.go index ac3071a9cddb3cb07676c28f5a20bcde79765304..f0760a50d56a637a60e6df14fbbada40b67b84b6 100644 --- a/app.go +++ b/app.go @@ -34,6 +34,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" "gitlab.com/gitlab-org/gitlab-pages/internal/rejectmethods" "gitlab.com/gitlab-org/gitlab-pages/internal/request" + app_router "gitlab.com/gitlab-org/gitlab-pages/internal/router" "gitlab.com/gitlab-org/gitlab-pages/internal/routing" "gitlab.com/gitlab-org/gitlab-pages/internal/serving/disk/zip" "gitlab.com/gitlab-org/gitlab-pages/internal/source" @@ -152,36 +153,44 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { handler = handlers.Ratelimiter(handler, &a.config.RateLimit) - // Health Check - handler = health.NewMiddleware(handler, a.config.General.StatusPath) - - // Custom response headers - handler = customheaders.NewMiddleware(handler, a.CustomHeaders) - - // Correlation ID injection middleware + metricsMiddleware := labmetrics.NewHandlerFactory(labmetrics.WithNamespace("gitlab_pages")) var correlationOpts []correlation.InboundHandlerOption if a.config.General.PropagateCorrelationID { correlationOpts = append(correlationOpts, correlation.WithPropagation()) } - handler = handlePanicMiddleware(handler) - // Access logs and metrics - handler, err := logging.BasicAccessLogger(handler, a.config.Log.Format) + // TODO: ideally, this should not be here. Instead, this check should be done before creating the routes. + accessLogger, err := logging.GetAccessLogger(a.config.Log.Format) if err != nil { return nil, err } - metricsMiddleware := labmetrics.NewHandlerFactory(labmetrics.WithNamespace("gitlab_pages")) - handler = metricsMiddleware(handler) - - handler = correlation.InjectCorrelationID(handler, correlationOpts...) - // These middlewares MUST be added in the end. - // Being last means they will be evaluated first - // preventing any operation on bogus requests. - handler = urilimiter.NewMiddleware(handler, a.config.General.MaxURILength) - handler = rejectmethods.NewMiddleware(handler) + router := app_router.NewRouter( + rejectmethods.NewMiddleware, + func(next http.Handler) http.Handler { + return urilimiter.NewMiddleware(next, a.config.General.MaxURILength) + }, + func(next http.Handler) http.Handler { + return correlation.InjectCorrelationID(next, correlationOpts...) + }, + func(next http.Handler) http.Handler { + return metricsMiddleware(next) + }, + func(next http.Handler) http.Handler { + return logging.BasicAccessLogger(next, accessLogger) + }, + handlePanicMiddleware, + func(next http.Handler) http.Handler { + return customheaders.NewMiddleware(next, a.CustomHeaders) + }, + ) + + if a.config.General.StatusPath != "" { + router.Handle(a.config.General.StatusPath, health.Handler()) + } + router.Handle("/", handler) - return handler, nil + return router, nil } // nolint: gocyclo // ignore this diff --git a/internal/healthcheck/handler.go b/internal/healthcheck/handler.go new file mode 100644 index 0000000000000000000000000000000000000000..e65d5f95de99f0b16fdf61a121a326f7b73fe45a --- /dev/null +++ b/internal/healthcheck/handler.go @@ -0,0 +1,11 @@ +package healthcheck + +import "net/http" + +// Handler is serving the application status check +func Handler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "no-store") + w.Write([]byte("success\n")) + }) +} diff --git a/internal/healthcheck/handler_test.go b/internal/healthcheck/handler_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7823c1a4ed74d4a960c1ae74bfc5a3a5380169de --- /dev/null +++ b/internal/healthcheck/handler_test.go @@ -0,0 +1,17 @@ +package healthcheck_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/healthcheck" +) + +func TestHealthCheckHandler(t *testing.T) { + u := "https://example.com/-/healthcheck" + + require.HTTPStatusCode(t, healthcheck.Handler().ServeHTTP, http.MethodGet, u, nil, http.StatusOK) + require.HTTPBodyContains(t, healthcheck.Handler().ServeHTTP, http.MethodGet, u, nil, "success\n") +} diff --git a/internal/healthcheck/middleware.go b/internal/healthcheck/middleware.go deleted file mode 100644 index 2ddd35b70bd70e6f20f53ac2929710cd3eb164ab..0000000000000000000000000000000000000000 --- a/internal/healthcheck/middleware.go +++ /dev/null @@ -1,19 +0,0 @@ -package healthcheck - -import ( - "net/http" -) - -// NewMiddleware is serving the application status check -func NewMiddleware(handler http.Handler, statusPath string) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == statusPath { - w.Header().Set("Cache-Control", "no-store") - w.Write([]byte("success\n")) - - return - } - - handler.ServeHTTP(w, r) - }) -} diff --git a/internal/healthcheck/middleware_test.go b/internal/healthcheck/middleware_test.go deleted file mode 100644 index 124bb0ff5f0fb0c4cdc00f6db1b0aaae53919ef4..0000000000000000000000000000000000000000 --- a/internal/healthcheck/middleware_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package healthcheck_test - -import ( - "io" - "net/http" - "testing" - - "github.com/stretchr/testify/require" - - "gitlab.com/gitlab-org/gitlab-pages/internal/config" - "gitlab.com/gitlab-org/gitlab-pages/internal/healthcheck" -) - -func TestHealthCheckMiddleware(t *testing.T) { - tests := map[string]struct { - path string - body string - }{ - "Not a healthcheck request": { - path: "/foo/bar", - body: "Hello from inner handler", - }, - "Healthcheck request": { - path: "/-/healthcheck", - body: "success\n", - }, - } - - cfg := config.Config{ - General: config.General{ - StatusPath: "/-/healthcheck", - }, - } - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - io.WriteString(w, "Hello from inner handler") - }) - - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - middleware := healthcheck.NewMiddleware(handler, cfg.General.StatusPath) - - u := "https://example.com" + tc.path - - require.HTTPStatusCode(t, middleware.ServeHTTP, http.MethodGet, u, nil, http.StatusOK) - require.HTTPBodyContains(t, middleware.ServeHTTP, http.MethodGet, u, nil, tc.body) - }) - } -} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 6e8533d87a4a77c31d7760158c2f87ad2f69b5d7..c47c8d3a009e0377fe6ba6629749c24b4a79d8a9 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -34,7 +34,7 @@ func ConfigureLogging(format string, verbose bool) error { // getAccessLogger will return the default logger, except when // the log format is text, in which case a combined HTTP access // logger will be configured. This behaviour matches Workhorse -func getAccessLogger(format string) (*logrus.Logger, error) { +func GetAccessLogger(format string) (*logrus.Logger, error) { if format != "text" && format != "" { return logrus.StandardLogger(), nil } @@ -52,17 +52,13 @@ func getAccessLogger(format string) (*logrus.Logger, error) { } // BasicAccessLogger configures the GitLab pages basic HTTP access logger middleware -func BasicAccessLogger(handler http.Handler, format string) (http.Handler, error) { - accessLogger, err := getAccessLogger(format) - if err != nil { - return nil, err - } - - return log.AccessLogger(handler, +func BasicAccessLogger(handler http.Handler, accessLogger *logrus.Logger) http.Handler { + return log.AccessLogger( + handler, log.WithExtraFields(extraFields), log.WithAccessLogger(accessLogger), log.WithXFFAllowed(func(sip string) bool { return false }), - ), nil + ) } func extraFields(r *http.Request) log.Fields { diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000000000000000000000000000000000000..5c353e4b7a42ecac2599e23032f4a580af5bb7f6 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,31 @@ +package router + +import "net/http" + +// middleware alias to avoid the long function description internally +type middleware = func(http.Handler) http.Handler + +type Router struct { + *http.ServeMux + defaultMiddlewares []middleware +} + +// NewRouter creates a new Server. The given middlewares are be executed in the given order. +func NewRouter(middlewares ...middleware) Router { + return Router{ + ServeMux: http.NewServeMux(), + defaultMiddlewares: middlewares, + } +} + +// Handle registers a new handler for the given pattern. The optional middlewares are executed in +// the given order, wrapping the given handler. +func (s Router) Handle(route string, handler http.Handler, middlewares ...middleware) { + ms := append(s.defaultMiddlewares, middlewares...) + + for i := len(ms) - 1; i >= 0; i-- { + handler = ms[i](handler) + } + + s.ServeMux.Handle(route, handler) +}