diff --git a/app.go b/app.go index 22264eb8e22d3806f1bcee6e6157a5a5dfe61a59..8a87f10525ea7b1133cc2d644059ae37f1a9c8a4 100644 --- a/app.go +++ b/app.go @@ -24,6 +24,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/logging" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" "gitlab.com/gitlab-org/gitlab-pages/internal/request" + "gitlab.com/gitlab-org/gitlab-pages/internal/singlehost" "gitlab.com/gitlab-org/gitlab-pages/internal/source" ) @@ -336,6 +337,13 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { handler = a.routingMiddleware(handler) + if a.appConfig.SingleHost { + handler, err = singlehost.NewMiddleware(handler, a.appConfig.Domain) + if err != nil { + return nil, err + } + } + // Health Check handler, err = a.healthCheckMiddleware(handler) if err != nil { diff --git a/app_config.go b/app_config.go index 245a9e0d5de37d27001157f1220a759ca4b1d326..3c6a8d31d9cd3b5273d8487eee1c3c8a3f1f9d81 100644 --- a/app_config.go +++ b/app_config.go @@ -4,6 +4,7 @@ import "time" type appConfig struct { Domain string + SingleHost bool ArtifactsServer string ArtifactsServerTimeout int RootCertificate []byte diff --git a/internal/singlehost/middleware.go b/internal/singlehost/middleware.go new file mode 100644 index 0000000000000000000000000000000000000000..820626c82d493d25f7c9ddc3b3b9209901ff2aa8 --- /dev/null +++ b/internal/singlehost/middleware.go @@ -0,0 +1,79 @@ +package singlehost + +import ( + "errors" + "net" + "net/http" + "strings" + + log "github.com/sirupsen/logrus" +) + +type middleware struct { + next http.Handler + pagesDomain string +} + +// NewMiddleware returns new single host middleware +// which substitutes first path segment for host, e.g.: +// pages.example.com/group becames group.pages.example.com +// pages.example.com/group/subgroup/path/index.html becames group.pages.example.com/subgroup/path/index.html +func NewMiddleware(next http.Handler, pagesDomain string) (http.Handler, error) { + if next == nil { + return nil, errors.New("Can't build singlehost middleware: next middleware is empty") + } + return middleware{next: next, pagesDomain: pagesDomain}, nil +} + +func (m middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { + m.extractHostFromPath(r) + + ww := newResponseWriter(w, m.pagesDomain) + + m.next.ServeHTTP(ww, r) +} + +func (m middleware) extractHostFromPath(r *http.Request) { + logger := log.WithFields(log.Fields{ + "orig_host": r.Host, + "orig_path": r.URL.Path, + "pages_domain": m.pagesDomain, + }) + + if !m.isTopPagesDomain(r.Host) { + logger.Debug("Incoming request does not match pages domain") + return + } + + path := strings.TrimLeft(r.URL.Path, "/") + segments := strings.SplitN(path, "/", 2) + if len(segments) == 0 { + logger.Debug("can't extract group from path because first segment is empty") + return + } + + namespace := segments[0] + newPath := "/" + + if len(segments) > 1 { + newPath += segments[1] + } + + newHost := namespace + "." + r.Host + + logger.WithFields(log.Fields{ + "old_path": r.URL.Path, + "new_path": newPath, + }).Debug("Rewrite namespace host") + + r.Host = newHost + r.URL.Path = newPath +} + +func (m middleware) isTopPagesDomain(host string) bool { + hostWithoutPort, _, err := net.SplitHostPort(host) + if err != nil { + hostWithoutPort = host + } + return hostWithoutPort == m.pagesDomain +} diff --git a/internal/singlehost/middleware_test.go b/internal/singlehost/middleware_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ee5269ce3183799972c771dae8450e73da687a50 --- /dev/null +++ b/internal/singlehost/middleware_test.go @@ -0,0 +1,106 @@ +package singlehost + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + "gitlab.com/gitlab-org/gitlab-pages/internal/testhelpers" +) + +var writeURLhandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, r.Host+r.URL.Path) +}) + +func TestServeHTTP(t *testing.T) { + handler, err := NewMiddleware(writeURLhandler, "pages.example.com") + require.NoError(t, err) + + tests := []struct { + name string + URL string + expectedURL string + }{ + { + name: "custom domain", + URL: "http://mydomain.example.com", + expectedURL: "mydomain.example.com", + }, + { + name: "namespace root", + URL: "http://pages.example.com/group", + expectedURL: "group.pages.example.com/", + }, + { + name: "namespace root with port", + URL: "http://pages.example.com:8080/group", + expectedURL: "group.pages.example.com:8080/", + }, + { + name: "namespace root with trailing slash", + URL: "http://pages.example.com/group/", + expectedURL: "group.pages.example.com/", + }, + { + name: "namespace with path", + URL: "http://pages.example.com/group/path/to/file", + expectedURL: "group.pages.example.com/path/to/file", + }, + { + name: "namespace with path does not remove trailing slash", + URL: "http://pages.example.com/group/path/to/file/", + expectedURL: "group.pages.example.com/path/to/file/", + }, + { + name: "namespace with path and port", + URL: "http://pages.example.com:8080/group/path/to/file", + expectedURL: "group.pages.example.com:8080/path/to/file", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.URL, nil) + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + + body, err := ioutil.ReadAll(recorder.Body) + require.NoError(t, err) + + require.Equal(t, tt.expectedURL, string(body)) + }) + } +} + +func TestServeHTTPWithRedirect(t *testing.T) { + tests := []struct { + name string + redirectURL string + expectedRedirectURL string + }{ + { + name: "redirecting to non-group domain", + redirectURL: "//example.com:8080/test", + expectedRedirectURL: "//example.com:8080/test", + }, + { + name: "redirecting to group domain", + redirectURL: "//group.pages.example.com:8080/test", + expectedRedirectURL: "//pages.example.com:8080/group/test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirectHandler := http.RedirectHandler(tt.redirectURL, 302) + handler, err := NewMiddleware(redirectHandler, "pages.example.com") + require.NoError(t, err) + + testhelpers.AssertRedirectTo(t, handler.ServeHTTP, "GET", "/", nil, tt.expectedRedirectURL) + }) + } +} diff --git a/internal/singlehost/responsewriter.go b/internal/singlehost/responsewriter.go new file mode 100644 index 0000000000000000000000000000000000000000..a4bb7ea724e9c2dd20a09a00158c65a78c5df55f --- /dev/null +++ b/internal/singlehost/responsewriter.go @@ -0,0 +1,64 @@ +package singlehost + +import ( + "net/http" + "net/url" + "strings" + + log "github.com/sirupsen/logrus" +) + +type responseWriter struct { + http.ResponseWriter + pagesDomain string +} + +func newResponseWriter(original http.ResponseWriter, pagesDomain string) http.ResponseWriter { + return responseWriter{ResponseWriter: original, pagesDomain: pagesDomain} +} + +func (w responseWriter) WriteHeader(statusCode int) { + if statusCode == http.StatusMovedPermanently || statusCode == http.StatusFound { + header := w.ResponseWriter.Header() + header.Set("Location", w.transformLocation(header.Get("Location"))) + } + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w responseWriter) transformLocation(location string) string { + URL, err := url.Parse(location) + if err != nil { + log.WithField("location", location).WithError(err).Error("Can't parse redirected location") + + return location + } + + if !strings.HasSuffix(URL.Hostname(), "."+w.pagesDomain) { + log.WithFields(log.Fields{ + "hostname": URL.Hostname(), + "pages_domain": w.pagesDomain, + }).Debug("Redirected URL doesn't match pages domain") + + return location + } + + namespace := strings.TrimSuffix(URL.Hostname(), "."+w.pagesDomain) + + host := w.pagesDomain + + if URL.Port() != "" { + host += ":" + URL.Port() + } + + URL.Host = host + URL.Path = "/" + namespace + URL.Path + + newLocation := URL.String() + + log.WithFields(log.Fields{ + "orig_location": location, + "new_location": newLocation, + }).Debug("Changing redirected location") + + return newLocation +} diff --git a/main.go b/main.go index 9a316c5e512d752315509c4b66533373fd0c6d72..55573bec7d05b835671fa3f68d910af8d3ac2802 100644 --- a/main.go +++ b/main.go @@ -73,6 +73,7 @@ var ( tlsMaxVersion = flag.String("tls-max-version", "", tlsconfig.FlagUsage("max")) disableCrossOriginRequests = flag.Bool("disable-cross-origin-requests", false, "Disable cross-origin requests") + singleHost = flag.Bool("single-host", false, "EXPERIMENTAL: can be removed without notice") // See init() listenHTTP MultiStringFlag @@ -147,6 +148,7 @@ func configFromFlags() appConfig { var config appConfig config.Domain = strings.ToLower(*pagesDomain) + config.SingleHost = *singleHost config.RedirectHTTP = *redirectHTTP config.HTTP2 = *useHTTP2 config.DisableCrossOriginRequests = *disableCrossOriginRequests