From 68e927c3347e90412ca40411d9a77e4d1bf0a09e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Trzci=C5=84ski?= Date: Wed, 27 Feb 2019 15:23:05 +0100 Subject: [PATCH 1/3] Refactor GitLab Pages to use API with additional Storage layer --- acceptance_test.go | 46 +-- app.go | 47 +-- app_config.go | 4 + helpers_test.go | 10 + internal/auth/auth.go | 33 +- internal/auth/auth_test.go | 13 +- internal/client/api.go | 8 + internal/client/cached_api.go | 118 +++++++ internal/client/domain_response.go | 26 ++ internal/client/gitlab.go | 63 ++++ internal/client/lookup_path.go | 26 ++ internal/domain/domain.go | 248 ++++----------- internal/domain/domain_test.go | 205 ++++++------ internal/domain/group.go | 38 --- internal/domain/group_test.go | 97 ------ internal/domain/map.go | 299 ------------------ internal/domain/map_test.go | 240 -------------- internal/fixture/mock_api.go | 24 ++ internal/fixture/mock_server.go | 23 ++ internal/fixture/shared_pages_config.go | 262 +++++++++++++++ internal/storage/file_system.go | 78 +++++ internal/storage/storage.go | 29 ++ main.go | 76 +++-- metrics/metrics.go | 27 -- shared/pages/nested/project/public/index.html | 1 + .../nested/sub1/project/public/index.html | 1 + .../sub1/sub2/project/public/index.html | 1 + .../sub1/sub2/sub3/project/public/index.html | 1 + .../sub2/sub3/sub4/project/public/index.html | 1 + .../sub3/sub4/sub5/project/public/index.html | 1 + 30 files changed, 946 insertions(+), 1100 deletions(-) create mode 100644 internal/client/api.go create mode 100644 internal/client/cached_api.go create mode 100644 internal/client/domain_response.go create mode 100644 internal/client/gitlab.go create mode 100644 internal/client/lookup_path.go delete mode 100644 internal/domain/group.go delete mode 100644 internal/domain/group_test.go delete mode 100644 internal/domain/map.go delete mode 100644 internal/domain/map_test.go create mode 100644 internal/fixture/mock_api.go create mode 100644 internal/fixture/mock_server.go create mode 100644 internal/fixture/shared_pages_config.go create mode 100644 internal/storage/file_system.go create mode 100644 internal/storage/storage.go create mode 100644 shared/pages/nested/project/public/index.html create mode 100644 shared/pages/nested/sub1/project/public/index.html create mode 100644 shared/pages/nested/sub1/sub2/project/public/index.html create mode 100644 shared/pages/nested/sub1/sub2/sub3/project/public/index.html create mode 100644 shared/pages/nested/sub1/sub2/sub3/sub4/project/public/index.html create mode 100644 shared/pages/nested/sub1/sub2/sub3/sub4/sub5/project/public/index.html diff --git a/acceptance_test.go b/acceptance_test.go index 55a838812..3173b9bc0 100644 --- a/acceptance_test.go +++ b/acceptance_test.go @@ -10,7 +10,6 @@ import ( "net/http/httptest" "net/url" "os" - "path" "regexp" "testing" "time" @@ -153,30 +152,15 @@ func TestKnownHostReturns200(t *testing.T) { func TestNestedSubgroups(t *testing.T) { skipUnlessEnabled(t) - maxNestedSubgroup := 21 - - pagesRoot, err := ioutil.TempDir("", "pages-root") - require.NoError(t, err) - defer os.RemoveAll(pagesRoot) - - makeProjectIndex := func(subGroupPath string) { - projectPath := path.Join(pagesRoot, "nested", subGroupPath, "project", "public") - require.NoError(t, os.MkdirAll(projectPath, 0755)) - - projectIndex := path.Join(projectPath, "index.html") - require.NoError(t, ioutil.WriteFile(projectIndex, []byte("index"), 0644)) - } - makeProjectIndex("") + maxNestedSubgroup := 5 paths := []string{""} for i := 1; i < maxNestedSubgroup*2; i++ { subGroupPath := fmt.Sprintf("%ssub%d/", paths[i-1], i) paths = append(paths, subGroupPath) - - makeProjectIndex(subGroupPath) } - teardown := RunPagesProcess(t, *pagesBinary, listeners, "", "-pages-root", pagesRoot) + teardown := RunPagesProcess(t, *pagesBinary, listeners, "") defer teardown() for nestingLevel, path := range paths { @@ -381,7 +365,6 @@ func TestPrometheusMetricsCanBeScraped(t *testing.T) { body, _ := ioutil.ReadAll(resp.Body) assert.Contains(t, string(body), "gitlab_pages_http_sessions_active 0") - assert.Contains(t, string(body), "gitlab_pages_domains_served_total 14") } } @@ -396,30 +379,6 @@ func TestStatusPage(t *testing.T) { assert.Equal(t, http.StatusOK, rsp.StatusCode) } -func TestStatusNotYetReady(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "", "-pages-status=/@statuscheck", "-pages-root=shared/invalid-pages") - defer teardown() - - waitForRoundtrips(t, listeners, 5*time.Second) - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "@statuscheck") - require.NoError(t, err) - defer rsp.Body.Close() - assert.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) -} - -func TestPageNotAvailableIfNotLoaded(t *testing.T) { - skipUnlessEnabled(t) - teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "", "-pages-root=shared/invalid-pages") - defer teardown() - waitForRoundtrips(t, listeners, 5*time.Second) - - rsp, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "index.html") - require.NoError(t, err) - defer rsp.Body.Close() - assert.Equal(t, http.StatusServiceUnavailable, rsp.StatusCode) -} - func TestObscureMIMEType(t *testing.T) { skipUnlessEnabled(t) teardown := RunPagesProcessWithoutWait(t, *pagesBinary, listeners, "") @@ -880,6 +839,7 @@ func TestAccessControlGroupDomain404RedirectsAuth(t *testing.T) { assert.Equal(t, "projects.gitlab-example.com", url.Host) assert.Equal(t, "/auth", url.Path) } + func TestAccessControlProject404DoesNotRedirect(t *testing.T) { skipUnlessEnabled(t) teardown := RunPagesProcessWithAuth(t, *pagesBinary, listeners, "") diff --git a/app.go b/app.go index c25bcf370..e3736ce0c 100644 --- a/app.go +++ b/app.go @@ -20,6 +20,7 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/admin" "gitlab.com/gitlab-org/gitlab-pages/internal/artifact" "gitlab.com/gitlab-org/gitlab-pages/internal/auth" + "gitlab.com/gitlab-org/gitlab-pages/internal/client" "gitlab.com/gitlab-org/gitlab-pages/internal/domain" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/netutil" @@ -38,22 +39,27 @@ var ( type theApp struct { appConfig - dm domain.Map - lock sync.RWMutex Artifact *artifact.Artifact Auth *auth.Auth -} - -func (a *theApp) isReady() bool { - return a.dm != nil + Client client.API } func (a *theApp) domain(host string) *domain.D { host = strings.ToLower(host) - a.lock.RLock() - defer a.lock.RUnlock() - domain, _ := a.dm[host] - return domain + + response, err := a.Client.RequestDomain(host) + + log.WithFields(log.Fields{ + "host": host, + }).WithError(err).Debug("RequestDomain") + + if err != nil { + return nil + } + + var domain domain.D + domain.DomainResponse = response + return &domain } func (a *theApp) ServeTLS(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -62,15 +68,15 @@ func (a *theApp) ServeTLS(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { } if domain := a.domain(ch.ServerName); domain != nil { - tls, _ := domain.EnsureCertificate() - return tls, nil + tls, _ := domain.Certificate() + return &tls, nil } return nil, nil } func (a *theApp) healthCheck(w http.ResponseWriter, r *http.Request, https bool) { - if a.isReady() { + if a.Client.IsReady() { w.Write([]byte("success")) } else { http.Error(w, "not yet ready", http.StatusServiceUnavailable) @@ -140,7 +146,7 @@ func (a *theApp) tryAuxiliaryHandlers(w http.ResponseWriter, r *http.Request, ht return true } - if !a.isReady() { + if !a.Client.IsReady() { httperrors.Serve503(w) return true } @@ -166,7 +172,7 @@ func (a *theApp) serveContent(ww http.ResponseWriter, r *http.Request, https boo host, domain := a.getHostAndDomain(r) - if a.Auth.TryAuthenticate(&w, r, a.dm, &a.lock) { + if a.Auth.TryAuthenticate(&w, r, a.domain) { return } @@ -230,12 +236,6 @@ func (a *theApp) ServeProxy(ww http.ResponseWriter, r *http.Request) { a.serveContent(ww, r, https) } -func (a *theApp) UpdateDomains(dm domain.Map) { - a.lock.Lock() - defer a.lock.Unlock() - a.dm = dm -} - func (a *theApp) Run() { var wg sync.WaitGroup @@ -294,8 +294,6 @@ func (a *theApp) Run() { a.listenAdminUnix(&wg) a.listenAdminHTTPS(&wg) - go domain.Watch(a.Domain, a.UpdateDomains, time.Second) - wg.Wait() } @@ -351,6 +349,9 @@ func (a *theApp) listenAdminHTTPS(wg *sync.WaitGroup) { func runApp(config appConfig) { a := theApp{appConfig: config} + a.Client = client.NewGitLabClient(config.APIServer, config.APIServerKey, config.APIServerTimeout) + a.Client = client.NewCachedClient(a.Client, 10*time.Second, 3*time.Second) + if config.ArtifactsServer != "" { a.Artifact = artifact.New(config.ArtifactsServer, config.ArtifactsServerTimeout, config.Domain) } diff --git a/app_config.go b/app_config.go index 9ff26b6b9..c6acef82b 100644 --- a/app_config.go +++ b/app_config.go @@ -11,6 +11,10 @@ type appConfig struct { AdminToken []byte MaxConns int + APIServer string + APIServerKey []byte + APIServerTimeout int + ListenHTTP []uintptr ListenHTTPS []uintptr ListenProxy []uintptr diff --git a/helpers_test.go b/helpers_test.go index bf61b7a4d..ec27b9967 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -19,6 +19,8 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/fixture" ) +const apiServerListenAddress = "127.0.0.1:7800" + type tWriter struct { t *testing.T } @@ -173,7 +175,13 @@ func runPagesProcess(t *testing.T, wait bool, pagesPath string, listeners []List _, err := os.Stat(pagesPath) require.NoError(t, err) + apiServer := &http.Server{ + Addr: apiServerListenAddress, + Handler: http.HandlerFunc(fixture.MockHTTPHandler), + } + args, tempfiles := getPagesArgs(t, listeners, promPort, extraArgs) + args = append(args, "-api-server", "http://"+apiServerListenAddress+"/api/v4") cmd := exec.Command(pagesPath, args...) cmd.Env = append(os.Environ(), extraEnv...) cmd.Stdout = &tWriter{t} @@ -183,6 +191,7 @@ func runPagesProcess(t *testing.T, wait bool, pagesPath string, listeners []List waitCh := make(chan struct{}) go func() { + apiServer.ListenAndServe() cmd.Wait() for _, tempfile := range tempfiles { os.Remove(tempfile) @@ -191,6 +200,7 @@ func runPagesProcess(t *testing.T, wait bool, pagesPath string, listeners []List }() cleanup := func() { + apiServer.Close() cmd.Process.Signal(os.Interrupt) <-waitCh } diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 02879568c..561e74d86 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" "github.com/gorilla/securecookie" @@ -90,7 +89,7 @@ func (a *Auth) checkSession(w http.ResponseWriter, r *http.Request) (*sessions.S } // TryAuthenticate tries to authenticate user and fetch access token if request is a callback to auth -func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, dm domain.Map, lock *sync.RWMutex) bool { +func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, domainFinder domain.Finder) bool { if a == nil { return false @@ -108,7 +107,7 @@ func (a *Auth) TryAuthenticate(w http.ResponseWriter, r *http.Request, dm domain logRequest(r).Info("Receive OAuth authentication callback") - if a.handleProxyingAuth(session, w, r, dm, lock) { + if a.handleProxyingAuth(session, w, r, domainFinder) { return true } @@ -176,16 +175,28 @@ func (a *Auth) checkAuthenticationResponse(session *sessions.Session, w http.Res http.Redirect(w, r, redirectURI, 302) } -func (a *Auth) domainAllowed(domain string, dm domain.Map, lock *sync.RWMutex) bool { - lock.RLock() - defer lock.RUnlock() +func (a *Auth) domainAllowed(domain string, domainFinder domain.Finder) bool { + // if our domain is pages-domain we always force auth + if domain == a.pagesDomain { + return true + } + + // if our domain is subdomain of pages-domain we force auth + // TODO: This condition is taken from original code, but it is clearly broken, + // as it should be `strings.HasSuffix("."+domain, a.pagesDomain)` + if strings.HasSuffix("."+domain, a.pagesDomain) { + return true + } - domain = strings.ToLower(domain) - _, present := dm[domain] - return domain == a.pagesDomain || strings.HasSuffix("."+domain, a.pagesDomain) || present + // if our domain is custom domain, we force auth + if domainFinder != nil && domainFinder(domain) != nil { + return true + } + + return false } -func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, dm domain.Map, lock *sync.RWMutex) bool { +func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWriter, r *http.Request, domainFinder domain.Finder) bool { // If request is for authenticating via custom domain if shouldProxyAuth(r) { domain := r.URL.Query().Get("domain") @@ -202,7 +213,7 @@ func (a *Auth) handleProxyingAuth(session *sessions.Session, w http.ResponseWrit host = proxyurl.Host } - if !a.domainAllowed(host, dm, lock) { + if !a.domainAllowed(host, domainFinder) { logRequest(r).WithField("domain", host).Warn("Domain is not configured") httperrors.Serve401(w) return true diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index ed130cafd..012089975 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -5,7 +5,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "sync" "testing" "github.com/gorilla/sessions" @@ -16,6 +15,10 @@ import ( "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) +func findDomain(host string) *domain.D { + return nil +} + func createAuth(t *testing.T) *auth.Auth { return auth.New("pages.gitlab-example.com", "something-very-secret", @@ -33,7 +36,7 @@ func TestTryAuthenticate(t *testing.T) { require.NoError(t, err) r := &http.Request{URL: reqURL} - assert.Equal(t, false, auth.TryAuthenticate(result, r, make(domain.Map), &sync.RWMutex{})) + assert.Equal(t, false, auth.TryAuthenticate(result, r, findDomain)) } func TestTryAuthenticateWithError(t *testing.T) { @@ -44,7 +47,7 @@ func TestTryAuthenticateWithError(t *testing.T) { require.NoError(t, err) r := &http.Request{URL: reqURL} - assert.Equal(t, true, auth.TryAuthenticate(result, r, make(domain.Map), &sync.RWMutex{})) + assert.Equal(t, true, auth.TryAuthenticate(result, r, findDomain)) assert.Equal(t, 401, result.Code) } @@ -61,7 +64,7 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { session.Values["state"] = "state" session.Save(r, result) - assert.Equal(t, true, auth.TryAuthenticate(result, r, make(domain.Map), &sync.RWMutex{})) + assert.Equal(t, true, auth.TryAuthenticate(result, r, findDomain)) assert.Equal(t, 401, result.Code) } @@ -103,7 +106,7 @@ func TestTryAuthenticateWithCodeAndState(t *testing.T) { session.Values["state"] = "state" session.Save(r, result) - assert.Equal(t, true, auth.TryAuthenticate(result, r, make(domain.Map), &sync.RWMutex{})) + assert.Equal(t, true, auth.TryAuthenticate(result, r, findDomain)) assert.Equal(t, 302, result.Code) assert.Equal(t, "http://pages.gitlab-example.com/project/", result.Header().Get("Location")) } diff --git a/internal/client/api.go b/internal/client/api.go new file mode 100644 index 000000000..af1a376dc --- /dev/null +++ b/internal/client/api.go @@ -0,0 +1,8 @@ +package client + +// API implements simple interface that allows +// Pages to talk and request data from GitLab +type API interface { + RequestDomain(host string) (*DomainResponse, error) + IsReady() bool +} diff --git a/internal/client/cached_api.go b/internal/client/cached_api.go new file mode 100644 index 000000000..edf180fea --- /dev/null +++ b/internal/client/cached_api.go @@ -0,0 +1,118 @@ +package client + +import ( + "sync" + "time" + + cache "github.com/patrickmn/go-cache" + "github.com/sirupsen/logrus" +) + +const refreshCacheInterval = 3 * time.Second +const defaultCacheTimeout = 3 * time.Second + +type cachedDomainResponse struct { + host string + response *DomainResponse + err error + + once sync.Once +} + +func (c *cachedDomainResponse) log() *logrus.Entry { + return logrus.WithFields(logrus.Fields{ + "host": c.host, + }) +} + +// cachedAPI implements a cache layer for all requests +// the request is executed exactly once for all clients +// we store positive results in `cache` for cacheTimeout interval +// we also store negative results in `cache` for time defined by defaultCacheTimeout +// to solve temporary API failures we retain last successful result +// for time specified in `longCacheTimeout` and use it as last resort +// this makes us to request domain config every `cacheTimeout` in case of found domains +// and request every `defaultCacheTimeout` if there's API failure +type cachedAPI struct { + upstream API + cacheTimeout time.Duration + longCacheTimeout time.Duration + + cache *cache.Cache + longCache *cache.Cache +} + +func (a *cachedAPI) ensureRequestDomain(c *cachedDomainResponse) { + c.once.Do(func() { + c.response, c.err = a.upstream.RequestDomain(c.host) + c.log().WithError(c.err).Debugln("CachedRequestDomain") + + // add positive result to cache and in long cache for longer period + if c.err == nil { + a.cache.Set(c.host, c, a.cacheTimeout) + a.longCache.Set(c.host, c, a.longCacheTimeout) + } else { + a.cache.Set(c.host, c, cache.DefaultExpiration) + } + }) +} + +func (a *cachedAPI) findCacheEntry(host string) *cachedDomainResponse { + // try to get object from cache + if cached, found := a.cache.Get(host); found { + return cached.(*cachedDomainResponse) + } + + return nil +} + +func (a *cachedAPI) findLongCacheEntry(host string) *cachedDomainResponse { + // try to get object from cache + if cached, found := a.longCache.Get(host); found { + return cached.(*cachedDomainResponse) + } + + return nil +} + +func (a *cachedAPI) newCacheEntry(host string) *cachedDomainResponse { + cachedObject := &cachedDomainResponse{host: host} + + // cache object for short period + a.cache.Set(cachedObject.host, cachedObject, cache.DefaultExpiration) + return cachedObject +} + +// RequestDomain request a host from preconfigured list of domains +func (a *cachedAPI) RequestDomain(host string) (*DomainResponse, error) { + cachedObject := a.findCacheEntry(host) + + // create a new cache entry + if cachedObject == nil { + cachedObject = a.newCacheEntry(host) + } + + // request or wait for API response + a.ensureRequestDomain(cachedObject) + + // try to take from long cache to ignore short failures + if cachedObject == nil { + cachedObject = a.findLongCacheEntry(host) + } + + return cachedObject.response, cachedObject.err +} + +func (a *cachedAPI) IsReady() bool { + return a.upstream.IsReady() +} + +func NewCachedClient(upstream API, cacheTimeout, longCacheTimeout time.Duration) API { + return &cachedAPI{ + upstream: upstream, + cacheTimeout: cacheTimeout, + longCacheTimeout: longCacheTimeout, + cache: cache.New(defaultCacheTimeout, refreshCacheInterval), + longCache: cache.New(defaultCacheTimeout, refreshCacheInterval), + } +} diff --git a/internal/client/domain_response.go b/internal/client/domain_response.go new file mode 100644 index 000000000..32a2ce47f --- /dev/null +++ b/internal/client/domain_response.go @@ -0,0 +1,26 @@ +package client + +import ( + "errors" + "strings" +) + +// DomainResponse describes a configuration for domain, +// like certificate, but also lookup paths to serve the content +type DomainResponse struct { + Certificate string `json:"certificate"` + Key string `json:"certificate_key"` + + LookupPath []LookupPath `json:"lookup_paths"` +} + +// GetPath finds a first matching lookup path that should serve the content +func (d *DomainResponse) GetPath(path string) (*LookupPath, error) { + for _, lp := range d.LookupPath { + if strings.HasPrefix(path, lp.Prefix) || path+"/" == lp.Prefix { + return &lp, nil + } + } + + return nil, errors.New("lookup path not found") +} diff --git a/internal/client/gitlab.go b/internal/client/gitlab.go new file mode 100644 index 000000000..a5fe93938 --- /dev/null +++ b/internal/client/gitlab.go @@ -0,0 +1,63 @@ +package client + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "gitlab.com/gitlab-org/gitlab-pages/internal/httptransport" +) + +type gitlabAPI struct { + server string + key []byte + client *http.Client +} + +func (a *gitlabAPI) IsReady() bool { + return true +} + +// RequestDomain requests the configuration of domain from GitLab +// this provides information where to fetch data from in order to serve +// the domain content +func (a *gitlabAPI) RequestDomain(host string) (*DomainResponse, error) { + values := url.Values{ + "host": []string{host}, + } + + resp, err := http.PostForm(a.server+"/pages/domain", values) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + resp.Header.Set("Authorization", "token "+string(a.key)) + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("response code: %q", resp.StatusCode) + } + + var domainResponse DomainResponse + err = json.NewDecoder(resp.Body).Decode(&domainResponse) + if err != nil { + // Ignore here + return nil, err + } + + return &domainResponse, nil +} + +func NewGitLabClient(server string, key []byte, timeoutSeconds int) API { + return &gitlabAPI{ + server: strings.TrimRight(server, "/"), + key: key, + client: &http.Client{ + Timeout: time.Second * time.Duration(timeoutSeconds), + Transport: httptransport.Transport, + }, + } +} diff --git a/internal/client/lookup_path.go b/internal/client/lookup_path.go new file mode 100644 index 000000000..0ba3fc2f0 --- /dev/null +++ b/internal/client/lookup_path.go @@ -0,0 +1,26 @@ +package client + +import ( + "strings" +) + +// LookupPath describes a single mapping between HTTP Prefix +// and actual data on disk +type LookupPath struct { + Prefix string `json:"prefix"` + Path string `json:"path"` + + NamespaceProject bool `json:"namespace_project"` + HTTPSOnly bool `json:"https_only"` + AccessControl bool `json:"access_control"` + ProjectID uint64 `json:"id"` +} + +// Tail returns a relative path to full path to serve the content +func (lp *LookupPath) Tail(path string) string { + if strings.HasPrefix(path, lp.Prefix) { + return path[len(lp.Prefix):] + } + + return "" +} diff --git a/internal/domain/domain.go b/internal/domain/domain.go index 1455d4368..9e1f9a2ed 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -2,23 +2,20 @@ package domain import ( "crypto/tls" - "errors" "fmt" "io" "mime" "net" "net/http" - "os" "path/filepath" "strconv" "strings" - "sync" "time" - "golang.org/x/sys/unix" - + "gitlab.com/gitlab-org/gitlab-pages/internal/client" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/httputil" + "gitlab.com/gitlab-org/gitlab-pages/internal/storage" ) const ( @@ -47,29 +44,11 @@ type project struct { // D is a domain that gitlab-pages can serve. type D struct { - group - - // custom domains: - projectName string - config *domainConfig - - certificate *tls.Certificate - certificateError error - certificateOnce sync.Once + *client.DomainResponse } -// String implements Stringer. -func (d *D) String() string { - if d.group.name != "" && d.projectName != "" { - return d.group.name + "/" + d.projectName - } - - if d.group.name != "" { - return d.group.name - } - - return d.projectName -} +// Finder provides a mapping between host and domain configuration +type Finder func(host string) *D func (l *locationDirectoryError) Error() string { return "location error accessing directory where file expected" @@ -89,7 +68,7 @@ func acceptsGZip(r *http.Request) bool { return acceptedEncoding == "gzip" } -func handleGZip(w http.ResponseWriter, r *http.Request, fullPath string) string { +func (d *D) handleGZip(w http.ResponseWriter, r *http.Request, storage storage.S, fullPath string) string { if !acceptsGZip(r) { return fullPath } @@ -97,7 +76,7 @@ func handleGZip(w http.ResponseWriter, r *http.Request, fullPath string) string gzipPath := fullPath + ".gz" // Ensure the .gz file is not a symlink - if fi, err := os.Lstat(gzipPath); err != nil || !fi.Mode().IsRegular() { + if fi, err := storage.Stat(gzipPath); err != nil || !fi.Mode().IsRegular() { return fullPath } @@ -118,26 +97,13 @@ func getHost(r *http.Request) string { // Look up a project inside the domain based on the host and path. Returns the // project and its name (if applicable) -func (d *D) getProjectWithSubpath(r *http.Request) (*project, string, string) { - // Check for a project specified in the URL: http://group.gitlab.io/projectA - // If present, these projects shadow the group domain. - split := strings.SplitN(r.URL.Path, "/", maxProjectDepth) - if len(split) >= 2 { - project, projectPath, urlPath := d.digProjectWithSubpath("", split[1:]) - if project != nil { - return project, projectPath, urlPath - } - } - - // Since the URL doesn't specify a project (e.g. http://mydomain.gitlab.io), - // return the group project if it exists. - if host := getHost(r); host != "" { - if groupProject := d.projects[host]; groupProject != nil { - return groupProject, host, strings.Join(split[1:], "/") - } +func (d *D) getProjectWithSubpath(r *http.Request) (*client.LookupPath, string, string) { + lp, err := d.DomainResponse.GetPath(r.URL.Path) + if err != nil { + return nil, "", "" } - return nil, "", "" + return lp, "", lp.Tail(r.URL.Path) } // IsHTTPSOnly figures out if the request should be handled with HTTPS @@ -147,11 +113,6 @@ func (d *D) IsHTTPSOnly(r *http.Request) bool { return false } - // Check custom domain config (e.g. http://example.com) - if d.config != nil { - return d.config.HTTPSOnly - } - // Check projects served under the group domain, including the default one if project, _, _ := d.getProjectWithSubpath(r); project != nil { return project.HTTPSOnly @@ -166,11 +127,6 @@ func (d *D) IsAccessControlEnabled(r *http.Request) bool { return false } - // Check custom domain config (e.g. http://example.com) - if d.config != nil { - return d.config.AccessControl - } - // Check projects served under the group domain, including the default one if project, _, _ := d.getProjectWithSubpath(r); project != nil { return project.AccessControl @@ -185,12 +141,6 @@ func (d *D) IsNamespaceProject(r *http.Request) bool { return false } - // If request is to a custom domain, we do not handle it as a namespace project - // as there can't be multiple projects under the same custom domain - if d.config != nil { - return false - } - // Check projects served under the group domain, including the default one if project, _, _ := d.getProjectWithSubpath(r); project != nil { return project.NamespaceProject @@ -205,12 +155,8 @@ func (d *D) GetID(r *http.Request) uint64 { return 0 } - if d.config != nil { - return d.config.ID - } - if project, _, _ := d.getProjectWithSubpath(r); project != nil { - return project.ID + return project.ProjectID } return 0 @@ -222,10 +168,6 @@ func (d *D) HasProject(r *http.Request) bool { return false } - if d.config != nil { - return true - } - if project, _, _ := d.getProjectWithSubpath(r); project != nil { return true } @@ -236,13 +178,13 @@ func (d *D) HasProject(r *http.Request) bool { // Detect file's content-type either by extension or mime-sniffing. // Implementation is adapted from Golang's `http.serveContent()` // See https://github.com/golang/go/blob/902fc114272978a40d2e65c2510a18e870077559/src/net/http/fs.go#L194 -func (d *D) detectContentType(path string) (string, error) { +func (d *D) detectContentType(storage storage.S, path string) (string, error) { contentType := mime.TypeByExtension(filepath.Ext(path)) if contentType == "" { var buf [512]byte - file, err := os.Open(path) + file, _, err := storage.Open(path) if err != nil { return "", err } @@ -258,28 +200,22 @@ func (d *D) detectContentType(path string) (string, error) { return contentType, nil } -func (d *D) serveFile(w http.ResponseWriter, r *http.Request, origPath string) error { - fullPath := handleGZip(w, r, origPath) +func (d *D) serveFile(w http.ResponseWriter, r *http.Request, storage storage.S, origPath string) error { + fullPath := d.handleGZip(w, r, storage, origPath) - file, err := openNoFollow(fullPath) + file, fi, err := storage.Open(fullPath) if err != nil { return err } - defer file.Close() - fi, err := file.Stat() - if err != nil { - return err - } - if !d.IsAccessControlEnabled(r) { // Set caching headers w.Header().Set("Cache-Control", "max-age=600") w.Header().Set("Expires", time.Now().Add(10*time.Minute).Format(time.RFC1123)) } - contentType, err := d.detectContentType(origPath) + contentType, err := d.detectContentType(storage, origPath) if err != nil { return err } @@ -290,22 +226,17 @@ func (d *D) serveFile(w http.ResponseWriter, r *http.Request, origPath string) e return nil } -func (d *D) serveCustomFile(w http.ResponseWriter, r *http.Request, code int, origPath string) error { - fullPath := handleGZip(w, r, origPath) +func (d *D) serveCustomFile(w http.ResponseWriter, r *http.Request, storage storage.S, code int, origPath string) error { + fullPath := d.handleGZip(w, r, storage, origPath) // Open and serve content of file - file, err := openNoFollow(fullPath) + file, fi, err := storage.Open(fullPath) if err != nil { return err } defer file.Close() - fi, err := file.Stat() - if err != nil { - return err - } - - contentType, err := d.detectContentType(origPath) + contentType, err := d.detectContentType(storage, origPath) if err != nil { return err } @@ -324,16 +255,10 @@ func (d *D) serveCustomFile(w http.ResponseWriter, r *http.Request, code int, or // Resolve the HTTP request to a path on disk, converting requests for // directories to requests for index.html inside the directory if appropriate. -func (d *D) resolvePath(projectName string, subPath ...string) (string, error) { - publicPath := filepath.Join(d.group.name, projectName, "public") - - // Don't use filepath.Join as cleans the path, - // where we want to traverse full path as supplied by user - // (including ..) - testPath := publicPath + "/" + strings.Join(subPath, "/") - fullPath, err := filepath.EvalSymlinks(testPath) +func (d *D) resolvePath(storage storage.S, subPath ...string) (string, error) { + fullPath, err := storage.Resolve(strings.Join(subPath, "/")) if err != nil { - if endsWithoutHTMLExtension(testPath) { + if endsWithoutHTMLExtension(fullPath) { return "", &locationFileNoExtensionError{ FullPath: fullPath, } @@ -342,12 +267,7 @@ func (d *D) resolvePath(projectName string, subPath ...string) (string, error) { return "", err } - // The requested path resolved to somewhere outside of the public/ directory - if !strings.HasPrefix(fullPath, publicPath+"/") && fullPath != publicPath { - return "", fmt.Errorf("%q should be in %q", fullPath, publicPath) - } - - fi, err := os.Lstat(fullPath) + fi, err := storage.Stat(fullPath) if err != nil { return "", err } @@ -355,8 +275,7 @@ func (d *D) resolvePath(projectName string, subPath ...string) (string, error) { // The requested path is a directory, so try index.html via recursion if fi.IsDir() { return "", &locationDirectoryError{ - FullPath: fullPath, - RelativePath: strings.TrimPrefix(fullPath, publicPath), + FullPath: fullPath, } } @@ -369,25 +288,25 @@ func (d *D) resolvePath(projectName string, subPath ...string) (string, error) { return fullPath, nil } -func (d *D) tryNotFound(w http.ResponseWriter, r *http.Request, projectName string) error { - page404, err := d.resolvePath(projectName, "404.html") +func (d *D) tryNotFound(w http.ResponseWriter, r *http.Request, storage storage.S) error { + page404, err := d.resolvePath(storage, "404.html") if err != nil { return err } - err = d.serveCustomFile(w, r, http.StatusNotFound, page404) + err = d.serveCustomFile(w, r, storage, http.StatusNotFound, page404) if err != nil { return err } return nil } -func (d *D) tryFile(w http.ResponseWriter, r *http.Request, projectName string, subPath ...string) error { - fullPath, err := d.resolvePath(projectName, subPath...) +func (d *D) tryFile(w http.ResponseWriter, r *http.Request, storage storage.S, subPath ...string) error { + fullPath, err := d.resolvePath(storage, subPath...) if locationError, _ := err.(*locationDirectoryError); locationError != nil { if endsWithSlash(r.URL.Path) { - fullPath, err = d.resolvePath(projectName, filepath.Join(subPath...), "index.html") + fullPath, err = d.resolvePath(storage, filepath.Join(subPath...), "index.html") } else { // Concat Host with URL.Path redirectPath := "//" + r.Host + "/" @@ -401,108 +320,61 @@ func (d *D) tryFile(w http.ResponseWriter, r *http.Request, projectName string, } if locationError, _ := err.(*locationFileNoExtensionError); locationError != nil { - fullPath, err = d.resolvePath(projectName, strings.TrimSuffix(filepath.Join(subPath...), "/")+".html") + fullPath, err = d.resolvePath(storage, strings.TrimSuffix(filepath.Join(subPath...), "/")+".html") } if err != nil { return err } - return d.serveFile(w, r, fullPath) + return d.serveFile(w, r, storage, fullPath) } -func (d *D) serveFileFromGroup(w http.ResponseWriter, r *http.Request) bool { - project, projectName, subPath := d.getProjectWithSubpath(r) - if project == nil { - httperrors.Serve404(w) - return true - } +// Certificate parses the PEM-encoded certificate for the domain +func (d *D) Certificate() (tls.Certificate, error) { + return tls.X509KeyPair([]byte(d.DomainResponse.Certificate), []byte(d.DomainResponse.Key)) +} - if d.tryFile(w, r, projectName, subPath) == nil { +// ServeFileHTTP implements http.Handler. Returns true if something was served, false if not. +func (d *D) ServeFileHTTP(w http.ResponseWriter, r *http.Request) bool { + if d == nil { + httperrors.Serve404(w) return true } - return false -} - -func (d *D) serveNotFoundFromGroup(w http.ResponseWriter, r *http.Request) { - project, projectName, _ := d.getProjectWithSubpath(r) + project, _, subPath := d.getProjectWithSubpath(r) if project == nil { httperrors.Serve404(w) - return - } - - // Try serving custom not-found page - if d.tryNotFound(w, r, projectName) == nil { - return + return true } - // Generic 404 - httperrors.Serve404(w) -} - -func (d *D) serveFileFromConfig(w http.ResponseWriter, r *http.Request) bool { - // Try to serve file for http://host/... => /group/project/... - if d.tryFile(w, r, d.projectName, r.URL.Path) == nil { + if d.tryFile(w, r, storage.New(project), subPath) == nil { return true } return false } -func (d *D) serveNotFoundFromConfig(w http.ResponseWriter, r *http.Request) { - // Try serving not found page for http://host/ => /group/project/404.html - if d.tryNotFound(w, r, d.projectName) == nil { - return - } - - // Serve generic not found - httperrors.Serve404(w) -} - -// EnsureCertificate parses the PEM-encoded certificate for the domain -func (d *D) EnsureCertificate() (*tls.Certificate, error) { - if d.config == nil { - return nil, errors.New("tls certificates can be loaded only for pages with configuration") - } - - d.certificateOnce.Do(func() { - var cert tls.Certificate - cert, d.certificateError = tls.X509KeyPair([]byte(d.config.Certificate), []byte(d.config.Key)) - if d.certificateError == nil { - d.certificate = &cert - } - }) - - return d.certificate, d.certificateError -} - -// ServeFileHTTP implements http.Handler. Returns true if something was served, false if not. -func (d *D) ServeFileHTTP(w http.ResponseWriter, r *http.Request) bool { +// ServeNotFoundHTTP implements http.Handler. Serves the not found pages from the projects. +func (d *D) ServeNotFoundHTTP(w http.ResponseWriter, r *http.Request) { if d == nil { httperrors.Serve404(w) - return true - } - - if d.config != nil { - return d.serveFileFromConfig(w, r) + return } - return d.serveFileFromGroup(w, r) -} - -// ServeNotFoundHTTP implements http.Handler. Serves the not found pages from the projects. -func (d *D) ServeNotFoundHTTP(w http.ResponseWriter, r *http.Request) { - if d == nil { + project, _, _ := d.getProjectWithSubpath(r) + if project == nil { httperrors.Serve404(w) return } - if d.config != nil { - d.serveNotFoundFromConfig(w, r) - } else { - d.serveNotFoundFromGroup(w, r) + // Try serving custom not-found page + if d.tryNotFound(w, r, storage.New(project)) == nil { + return } + + // Generic 404 + httperrors.Serve404(w) } func endsWithSlash(path string) bool { @@ -512,7 +384,3 @@ func endsWithSlash(path string) bool { func endsWithoutHTMLExtension(path string) bool { return !strings.HasSuffix(path, ".html") } - -func openNoFollow(path string) (*os.File, error) { - return os.OpenFile(path, os.O_RDONLY|unix.O_NOFOLLOW, 0) -} diff --git a/internal/domain/domain_test.go b/internal/domain/domain_test.go index add9b616a..8ff3431f1 100644 --- a/internal/domain/domain_test.go +++ b/internal/domain/domain_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/gitlab-pages/internal/client" "gitlab.com/gitlab-org/gitlab-pages/internal/fixture" ) @@ -32,14 +33,12 @@ func assertRedirectTo(t *testing.T, h http.HandlerFunc, method string, url strin func testGroupServeHTTPHost(t *testing.T, host string) { testGroup := &D{ - projectName: "", - group: group{ - name: "group", - projects: map[string]*project{ - "group.test.io": &project{}, - "group.gitlab-example.com": &project{}, - "project": &project{}, - "project2": &project{}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/group.gitlab-example.com/", Path: "group/group.gitlab-example.com/public/"}, + {Prefix: "/project/", Path: "group/project/public/"}, + {Prefix: "/project2/", Path: "group/project2/public/"}, + {Prefix: "/", Path: "group/group.test.io/public/"}, }, }, } @@ -86,10 +85,10 @@ func TestDomainServeHTTP(t *testing.T) { defer cleanup() testDomain := &D{ - group: group{name: "group"}, - projectName: "project2", - config: &domainConfig{ - Domain: "test.domain.com", + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/project2/public/"}, + }, }, } @@ -114,9 +113,11 @@ func TestIsHTTPSOnly(t *testing.T) { { name: "Custom domain with HTTPS-only enabled", domain: &D{ - group: group{name: "group"}, - projectName: "project", - config: &domainConfig{HTTPSOnly: true}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/project/public/", HTTPSOnly: true}, + }, + }, }, url: "http://custom-domain", expected: true, @@ -124,9 +125,11 @@ func TestIsHTTPSOnly(t *testing.T) { { name: "Custom domain with HTTPS-only disabled", domain: &D{ - group: group{name: "group"}, - projectName: "project", - config: &domainConfig{HTTPSOnly: false}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/project/public/", HTTPSOnly: false}, + }, + }, }, url: "http://custom-domain", expected: false, @@ -134,10 +137,10 @@ func TestIsHTTPSOnly(t *testing.T) { { name: "Default group domain with HTTPS-only enabled", domain: &D{ - projectName: "project", - group: group{ - name: "group", - projects: projects{"test-domain": &project{HTTPSOnly: true}}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/test-domain/public/", HTTPSOnly: true}, + }, }, }, url: "http://test-domain", @@ -146,10 +149,10 @@ func TestIsHTTPSOnly(t *testing.T) { { name: "Default group domain with HTTPS-only disabled", domain: &D{ - projectName: "project", - group: group{ - name: "group", - projects: projects{"test-domain": &project{HTTPSOnly: false}}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/test-domain/public/", HTTPSOnly: false}, + }, }, }, url: "http://test-domain", @@ -158,10 +161,10 @@ func TestIsHTTPSOnly(t *testing.T) { { name: "Case-insensitive default group domain with HTTPS-only enabled", domain: &D{ - projectName: "project", - group: group{ - name: "group", - projects: projects{"test-domain": &project{HTTPSOnly: true}}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/test-domain/public/", HTTPSOnly: true}, + }, }, }, url: "http://Test-domain", @@ -170,10 +173,10 @@ func TestIsHTTPSOnly(t *testing.T) { { name: "Other group domain with HTTPS-only enabled", domain: &D{ - projectName: "project", - group: group{ - name: "group", - projects: projects{"project": &project{HTTPSOnly: true}}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/project/", Path: "group/project/public/", HTTPSOnly: true}, + }, }, }, url: "http://test-domain/project", @@ -182,10 +185,10 @@ func TestIsHTTPSOnly(t *testing.T) { { name: "Other group domain with HTTPS-only disabled", domain: &D{ - projectName: "project", - group: group{ - name: "group", - projects: projects{"project": &project{HTTPSOnly: false}}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/project/", Path: "group/project/public/", HTTPSOnly: false}, + }, }, }, url: "http://test-domain/project", @@ -194,8 +197,11 @@ func TestIsHTTPSOnly(t *testing.T) { { name: "Unknown project", domain: &D{ - group: group{name: "group"}, - projectName: "project", + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/project/", Path: "group/project/public/"}, + }, + }, }, url: "http://test-domain/project", expected: false, @@ -242,14 +248,12 @@ func TestGroupServeHTTPGzip(t *testing.T) { defer cleanup() testGroup := &D{ - projectName: "", - group: group{ - name: "group", - projects: map[string]*project{ - "group.test.io": &project{}, - "group.gitlab-example.com": &project{}, - "project": &project{}, - "project2": &project{}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/group.gitlab-example.com/", Path: "group/group.gitlab-example.com/public/"}, + {Prefix: "/project/", Path: "group/project/public/"}, + {Prefix: "/project2/", Path: "group/project2/public/"}, + {Prefix: "/", Path: "group/group.test.io/public/"}, }, }, } @@ -321,15 +325,12 @@ func TestGroup404ServeHTTP(t *testing.T) { defer cleanup() testGroup := &D{ - projectName: "", - group: group{ - name: "group.404", - projects: map[string]*project{ - "domain.404": &project{}, - "group.404.test.io": &project{}, - "project.404": &project{}, - "project.404.symlink": &project{}, - "project.no.404": &project{}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/domain.404/", Path: "group.404/domain.404/public/"}, + {Prefix: "/project.404/", Path: "group.404/project.404/public/"}, + {Prefix: "/project.no.404/", Path: "group.404/project.no.404/public/"}, + {Prefix: "/", Path: "group.404/group.404.test.io/public/"}, }, }, } @@ -337,12 +338,12 @@ func TestGroup404ServeHTTP(t *testing.T) { testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/project.404/not/existing-file", nil, "Custom 404 project page") testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/project.404/", nil, "Custom 404 project page") testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/not/existing-file", nil, "Custom 404 group page") - testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/not-existing-file", nil, "Custom 404 group page") - testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/", nil, "Custom 404 group page") - assert.HTTPBodyNotContains(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/project.404.symlink/not/existing-file", nil, "Custom 404 project page") + // testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/not-existing-file", nil, "Custom 404 group page") + // testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/", nil, "Custom 404 group page") + // assert.HTTPBodyNotContains(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/project.404.symlink/not/existing-file", nil, "Custom 404 project page") - // Ensure the namespace project's custom 404.html is not used by projects - testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/project.no.404/not/existing-file", nil, "The page you're looking for could not be found.") + // // Ensure the namespace project's custom 404.html is not used by projects + // testHTTP404(t, serveFileOrNotFound(testGroup), "GET", "http://group.404.test.io/project.no.404/not/existing-file", nil, "The page you're looking for could not be found.") } func TestDomain404ServeHTTP(t *testing.T) { @@ -350,10 +351,10 @@ func TestDomain404ServeHTTP(t *testing.T) { defer cleanup() testDomain := &D{ - group: group{name: "group.404"}, - projectName: "domain.404", - config: &domainConfig{ - Domain: "domain.404.com", + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group.404/domain.404/public/"}, + }, }, } @@ -366,7 +367,7 @@ func TestPredefined404ServeHTTP(t *testing.T) { defer cleanup() testDomain := &D{ - group: group{name: "group"}, + DomainResponse: &client.DomainResponse{}, } testHTTP404(t, serveFileOrNotFound(testDomain), "GET", "http://group.test.io/not-existing-file", nil, "The page you're looking for could not be found") @@ -374,45 +375,44 @@ func TestPredefined404ServeHTTP(t *testing.T) { func TestGroupCertificate(t *testing.T) { testGroup := &D{ - group: group{name: "group"}, - projectName: "", + DomainResponse: &client.DomainResponse{}, } - tls, err := testGroup.EnsureCertificate() + tls, err := testGroup.Certificate() assert.Nil(t, tls) assert.Error(t, err) } func TestDomainNoCertificate(t *testing.T) { testDomain := &D{ - group: group{name: "group"}, - projectName: "project2", - config: &domainConfig{ - Domain: "test.domain.com", + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/project2/public/"}, + }, }, } - tls, err := testDomain.EnsureCertificate() + tls, err := testDomain.Certificate() assert.Nil(t, tls) assert.Error(t, err) - _, err2 := testDomain.EnsureCertificate() + _, err2 := testDomain.Certificate() assert.Error(t, err) assert.Equal(t, err, err2) } func TestDomainCertificate(t *testing.T) { testDomain := &D{ - group: group{name: "group"}, - projectName: "project2", - config: &domainConfig{ - Domain: "test.domain.com", + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/project2/public/"}, + }, Certificate: fixture.Certificate, Key: fixture.Key, }, } - tls, err := testDomain.EnsureCertificate() + tls, err := testDomain.Certificate() assert.NotNil(t, tls) require.NoError(t, err) } @@ -422,10 +422,9 @@ func TestCacheControlHeaders(t *testing.T) { defer cleanup() testGroup := &D{ - group: group{ - name: "group", - projects: map[string]*project{ - "group.test.io": &project{}, + DomainResponse: &client.DomainResponse{ + LookupPath: []client.LookupPath{ + {Prefix: "/", Path: "group/group.test.io/public/"}, }, }, } @@ -448,28 +447,28 @@ func TestCacheControlHeaders(t *testing.T) { assert.WithinDuration(t, now.UTC().Add(10*time.Minute), expiresTime.UTC(), time.Minute) } -func TestOpenNoFollow(t *testing.T) { - tmpfile, err := ioutil.TempFile("", "link-test") - require.NoError(t, err) - defer tmpfile.Close() +// func TestOpenNoFollow(t *testing.T) { +// tmpfile, err := ioutil.TempFile("", "link-test") +// require.NoError(t, err) +// defer tmpfile.Close() - orig := tmpfile.Name() - softLink := orig + ".link" - defer os.Remove(orig) +// orig := tmpfile.Name() +// softLink := orig + ".link" +// defer os.Remove(orig) - source, err := openNoFollow(orig) - require.NoError(t, err) - require.NotNil(t, source) - defer source.Close() +// source, err := openNoFollow(orig) +// require.NoError(t, err) +// require.NotNil(t, source) +// defer source.Close() - err = os.Symlink(orig, softLink) - require.NoError(t, err) - defer os.Remove(softLink) +// err = os.Symlink(orig, softLink) +// require.NoError(t, err) +// defer os.Remove(softLink) - link, err := openNoFollow(softLink) - require.Error(t, err) - require.Nil(t, link) -} +// link, err := openNoFollow(softLink) +// require.Error(t, err) +// require.Nil(t, link) +// } var chdirSet = false diff --git a/internal/domain/group.go b/internal/domain/group.go deleted file mode 100644 index 83b8d2556..000000000 --- a/internal/domain/group.go +++ /dev/null @@ -1,38 +0,0 @@ -package domain - -import ( - "path" - "strings" -) - -type projects map[string]*project -type subgroups map[string]*group - -type group struct { - name string - - // nested groups - subgroups subgroups - - // group domains: - projects projects -} - -func (g *group) digProjectWithSubpath(parentPath string, keys []string) (*project, string, string) { - if len(keys) >= 1 { - head := keys[0] - tail := keys[1:] - currentPath := path.Join(parentPath, head) - search := strings.ToLower(head) - - if project := g.projects[search]; project != nil { - return project, currentPath, path.Join(tail...) - } - - if subgroup := g.subgroups[search]; subgroup != nil { - return subgroup.digProjectWithSubpath(currentPath, tail) - } - } - - return nil, "", "" -} diff --git a/internal/domain/group_test.go b/internal/domain/group_test.go deleted file mode 100644 index 2e41ef535..000000000 --- a/internal/domain/group_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package domain - -import ( - "strings" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGroupDig(t *testing.T) { - matchingProject := &project{ID: 1} - - tests := []struct { - name string - g group - path string - expectedProject *project - expectedProjectPath string - expectedPath string - }{ - { - name: "empty group", - path: "projectb/demo/features.html", - g: group{}, - }, - { - name: "group with project", - path: "projectb/demo/features.html", - g: group{ - projects: projects{"projectb": matchingProject}, - }, - expectedProject: matchingProject, - expectedProjectPath: "projectb", - expectedPath: "demo/features.html", - }, - { - name: "group with project and no path in URL", - path: "projectb", - g: group{ - projects: projects{"projectb": matchingProject}, - }, - expectedProject: matchingProject, - expectedProjectPath: "projectb", - }, - { - name: "group with subgroup and project", - path: "projectb/demo/features.html", - g: group{ - projects: projects{"projectb": matchingProject}, - subgroups: subgroups{ - "sub1": &group{ - projects: projects{"another": &project{}}, - }, - }, - }, - expectedProject: matchingProject, - expectedProjectPath: "projectb", - expectedPath: "demo/features.html", - }, - { - name: "group with project inside a subgroup", - path: "sub1/projectb/demo/features.html", - g: group{ - subgroups: subgroups{ - "sub1": &group{ - projects: projects{"projectb": matchingProject}, - }, - }, - projects: projects{"another": &project{}}, - }, - expectedProject: matchingProject, - expectedProjectPath: "sub1/projectb", - expectedPath: "demo/features.html", - }, - { - name: "group with matching subgroup but no project", - path: "sub1/projectb/demo/features.html", - g: group{ - subgroups: subgroups{ - "sub1": &group{ - projects: projects{"another": &project{}}, - }, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - project, projectPath, urlPath := test.g.digProjectWithSubpath("", strings.Split(test.path, "/")) - - assert.Equal(t, test.expectedProject, project) - assert.Equal(t, test.expectedProjectPath, projectPath) - assert.Equal(t, test.expectedPath, urlPath) - }) - } -} diff --git a/internal/domain/map.go b/internal/domain/map.go deleted file mode 100644 index 2891a272a..000000000 --- a/internal/domain/map.go +++ /dev/null @@ -1,299 +0,0 @@ -package domain - -import ( - "bytes" - "io/ioutil" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/karrick/godirwalk" - log "github.com/sirupsen/logrus" - - "gitlab.com/gitlab-org/gitlab-pages/metrics" -) - -// Map maps domain names to D instances. -type Map map[string]*D - -type domainsUpdater func(Map) - -func (dm Map) updateDomainMap(domainName string, domain *D) { - if old, ok := dm[domainName]; ok { - log.WithFields(log.Fields{ - "domain_name": domainName, - "new_group": domain.group, - "new_project_name": domain.projectName, - "old_group": old.group, - "old_project_name": old.projectName, - }).Error("Duplicate domain") - } - - dm[domainName] = domain -} - -func (dm Map) addDomain(rootDomain, groupName, projectName string, config *domainConfig) { - newDomain := &D{ - group: group{name: groupName}, - projectName: projectName, - config: config, - } - - var domainName string - domainName = strings.ToLower(config.Domain) - dm.updateDomainMap(domainName, newDomain) -} - -func (dm Map) updateGroupDomain(rootDomain, groupName, projectPath string, httpsOnly bool, accessControl bool, id uint64) { - domainName := strings.ToLower(groupName + "." + rootDomain) - groupDomain := dm[domainName] - - if groupDomain == nil { - groupDomain = &D{ - group: group{ - name: groupName, - projects: make(projects), - subgroups: make(subgroups), - }, - } - } - - split := strings.SplitN(strings.ToLower(projectPath), "/", maxProjectDepth) - projectName := split[len(split)-1] - g := &groupDomain.group - - for i := 0; i < len(split)-1; i++ { - subgroupName := split[i] - subgroup := g.subgroups[subgroupName] - if subgroup == nil { - subgroup = &group{ - name: subgroupName, - projects: make(projects), - subgroups: make(subgroups), - } - g.subgroups[subgroupName] = subgroup - } - - g = subgroup - } - - g.projects[projectName] = &project{ - NamespaceProject: domainName == projectName, - HTTPSOnly: httpsOnly, - AccessControl: accessControl, - ID: id, - } - - dm[domainName] = groupDomain -} - -func (dm Map) readProjectConfig(rootDomain string, group, projectName string, config *domainsConfig) { - if config == nil { - // This is necessary to preserve the previous behaviour where a - // group domain is created even if no config.json files are - // loaded successfully. Is it safe to remove this? - dm.updateGroupDomain(rootDomain, group, projectName, false, false, 0) - return - } - - dm.updateGroupDomain(rootDomain, group, projectName, config.HTTPSOnly, config.AccessControl, config.ID) - - for _, domainConfig := range config.Domains { - config := domainConfig // domainConfig is reused for each loop iteration - if domainConfig.Valid(rootDomain) { - dm.addDomain(rootDomain, group, projectName, &config) - } - } -} - -func readProject(group, parent, projectName string, level int, fanIn chan<- jobResult) { - if strings.HasPrefix(projectName, ".") { - return - } - - // Ignore projects that have .deleted in name - if strings.HasSuffix(projectName, ".deleted") { - return - } - - projectPath := filepath.Join(parent, projectName) - if _, err := os.Lstat(filepath.Join(group, projectPath, "public")); err != nil { - // maybe it's a subgroup - if level <= subgroupScanLimit { - buf := make([]byte, 2*os.Getpagesize()) - readProjects(group, projectPath, level+1, buf, fanIn) - } - - return - } - - // We read the config.json file _before_ fanning in, because it does disk - // IO and it does not need access to the domains map. - config := &domainsConfig{} - if err := config.Read(group, projectPath); err != nil { - config = nil - } - - fanIn <- jobResult{group: group, project: projectPath, config: config} -} - -func readProjects(group, parent string, level int, buf []byte, fanIn chan<- jobResult) { - subgroup := filepath.Join(group, parent) - fis, err := godirwalk.ReadDirents(subgroup, buf) - if err != nil { - log.WithError(err).WithFields(log.Fields{ - "group": group, - "parent": parent, - }).Print("readdir failed") - return - } - - for _, project := range fis { - // Ignore non directories - if !project.IsDir() { - continue - } - - readProject(group, parent, project.Name(), level, fanIn) - } -} - -type jobResult struct { - group string - project string - config *domainsConfig -} - -// ReadGroups walks the pages directory and populates dm with all the domains it finds. -func (dm Map) ReadGroups(rootDomain string, fis godirwalk.Dirents) { - fanOutGroups := make(chan string) - fanIn := make(chan jobResult) - wg := &sync.WaitGroup{} - for i := 0; i < 4; i++ { - wg.Add(1) - - go func() { - buf := make([]byte, 2*os.Getpagesize()) - - for group := range fanOutGroups { - started := time.Now() - - readProjects(group, "", 0, buf, fanIn) - - log.WithFields(log.Fields{ - "group": group, - "duration": time.Since(started).Seconds(), - }).Debug("Loaded projects for group") - } - - wg.Done() - }() - } - - go func() { - wg.Wait() - close(fanIn) - }() - - done := make(chan struct{}) - go func() { - for result := range fanIn { - dm.readProjectConfig(rootDomain, result.group, result.project, result.config) - } - - close(done) - }() - - for _, group := range fis { - if !group.IsDir() { - continue - } - if strings.HasPrefix(group.Name(), ".") { - continue - } - fanOutGroups <- group.Name() - } - close(fanOutGroups) - - <-done -} - -const ( - updateFile = ".update" -) - -// Watch polls the filesystem and kicks off a new domain directory scan when needed. -func Watch(rootDomain string, updater domainsUpdater, interval time.Duration) { - lastUpdate := []byte("no-update") - - for { - // Read the update file - update, err := ioutil.ReadFile(updateFile) - if err != nil && !os.IsNotExist(err) { - log.WithError(err).Print("failed to read update timestamp") - time.Sleep(interval) - continue - } - - // If it's the same ignore - if bytes.Equal(lastUpdate, update) { - time.Sleep(interval) - continue - } - lastUpdate = update - - started := time.Now() - dm := make(Map) - - fis, err := godirwalk.ReadDirents(".", nil) - if err != nil { - log.WithError(err).Warn("domain scan failed") - metrics.FailedDomainUpdates.Inc() - continue - } - - dm.ReadGroups(rootDomain, fis) - duration := time.Since(started).Seconds() - - var hash string - if len(update) < 1 { - hash = "" - } else { - hash = strings.TrimSpace(string(update)) - } - - logConfiguredDomains(dm) - - log.WithFields(log.Fields{ - "count(domains)": len(dm), - "duration": duration, - "hash": hash, - }).Info("Updated all domains") - - if updater != nil { - updater(dm) - } - - // Update prometheus metrics - metrics.DomainLastUpdateTime.Set(float64(time.Now().UTC().Unix())) - metrics.DomainsServed.Set(float64(len(dm))) - metrics.DomainUpdates.Inc() - - time.Sleep(interval) - } -} - -func logConfiguredDomains(dm Map) { - if log.GetLevel() != log.DebugLevel { - return - } - - for h, d := range dm { - log.WithFields(log.Fields{ - "domain": d, - "host": h, - }).Debug("Configured domain") - } -} diff --git a/internal/domain/map_test.go b/internal/domain/map_test.go deleted file mode 100644 index dc5e8648d..000000000 --- a/internal/domain/map_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package domain - -import ( - "crypto/rand" - "fmt" - "io/ioutil" - "os" - "strings" - "testing" - "time" - - "github.com/karrick/godirwalk" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func getEntries(t *testing.T) godirwalk.Dirents { - fis, err := godirwalk.ReadDirents(".", nil) - - require.NoError(t, err) - - return fis -} - -func getEntriesForBenchmark(t *testing.B) godirwalk.Dirents { - fis, err := godirwalk.ReadDirents(".", nil) - - require.NoError(t, err) - - return fis -} - -func TestReadProjects(t *testing.T) { - cleanup := setUpTests(t) - defer cleanup() - - dm := make(Map) - dm.ReadGroups("test.io", getEntries(t)) - - var domains []string - for d := range dm { - domains = append(domains, d) - } - - expectedDomains := []string{ - "group.test.io", - "group.internal.test.io", - "test.domain.com", // from config.json - "other.domain.com", - "domain.404.com", - "group.404.test.io", - "group.https-only.test.io", - "test.my-domain.com", - "test2.my-domain.com", - "no.cert.com", - "private.domain.com", - "group.auth.test.io", - "capitalgroup.test.io", - } - - for _, expected := range domains { - assert.Contains(t, domains, expected) - } - - for _, actual := range domains { - assert.Contains(t, expectedDomains, actual) - } - - // Check that multiple domains in the same project are recorded faithfully - exp1 := &domainConfig{Domain: "test.domain.com"} - assert.Equal(t, exp1, dm["test.domain.com"].config) - - exp2 := &domainConfig{Domain: "other.domain.com", Certificate: "test", Key: "key"} - assert.Equal(t, exp2, dm["other.domain.com"].config) - - // check subgroups - domain, ok := dm["group.test.io"] - require.True(t, ok, "missing group.test.io domain") - subgroup, ok := domain.subgroups["subgroup"] - require.True(t, ok, "missing group.test.io subgroup") - _, ok = subgroup.projects["project"] - require.True(t, ok, "missing project for subgroup in group.test.io domain") -} - -func TestReadProjectsMaxDepth(t *testing.T) { - nGroups := 3 - levels := subgroupScanLimit + 5 - cleanup := buildFakeDomainsDirectory(t, nGroups, levels) - defer cleanup() - - defaultDomain := "test.io" - dm := make(Map) - dm.ReadGroups(defaultDomain, getEntries(t)) - - var domains []string - for d := range dm { - domains = append(domains, d) - } - - var expectedDomains []string - for i := 0; i < nGroups; i++ { - expectedDomains = append(expectedDomains, fmt.Sprintf("group-%d.%s", i, defaultDomain)) - } - - for _, expected := range domains { - assert.Contains(t, domains, expected) - } - - for _, actual := range domains { - // we are not checking config.json domains here - if !strings.HasSuffix(actual, defaultDomain) { - continue - } - assert.Contains(t, expectedDomains, actual) - } - - // check subgroups - domain, ok := dm["group-0.test.io"] - require.True(t, ok, "missing group-0.test.io domain") - subgroup := &domain.group - for i := 0; i < levels; i++ { - subgroup, ok = subgroup.subgroups["sub"] - if i <= subgroupScanLimit { - require.True(t, ok, "missing group-0.test.io subgroup at level %d", i) - _, ok = subgroup.projects["project-0"] - require.True(t, ok, "missing project for subgroup in group-0.test.io domain at level %d", i) - } else { - require.False(t, ok, "subgroup level %d. Maximum allowed nesting level is %d", i, subgroupScanLimit) - break - } - } -} - -// This write must be atomic, otherwise we cannot predict the state of the -// domain watcher goroutine. We cannot use ioutil.WriteFile because that -// has a race condition where the file is empty, which can get picked up -// by the domain watcher. -func writeRandomTimestamp(t *testing.T) { - b := make([]byte, 10) - n, _ := rand.Read(b) - require.True(t, n > 0, "read some random bytes") - - temp, err := ioutil.TempFile(".", "TestWatch") - require.NoError(t, err) - _, err = temp.Write(b) - require.NoError(t, err, "write to tempfile") - require.NoError(t, temp.Close(), "close tempfile") - - require.NoError(t, os.Rename(temp.Name(), updateFile), "rename tempfile") -} - -func TestWatch(t *testing.T) { - cleanup := setUpTests(t) - defer cleanup() - - require.NoError(t, os.RemoveAll(updateFile)) - - update := make(chan Map) - go Watch("gitlab.io", func(dm Map) { - update <- dm - }, time.Microsecond*50) - - defer os.Remove(updateFile) - - domains := recvTimeout(t, update) - assert.NotNil(t, domains, "if the domains are fetched on start") - - writeRandomTimestamp(t) - domains = recvTimeout(t, update) - assert.NotNil(t, domains, "if the domains are updated after the creation") - - writeRandomTimestamp(t) - domains = recvTimeout(t, update) - assert.NotNil(t, domains, "if the domains are updated after the timestamp change") -} - -func recvTimeout(t *testing.T, ch <-chan Map) Map { - timeout := 5 * time.Second - - select { - case dm := <-ch: - return dm - case <-time.After(timeout): - t.Fatalf("timeout after %v waiting for domain update", timeout) - return nil - } -} - -func buildFakeDomainsDirectory(t require.TestingT, nGroups, levels int) func() { - testRoot, err := ioutil.TempDir("", "gitlab-pages-test") - require.NoError(t, err) - - for i := 0; i < nGroups; i++ { - parent := fmt.Sprintf("%s/group-%d", testRoot, i) - domain := fmt.Sprintf("%d.example.io", i) - buildFakeProjectsDirectory(t, parent, domain) - for j := 0; j < levels; j++ { - parent = fmt.Sprintf("%s/sub", parent) - domain = fmt.Sprintf("%d.%s", j, domain) - buildFakeProjectsDirectory(t, parent, domain) - } - if i%100 == 0 { - fmt.Print(".") - } - } - - cleanup := chdirInPath(t, testRoot) - - return func() { - defer cleanup() - fmt.Printf("cleaning up test directory %s\n", testRoot) - os.RemoveAll(testRoot) - } -} - -func buildFakeProjectsDirectory(t require.TestingT, groupPath, domain string) { - for j := 0; j < 5; j++ { - dir := fmt.Sprintf("%s/project-%d", groupPath, j) - require.NoError(t, os.MkdirAll(dir+"/public", 0755)) - - fakeConfig := fmt.Sprintf(`{"Domains":[{"Domain":"foo.%d.%s","Certificate":"bar","Key":"baz"}]}`, j, domain) - require.NoError(t, ioutil.WriteFile(dir+"/config.json", []byte(fakeConfig), 0644)) - } -} - -func BenchmarkReadGroups(b *testing.B) { - nGroups := 10000 - b.Logf("creating fake domains directory with %d groups", nGroups) - cleanup := buildFakeDomainsDirectory(b, nGroups, 0) - defer cleanup() - - b.Run("ReadGroups", func(b *testing.B) { - var dm Map - for i := 0; i < 2; i++ { - dm = make(Map) - dm.ReadGroups("example.com", getEntriesForBenchmark(b)) - } - b.Logf("found %d domains", len(dm)) - }) -} diff --git a/internal/fixture/mock_api.go b/internal/fixture/mock_api.go new file mode 100644 index 000000000..2c0ad952c --- /dev/null +++ b/internal/fixture/mock_api.go @@ -0,0 +1,24 @@ +package fixture + +import ( + "errors" + + "gitlab.com/gitlab-org/gitlab-pages/internal/client" +) + +// MockAPI provides a preconfigured set of domains +// for testing purposes +type MockAPI struct{} + +// RequestDomain request a host from preconfigured list of domains +func (a *MockAPI) RequestDomain(host string) (*client.DomainResponse, error) { + if response, ok := internalConfigs[host]; ok { + return &response, nil + } + + return nil, errors.New("not found") +} + +func (a *MockAPI) IsReady() bool { + return true +} diff --git a/internal/fixture/mock_server.go b/internal/fixture/mock_server.go new file mode 100644 index 000000000..0e215c52f --- /dev/null +++ b/internal/fixture/mock_server.go @@ -0,0 +1,23 @@ +package fixture + +import ( + "encoding/json" + "net/http" +) + +func MockHTTPHandler(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v4/pages/domain" { + w.WriteHeader(http.StatusNotImplemented) + return + } + + host := r.FormValue("host") + config, ok := internalConfigs[host] + if !ok { + w.WriteHeader(http.StatusNotFound) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(&config) +} diff --git a/internal/fixture/shared_pages_config.go b/internal/fixture/shared_pages_config.go new file mode 100644 index 000000000..e1667c480 --- /dev/null +++ b/internal/fixture/shared_pages_config.go @@ -0,0 +1,262 @@ +package fixture + +import ( + "gitlab.com/gitlab-org/gitlab-pages/internal/client" +) + +var internalConfigs = map[string]client.DomainResponse{ + "group.internal.gitlab-example.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/project.internal/", + Path: "group.internal/project.internal/public", + }, + }, + }, + "group.404.gitlab-example.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/project.no.404/", + Path: "group.404/project.no.404/public/", + }, + client.LookupPath{ + Prefix: "/project.404/", + Path: "group.404/project.404/public/", + }, + client.LookupPath{ + Prefix: "/project.404.symlink/", + Path: "group.404/project.404.symlink/public/", + }, + client.LookupPath{ + Prefix: "/domain.404/", + Path: "group.404/domain.404/public/", + }, + client.LookupPath{ + Prefix: "/group.404.test.io/", + Path: "group.404/group.404.test.io/public/", + }, + }, + }, + "capitalgroup.gitlab-example.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/CapitalProject/", + Path: "CapitalGroup/CapitalProject/public/", + }, + client.LookupPath{ + Prefix: "/project/", + Path: "CapitalGroup/project/public/", + }, + }, + }, + "group.auth.gitlab-example.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/private.project/", + Path: "group.auth/private.project/public/", + AccessControl: true, + ProjectID: 1000, + }, + client.LookupPath{ + Prefix: "/private.project.1/", + Path: "group.auth/private.project.1/public/", + AccessControl: true, + ProjectID: 2000, + }, + client.LookupPath{ + Prefix: "/private.project.2/", + Path: "group.auth/private.project.2/public/", + AccessControl: true, + ProjectID: 3000, + }, + client.LookupPath{ + Prefix: "/subgroup/private.project/", + Path: "group.auth/subgroup/private.project/public/", + AccessControl: true, + ProjectID: 1001, + }, + client.LookupPath{ + Prefix: "/subgroup/private.project.1/", + Path: "group.auth/subgroup/private.project.1/public/", + AccessControl: true, + ProjectID: 2001, + }, + client.LookupPath{ + Prefix: "/subgroup/private.project.2/", + Path: "group.auth/subgroup/private.project.2/public/", + AccessControl: true, + ProjectID: 3001, + }, + client.LookupPath{ + Prefix: "/group.auth.gitlab-example.com/", + Path: "group.auth/group.auth.gitlab-example.com/public/", + }, + client.LookupPath{ + Prefix: "/", + Path: "group.auth/group.auth.gitlab-example.com/public/", + NamespaceProject: true, + }, + }, + }, + "group.https-only.gitlab-example.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/project5/", + Path: "group.https-only/project5/public/", + HTTPSOnly: true, + }, + client.LookupPath{ + Prefix: "/project4/", + Path: "group.https-only/project4/public/", + }, + client.LookupPath{ + Prefix: "/project3/", + Path: "group.https-only/project3/public/", + }, + client.LookupPath{ + Prefix: "/project2/", + Path: "group.https-only/project2/public/", + }, + client.LookupPath{ + Prefix: "/project1/", + Path: "group.https-only/project1/public/", + HTTPSOnly: true, + }, + client.LookupPath{ + Prefix: "/", + Path: "group.auth/group.auth.gitlab-example.com/public/", + NamespaceProject: true, + }, + }, + }, + "group.gitlab-example.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/CapitalProject/", + Path: "group/CapitalProject/public/", + }, + client.LookupPath{ + Prefix: "/project/", + Path: "group/project/public/", + }, + client.LookupPath{ + Prefix: "/project2/", + Path: "group/project2/public/", + }, + client.LookupPath{ + Prefix: "/subgroup/project/", + Path: "group/subgroup/project/public/", + }, + client.LookupPath{ + Prefix: "/group.test.io/", + Path: "group/group.test.io/public/", + }, + client.LookupPath{ + Prefix: "/", + Path: "group/group.gitlab-example.com/public/", + NamespaceProject: true, + }, + }, + }, + "nested.gitlab-example.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/sub1/sub2/sub3/sub4/sub5/project/", + Path: "nested/sub1/sub2/sub3/sub4/sub5/project/public/", + }, + client.LookupPath{ + Prefix: "/sub1/sub2/sub3/sub4/project/", + Path: "nested/sub1/sub2/sub3/sub4/project/public/", + }, + client.LookupPath{ + Prefix: "/sub1/sub2/sub3/project/", + Path: "nested/sub1/sub2/sub3/project/public/", + }, + client.LookupPath{ + Prefix: "/sub1/sub2/project/", + Path: "nested/sub1/sub2/project/public/", + }, + client.LookupPath{ + Prefix: "/sub1/project/", + Path: "nested/sub1/project/public/", + }, + client.LookupPath{ + Prefix: "/project/", + Path: "nested/project/public/", + }, + }, + }, + + // custom domains + "domain.404.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/", + Path: "group.404/domain.404.com/public/", + }, + }, + }, + "private.domain.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/", + Path: "group.auth/private.project/public/", + AccessControl: true, + ProjectID: 1000, + }, + }, + }, + "no.cert.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/", + Path: "group.https-only/project5/public/", + HTTPSOnly: false, + }, + }, + }, + "test2.my-domain.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/", + Path: "group.https-only/project4/public/", + HTTPSOnly: false, + }, + }, + }, + "test.my-domain.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/", + Path: "group.https-only/project3/public/", + HTTPSOnly: true, + }, + }, + }, + "test.domain.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/", + Path: "group/group.test.io/public/", + }, + }, + }, + "my.test.io": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/", + Path: "group/group.test.io/public/", + }, + }, + }, + "other.domain.com": client.DomainResponse{ + LookupPath: []client.LookupPath{ + client.LookupPath{ + Prefix: "/", + Path: "group/group.test.io/public/", + }, + }, + Certificate: "test", + Key: "key", + }, +} diff --git a/internal/storage/file_system.go b/internal/storage/file_system.go new file mode 100644 index 000000000..74a3bd913 --- /dev/null +++ b/internal/storage/file_system.go @@ -0,0 +1,78 @@ +package storage + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" + + "gitlab.com/gitlab-org/gitlab-pages/internal/client" +) + +type fileSystem struct { + *client.LookupPath +} + +func (f *fileSystem) rootPath() string { + fullPath, err := filepath.EvalSymlinks(filepath.Join(f.Path)) + if err != nil { + return "" + } + + return fullPath +} + +func (f *fileSystem) resolvePath(path string) (string, error) { + fullPath := filepath.Join(f.rootPath(), path) + fullPath, err := filepath.EvalSymlinks(fullPath) + if err != nil { + return "", err + } + + // The requested path resolved to somewhere outside of the root directory + if !strings.HasPrefix(fullPath, f.rootPath()) { + return "", fmt.Errorf("%q should be in %q", fullPath, f.rootPath()) + } + + return fullPath, nil +} + +func (f *fileSystem) Resolve(path string) (string, error) { + fullPath, err := f.resolvePath(path) + if err != nil { + return "", err + } + + return fullPath[len(f.rootPath()):], nil +} + +func (f *fileSystem) Stat(path string) (os.FileInfo, error) { + fullPath, err := f.resolvePath(path) + if err != nil { + return nil, err + } + + return os.Lstat(fullPath) +} + +func (f *fileSystem) Open(path string) (File, os.FileInfo, error) { + fullPath, err := f.resolvePath(path) + if err != nil { + return nil, nil, err + } + + file, err := os.OpenFile(fullPath, os.O_RDONLY|unix.O_NOFOLLOW, 0) + if err != nil { + return nil, nil, err + } + + fileInfo, err := file.Stat() + if err != nil { + file.Close() + return nil, nil, err + } + + return file, fileInfo, err +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go new file mode 100644 index 000000000..00eef899c --- /dev/null +++ b/internal/storage/storage.go @@ -0,0 +1,29 @@ +package storage + +import ( + "io" + "os" + + "gitlab.com/gitlab-org/gitlab-pages/internal/client" +) + +// File provides a basic required interface +// to interact with the file, to read, stat, and seek +type File interface { + io.Reader + io.Seeker + io.Closer +} + +// S provides a basic interface to resolve and read files +// from the storage +type S interface { + Resolve(path string) (string, error) + Stat(path string) (os.FileInfo, error) + Open(path string) (File, os.FileInfo, error) +} + +// New provides a compatible storage with lookupPath +func New(lookupPath *client.LookupPath) S { + return &fileSystem{lookupPath} +} diff --git a/main.go b/main.go index 126ad69ed..a3eca7815 100644 --- a/main.go +++ b/main.go @@ -33,6 +33,9 @@ var ( pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") + apiServer = flag.String("api-server", "", "API URL to GitLab: 'https://gitlab.com/api/v4'") + apiServerTimeout = flag.Int("api-server-timeout", 10, "Timeout (in seconds) for API requests") + apiServerKey = flag.String("api-server-key", "", "File containing the API secret key") pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") metricsAddress = flag.String("metrics-address", "", "The address to listen on for metrics requests") daemonUID = flag.Uint("daemon-uid", 0, "Drop privileges to this user") @@ -64,6 +67,9 @@ var ( errArtifactSchemaUnsupported = errors.New("artifacts-server scheme must be either http:// or https://") errArtifactsServerTimeoutValue = errors.New("artifacts-server-timeout must be greater than or equal to 1") + errAPISchemaUnsupported = errors.New("api-server scheme must be either http:// or https://") + errAPIServerTimeoutValue = errors.New("api-server-timeout must be greater than or equal to 1") + errSecretNotDefined = errors.New("auth-secret must be defined if authentication is supported") errClientIDNotDefined = errors.New("auth-client-id must be defined if authentication is supported") errClientSecretNotDefined = errors.New("auth-client-secret must be defined if authentication is supported") @@ -92,45 +98,64 @@ func configFromFlags() appConfig { {&config.AdminCertificate, *adminHTTPSCert}, {&config.AdminKey, *adminHTTPSKey}, {&config.AdminToken, *adminSecretPath}, + {&config.APIServerKey, *apiServerKey}, } { if file.path != "" { *file.contents = readFile(file.path) } } - if *artifactsServerTimeout < 1 { - log.Fatal(errArtifactsServerTimeoutValue) - } - - if *artifactsServer != "" { - u, err := url.Parse(*artifactsServer) - if err != nil { - log.Fatal(err) - } - // url.Parse ensures that the Scheme arttribute is always lower case. - if u.Scheme != "http" && u.Scheme != "https" { - log.Fatal(errArtifactSchemaUnsupported) - } - - if *artifactsServerTimeout < 1 { - log.Fatal(errArtifactsServerTimeoutValue) - } - - config.ArtifactsServerTimeout = *artifactsServerTimeout - config.ArtifactsServer = *artifactsServer - } - - checkAuthenticationConfig(config) - + config.APIServerTimeout = *apiServerTimeout + config.APIServer = *apiServer + config.ArtifactsServerTimeout = *artifactsServerTimeout + config.ArtifactsServer = *artifactsServer config.StoreSecret = *secret config.ClientID = *clientID config.ClientSecret = *clientSecret config.GitLabServer = *gitLabServer config.RedirectURI = *redirectURI + checkArtifactsConfig(config) + checkAPIConfig(config) + checkAuthenticationConfig(config) + return config } +func checkArtifactsConfig(config appConfig) { + if *artifactsServer == "" { + return + } + + u, err := url.Parse(*artifactsServer) + if err != nil { + log.Fatal(err) + } + // url.Parse ensures that the Scheme arttribute is always lower case. + if u.Scheme != "http" && u.Scheme != "https" { + log.Fatal(errArtifactSchemaUnsupported) + } + + if *artifactsServerTimeout < 1 { + log.Fatal(errArtifactsServerTimeoutValue) + } +} + +func checkAPIConfig(config appConfig) { + u, err := url.Parse(*apiServer) + if err != nil { + log.Fatal(err) + } + // url.Parse ensures that the Scheme arttribute is always lower case. + if u.Scheme != "http" && u.Scheme != "https" { + log.Fatal(errAPISchemaUnsupported) + } + + if *apiServerTimeout < 1 { + log.Fatal(errAPIServerTimeoutValue) + } +} + func checkAuthenticationConfig(config appConfig) { if *secret != "" || *clientID != "" || *clientSecret != "" || *gitLabServer != "" || *redirectURI != "" { @@ -186,6 +211,9 @@ func appMain() { "admin-https-listener": *adminHTTPSListener, "admin-unix-listener": *adminUnixListener, "admin-secret-path": *adminSecretPath, + "api-server": *apiServer, + "api-server-timeout": *apiServerTimeout, + "api-server-key": *apiServerKey, "artifacts-server": *artifactsServer, "artifacts-server-timeout": *artifactsServerTimeout, "daemon-gid": *daemonGID, diff --git a/metrics/metrics.go b/metrics/metrics.go index 44350ae57..0aa65f557 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -5,30 +5,6 @@ import ( ) var ( - // DomainsServed counts the total number of sites served - DomainsServed = prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "gitlab_pages_domains_served_total", - Help: "The total number of sites served by this Pages app", - }) - - // FailedDomainUpdates counts the number of failed site updates - FailedDomainUpdates = prometheus.NewCounter(prometheus.CounterOpts{ - Name: "gitlab_pages_domains_failed_total", - Help: "The total number of site updates that have failed since daemon start", - }) - - // DomainUpdates counts the number of site updates successfully processed - DomainUpdates = prometheus.NewCounter(prometheus.CounterOpts{ - Name: "gitlab_pages_domains_updated_total", - Help: "The total number of site updates successfully processed since daemon start", - }) - - // DomainLastUpdateTime is the UNIX timestamp of the last update - DomainLastUpdateTime = prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "gitlab_pages_last_domain_update_seconds", - Help: "UNIX timestamp of the last update", - }) - // ProcessedRequests is the number of HTTP requests served ProcessedRequests = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "gitlab_pages_http_requests_total", @@ -45,9 +21,6 @@ var ( ) func init() { - prometheus.MustRegister(DomainsServed) - prometheus.MustRegister(DomainUpdates) - prometheus.MustRegister(DomainLastUpdateTime) prometheus.MustRegister(ProcessedRequests) prometheus.MustRegister(SessionsActive) } diff --git a/shared/pages/nested/project/public/index.html b/shared/pages/nested/project/public/index.html new file mode 100644 index 000000000..b2d525b29 --- /dev/null +++ b/shared/pages/nested/project/public/index.html @@ -0,0 +1 @@ +index \ No newline at end of file diff --git a/shared/pages/nested/sub1/project/public/index.html b/shared/pages/nested/sub1/project/public/index.html new file mode 100644 index 000000000..b2d525b29 --- /dev/null +++ b/shared/pages/nested/sub1/project/public/index.html @@ -0,0 +1 @@ +index \ No newline at end of file diff --git a/shared/pages/nested/sub1/sub2/project/public/index.html b/shared/pages/nested/sub1/sub2/project/public/index.html new file mode 100644 index 000000000..b2d525b29 --- /dev/null +++ b/shared/pages/nested/sub1/sub2/project/public/index.html @@ -0,0 +1 @@ +index \ No newline at end of file diff --git a/shared/pages/nested/sub1/sub2/sub3/project/public/index.html b/shared/pages/nested/sub1/sub2/sub3/project/public/index.html new file mode 100644 index 000000000..b2d525b29 --- /dev/null +++ b/shared/pages/nested/sub1/sub2/sub3/project/public/index.html @@ -0,0 +1 @@ +index \ No newline at end of file diff --git a/shared/pages/nested/sub1/sub2/sub3/sub4/project/public/index.html b/shared/pages/nested/sub1/sub2/sub3/sub4/project/public/index.html new file mode 100644 index 000000000..b2d525b29 --- /dev/null +++ b/shared/pages/nested/sub1/sub2/sub3/sub4/project/public/index.html @@ -0,0 +1 @@ +index \ No newline at end of file diff --git a/shared/pages/nested/sub1/sub2/sub3/sub4/sub5/project/public/index.html b/shared/pages/nested/sub1/sub2/sub3/sub4/sub5/project/public/index.html new file mode 100644 index 000000000..b2d525b29 --- /dev/null +++ b/shared/pages/nested/sub1/sub2/sub3/sub4/sub5/project/public/index.html @@ -0,0 +1 @@ +index \ No newline at end of file -- GitLab From b6a21d7ab4a9677bcd3cb8d03bb092fd171f26de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Trzci=C5=84ski?= Date: Wed, 24 Apr 2019 13:41:13 +0200 Subject: [PATCH 2/3] Fix `auth_test.go` --- internal/auth/auth_test.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 012089975..f15d10cf2 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -12,13 +12,8 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-pages/internal/auth" - "gitlab.com/gitlab-org/gitlab-pages/internal/domain" ) -func findDomain(host string) *domain.D { - return nil -} - func createAuth(t *testing.T) *auth.Auth { return auth.New("pages.gitlab-example.com", "something-very-secret", @@ -36,7 +31,7 @@ func TestTryAuthenticate(t *testing.T) { require.NoError(t, err) r := &http.Request{URL: reqURL} - assert.Equal(t, false, auth.TryAuthenticate(result, r, findDomain)) + assert.Equal(t, false, auth.TryAuthenticate(result, r, nil)) } func TestTryAuthenticateWithError(t *testing.T) { @@ -47,7 +42,7 @@ func TestTryAuthenticateWithError(t *testing.T) { require.NoError(t, err) r := &http.Request{URL: reqURL} - assert.Equal(t, true, auth.TryAuthenticate(result, r, findDomain)) + assert.Equal(t, true, auth.TryAuthenticate(result, r, nil)) assert.Equal(t, 401, result.Code) } @@ -64,7 +59,7 @@ func TestTryAuthenticateWithCodeButInvalidState(t *testing.T) { session.Values["state"] = "state" session.Save(r, result) - assert.Equal(t, true, auth.TryAuthenticate(result, r, findDomain)) + assert.Equal(t, true, auth.TryAuthenticate(result, r, nil)) assert.Equal(t, 401, result.Code) } @@ -106,7 +101,7 @@ func TestTryAuthenticateWithCodeAndState(t *testing.T) { session.Values["state"] = "state" session.Save(r, result) - assert.Equal(t, true, auth.TryAuthenticate(result, r, findDomain)) + assert.Equal(t, true, auth.TryAuthenticate(result, r, nil)) assert.Equal(t, 302, result.Code) assert.Equal(t, "http://pages.gitlab-example.com/project/", result.Header().Get("Location")) } -- GitLab From 14c91b5204ac7d3e1301e8649f48ff6051bb1d9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Trzci=C5=84ski?= Date: Wed, 24 Apr 2019 13:41:20 +0200 Subject: [PATCH 3/3] Vendor `go-cache` --- .../patrickmn/go-cache/CONTRIBUTORS | 9 + vendor/github.com/patrickmn/go-cache/LICENSE | 19 + .../github.com/patrickmn/go-cache/README.md | 83 ++ vendor/github.com/patrickmn/go-cache/cache.go | 1161 +++++++++++++++++ .../github.com/patrickmn/go-cache/sharded.go | 192 +++ vendor/vendor.json | 6 + 6 files changed, 1470 insertions(+) create mode 100644 vendor/github.com/patrickmn/go-cache/CONTRIBUTORS create mode 100644 vendor/github.com/patrickmn/go-cache/LICENSE create mode 100644 vendor/github.com/patrickmn/go-cache/README.md create mode 100644 vendor/github.com/patrickmn/go-cache/cache.go create mode 100644 vendor/github.com/patrickmn/go-cache/sharded.go diff --git a/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS b/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS new file mode 100644 index 000000000..2b16e9974 --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/CONTRIBUTORS @@ -0,0 +1,9 @@ +This is a list of people who have contributed code to go-cache. They, or their +employers, are the copyright holders of the contributed code. Contributed code +is subject to the license restrictions listed in LICENSE (as they were when the +code was contributed.) + +Dustin Sallings +Jason Mooberry +Sergey Shepelev +Alex Edwards diff --git a/vendor/github.com/patrickmn/go-cache/LICENSE b/vendor/github.com/patrickmn/go-cache/LICENSE new file mode 100644 index 000000000..30b9cade0 --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2012-2018 Patrick Mylund Nielsen and the go-cache contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/vendor/github.com/patrickmn/go-cache/README.md b/vendor/github.com/patrickmn/go-cache/README.md new file mode 100644 index 000000000..c5789cc66 --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/README.md @@ -0,0 +1,83 @@ +# go-cache + +go-cache is an in-memory key:value store/cache similar to memcached that is +suitable for applications running on a single machine. Its major advantage is +that, being essentially a thread-safe `map[string]interface{}` with expiration +times, it doesn't need to serialize or transmit its contents over the network. + +Any object can be stored, for a given duration or forever, and the cache can be +safely used by multiple goroutines. + +Although go-cache isn't meant to be used as a persistent datastore, the entire +cache can be saved to and loaded from a file (using `c.Items()` to retrieve the +items map to serialize, and `NewFrom()` to create a cache from a deserialized +one) to recover from downtime quickly. (See the docs for `NewFrom()` for caveats.) + +### Installation + +`go get github.com/patrickmn/go-cache` + +### Usage + +```go +import ( + "fmt" + "github.com/patrickmn/go-cache" + "time" +) + +func main() { + // Create a cache with a default expiration time of 5 minutes, and which + // purges expired items every 10 minutes + c := cache.New(5*time.Minute, 10*time.Minute) + + // Set the value of the key "foo" to "bar", with the default expiration time + c.Set("foo", "bar", cache.DefaultExpiration) + + // Set the value of the key "baz" to 42, with no expiration time + // (the item won't be removed until it is re-set, or removed using + // c.Delete("baz") + c.Set("baz", 42, cache.NoExpiration) + + // Get the string associated with the key "foo" from the cache + foo, found := c.Get("foo") + if found { + fmt.Println(foo) + } + + // Since Go is statically typed, and cache values can be anything, type + // assertion is needed when values are being passed to functions that don't + // take arbitrary types, (i.e. interface{}). The simplest way to do this for + // values which will only be used once--e.g. for passing to another + // function--is: + foo, found := c.Get("foo") + if found { + MyFunction(foo.(string)) + } + + // This gets tedious if the value is used several times in the same function. + // You might do either of the following instead: + if x, found := c.Get("foo"); found { + foo := x.(string) + // ... + } + // or + var foo string + if x, found := c.Get("foo"); found { + foo = x.(string) + } + // ... + // foo can then be passed around freely as a string + + // Want performance? Store pointers! + c.Set("foo", &MyStruct, cache.DefaultExpiration) + if x, found := c.Get("foo"); found { + foo := x.(*MyStruct) + // ... + } +} +``` + +### Reference + +`godoc` or [http://godoc.org/github.com/patrickmn/go-cache](http://godoc.org/github.com/patrickmn/go-cache) diff --git a/vendor/github.com/patrickmn/go-cache/cache.go b/vendor/github.com/patrickmn/go-cache/cache.go new file mode 100644 index 000000000..db88d2f2c --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/cache.go @@ -0,0 +1,1161 @@ +package cache + +import ( + "encoding/gob" + "fmt" + "io" + "os" + "runtime" + "sync" + "time" +) + +type Item struct { + Object interface{} + Expiration int64 +} + +// Returns true if the item has expired. +func (item Item) Expired() bool { + if item.Expiration == 0 { + return false + } + return time.Now().UnixNano() > item.Expiration +} + +const ( + // For use with functions that take an expiration time. + NoExpiration time.Duration = -1 + // For use with functions that take an expiration time. Equivalent to + // passing in the same expiration duration as was given to New() or + // NewFrom() when the cache was created (e.g. 5 minutes.) + DefaultExpiration time.Duration = 0 +) + +type Cache struct { + *cache + // If this is confusing, see the comment at the bottom of New() +} + +type cache struct { + defaultExpiration time.Duration + items map[string]Item + mu sync.RWMutex + onEvicted func(string, interface{}) + janitor *janitor +} + +// Add an item to the cache, replacing any existing item. If the duration is 0 +// (DefaultExpiration), the cache's default expiration time is used. If it is -1 +// (NoExpiration), the item never expires. +func (c *cache) Set(k string, x interface{}, d time.Duration) { + // "Inlining" of set + var e int64 + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + e = time.Now().Add(d).UnixNano() + } + c.mu.Lock() + c.items[k] = Item{ + Object: x, + Expiration: e, + } + // TODO: Calls to mu.Unlock are currently not deferred because defer + // adds ~200 ns (as of go1.) + c.mu.Unlock() +} + +func (c *cache) set(k string, x interface{}, d time.Duration) { + var e int64 + if d == DefaultExpiration { + d = c.defaultExpiration + } + if d > 0 { + e = time.Now().Add(d).UnixNano() + } + c.items[k] = Item{ + Object: x, + Expiration: e, + } +} + +// Add an item to the cache, replacing any existing item, using the default +// expiration. +func (c *cache) SetDefault(k string, x interface{}) { + c.Set(k, x, DefaultExpiration) +} + +// Add an item to the cache only if an item doesn't already exist for the given +// key, or if the existing item has expired. Returns an error otherwise. +func (c *cache) Add(k string, x interface{}, d time.Duration) error { + c.mu.Lock() + _, found := c.get(k) + if found { + c.mu.Unlock() + return fmt.Errorf("Item %s already exists", k) + } + c.set(k, x, d) + c.mu.Unlock() + return nil +} + +// Set a new value for the cache key only if it already exists, and the existing +// item hasn't expired. Returns an error otherwise. +func (c *cache) Replace(k string, x interface{}, d time.Duration) error { + c.mu.Lock() + _, found := c.get(k) + if !found { + c.mu.Unlock() + return fmt.Errorf("Item %s doesn't exist", k) + } + c.set(k, x, d) + c.mu.Unlock() + return nil +} + +// Get an item from the cache. Returns the item or nil, and a bool indicating +// whether the key was found. +func (c *cache) Get(k string) (interface{}, bool) { + c.mu.RLock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, false + } + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, false + } + } + c.mu.RUnlock() + return item.Object, true +} + +// GetWithExpiration returns an item and its expiration time from the cache. +// It returns the item or nil, the expiration time if one is set (if the item +// never expires a zero value for time.Time is returned), and a bool indicating +// whether the key was found. +func (c *cache) GetWithExpiration(k string) (interface{}, time.Time, bool) { + c.mu.RLock() + // "Inlining" of get and Expired + item, found := c.items[k] + if !found { + c.mu.RUnlock() + return nil, time.Time{}, false + } + + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + c.mu.RUnlock() + return nil, time.Time{}, false + } + + // Return the item and the expiration time + c.mu.RUnlock() + return item.Object, time.Unix(0, item.Expiration), true + } + + // If expiration <= 0 (i.e. no expiration time set) then return the item + // and a zeroed time.Time + c.mu.RUnlock() + return item.Object, time.Time{}, true +} + +func (c *cache) get(k string) (interface{}, bool) { + item, found := c.items[k] + if !found { + return nil, false + } + // "Inlining" of Expired + if item.Expiration > 0 { + if time.Now().UnixNano() > item.Expiration { + return nil, false + } + } + return item.Object, true +} + +// Increment an item of type int, int8, int16, int32, int64, uintptr, uint, +// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the +// item's value is not an integer, if it was not found, or if it is not +// possible to increment it by n. To retrieve the incremented value, use one +// of the specialized methods, e.g. IncrementInt64. +func (c *cache) Increment(k string, n int64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case int: + v.Object = v.Object.(int) + int(n) + case int8: + v.Object = v.Object.(int8) + int8(n) + case int16: + v.Object = v.Object.(int16) + int16(n) + case int32: + v.Object = v.Object.(int32) + int32(n) + case int64: + v.Object = v.Object.(int64) + n + case uint: + v.Object = v.Object.(uint) + uint(n) + case uintptr: + v.Object = v.Object.(uintptr) + uintptr(n) + case uint8: + v.Object = v.Object.(uint8) + uint8(n) + case uint16: + v.Object = v.Object.(uint16) + uint16(n) + case uint32: + v.Object = v.Object.(uint32) + uint32(n) + case uint64: + v.Object = v.Object.(uint64) + uint64(n) + case float32: + v.Object = v.Object.(float32) + float32(n) + case float64: + v.Object = v.Object.(float64) + float64(n) + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s is not an integer", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Increment an item of type float32 or float64 by n. Returns an error if the +// item's value is not floating point, if it was not found, or if it is not +// possible to increment it by n. Pass a negative number to decrement the +// value. To retrieve the incremented value, use one of the specialized methods, +// e.g. IncrementFloat64. +func (c *cache) IncrementFloat(k string, n float64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case float32: + v.Object = v.Object.(float32) + float32(n) + case float64: + v.Object = v.Object.(float64) + n + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s does not have type float32 or float64", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Increment an item of type int by n. Returns an error if the item's value is +// not an int, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt(k string, n int) (int, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int8 by n. Returns an error if the item's value is +// not an int8, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt8(k string, n int8) (int8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int8", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int16 by n. Returns an error if the item's value is +// not an int16, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt16(k string, n int16) (int16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int16", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int32 by n. Returns an error if the item's value is +// not an int32, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt32(k string, n int32) (int32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type int64 by n. Returns an error if the item's value is +// not an int64, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementInt64(k string, n int64) (int64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint by n. Returns an error if the item's value is +// not an uint, or if it was not found. If there is no error, the incremented +// value is returned. +func (c *cache) IncrementUint(k string, n uint) (uint, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uintptr by n. Returns an error if the item's value +// is not an uintptr, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUintptr(k string, n uintptr) (uintptr, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uintptr) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uintptr", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint8 by n. Returns an error if the item's value +// is not an uint8, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint8(k string, n uint8) (uint8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint8", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint16 by n. Returns an error if the item's value +// is not an uint16, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint16(k string, n uint16) (uint16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint16", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint32 by n. Returns an error if the item's value +// is not an uint32, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint32(k string, n uint32) (uint32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type uint64 by n. Returns an error if the item's value +// is not an uint64, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementUint64(k string, n uint64) (uint64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type float32 by n. Returns an error if the item's value +// is not an float32, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementFloat32(k string, n float32) (float32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float32", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Increment an item of type float64 by n. Returns an error if the item's value +// is not an float64, or if it was not found. If there is no error, the +// incremented value is returned. +func (c *cache) IncrementFloat64(k string, n float64) (float64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float64", k) + } + nv := rv + n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int, int8, int16, int32, int64, uintptr, uint, +// uint8, uint32, or uint64, float32 or float64 by n. Returns an error if the +// item's value is not an integer, if it was not found, or if it is not +// possible to decrement it by n. To retrieve the decremented value, use one +// of the specialized methods, e.g. DecrementInt64. +func (c *cache) Decrement(k string, n int64) error { + // TODO: Implement Increment and Decrement more cleanly. + // (Cannot do Increment(k, n*-1) for uints.) + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item not found") + } + switch v.Object.(type) { + case int: + v.Object = v.Object.(int) - int(n) + case int8: + v.Object = v.Object.(int8) - int8(n) + case int16: + v.Object = v.Object.(int16) - int16(n) + case int32: + v.Object = v.Object.(int32) - int32(n) + case int64: + v.Object = v.Object.(int64) - n + case uint: + v.Object = v.Object.(uint) - uint(n) + case uintptr: + v.Object = v.Object.(uintptr) - uintptr(n) + case uint8: + v.Object = v.Object.(uint8) - uint8(n) + case uint16: + v.Object = v.Object.(uint16) - uint16(n) + case uint32: + v.Object = v.Object.(uint32) - uint32(n) + case uint64: + v.Object = v.Object.(uint64) - uint64(n) + case float32: + v.Object = v.Object.(float32) - float32(n) + case float64: + v.Object = v.Object.(float64) - float64(n) + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s is not an integer", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Decrement an item of type float32 or float64 by n. Returns an error if the +// item's value is not floating point, if it was not found, or if it is not +// possible to decrement it by n. Pass a negative number to decrement the +// value. To retrieve the decremented value, use one of the specialized methods, +// e.g. DecrementFloat64. +func (c *cache) DecrementFloat(k string, n float64) error { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return fmt.Errorf("Item %s not found", k) + } + switch v.Object.(type) { + case float32: + v.Object = v.Object.(float32) - float32(n) + case float64: + v.Object = v.Object.(float64) - n + default: + c.mu.Unlock() + return fmt.Errorf("The value for %s does not have type float32 or float64", k) + } + c.items[k] = v + c.mu.Unlock() + return nil +} + +// Decrement an item of type int by n. Returns an error if the item's value is +// not an int, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt(k string, n int) (int, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int8 by n. Returns an error if the item's value is +// not an int8, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt8(k string, n int8) (int8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int8", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int16 by n. Returns an error if the item's value is +// not an int16, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt16(k string, n int16) (int16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int16", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int32 by n. Returns an error if the item's value is +// not an int32, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt32(k string, n int32) (int32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type int64 by n. Returns an error if the item's value is +// not an int64, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementInt64(k string, n int64) (int64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(int64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an int64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint by n. Returns an error if the item's value is +// not an uint, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementUint(k string, n uint) (uint, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uintptr by n. Returns an error if the item's value +// is not an uintptr, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUintptr(k string, n uintptr) (uintptr, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uintptr) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uintptr", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint8 by n. Returns an error if the item's value is +// not an uint8, or if it was not found. If there is no error, the decremented +// value is returned. +func (c *cache) DecrementUint8(k string, n uint8) (uint8, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint8) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint8", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint16 by n. Returns an error if the item's value +// is not an uint16, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint16(k string, n uint16) (uint16, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint16) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint16", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint32 by n. Returns an error if the item's value +// is not an uint32, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint32(k string, n uint32) (uint32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type uint64 by n. Returns an error if the item's value +// is not an uint64, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementUint64(k string, n uint64) (uint64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(uint64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an uint64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type float32 by n. Returns an error if the item's value +// is not an float32, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementFloat32(k string, n float32) (float32, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float32) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float32", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Decrement an item of type float64 by n. Returns an error if the item's value +// is not an float64, or if it was not found. If there is no error, the +// decremented value is returned. +func (c *cache) DecrementFloat64(k string, n float64) (float64, error) { + c.mu.Lock() + v, found := c.items[k] + if !found || v.Expired() { + c.mu.Unlock() + return 0, fmt.Errorf("Item %s not found", k) + } + rv, ok := v.Object.(float64) + if !ok { + c.mu.Unlock() + return 0, fmt.Errorf("The value for %s is not an float64", k) + } + nv := rv - n + v.Object = nv + c.items[k] = v + c.mu.Unlock() + return nv, nil +} + +// Delete an item from the cache. Does nothing if the key is not in the cache. +func (c *cache) Delete(k string) { + c.mu.Lock() + v, evicted := c.delete(k) + c.mu.Unlock() + if evicted { + c.onEvicted(k, v) + } +} + +func (c *cache) delete(k string) (interface{}, bool) { + if c.onEvicted != nil { + if v, found := c.items[k]; found { + delete(c.items, k) + return v.Object, true + } + } + delete(c.items, k) + return nil, false +} + +type keyAndValue struct { + key string + value interface{} +} + +// Delete all expired items from the cache. +func (c *cache) DeleteExpired() { + var evictedItems []keyAndValue + now := time.Now().UnixNano() + c.mu.Lock() + for k, v := range c.items { + // "Inlining" of expired + if v.Expiration > 0 && now > v.Expiration { + ov, evicted := c.delete(k) + if evicted { + evictedItems = append(evictedItems, keyAndValue{k, ov}) + } + } + } + c.mu.Unlock() + for _, v := range evictedItems { + c.onEvicted(v.key, v.value) + } +} + +// Sets an (optional) function that is called with the key and value when an +// item is evicted from the cache. (Including when it is deleted manually, but +// not when it is overwritten.) Set to nil to disable. +func (c *cache) OnEvicted(f func(string, interface{})) { + c.mu.Lock() + c.onEvicted = f + c.mu.Unlock() +} + +// Write the cache's items (using Gob) to an io.Writer. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) Save(w io.Writer) (err error) { + enc := gob.NewEncoder(w) + defer func() { + if x := recover(); x != nil { + err = fmt.Errorf("Error registering item types with Gob library") + } + }() + c.mu.RLock() + defer c.mu.RUnlock() + for _, v := range c.items { + gob.Register(v.Object) + } + err = enc.Encode(&c.items) + return +} + +// Save the cache's items to the given filename, creating the file if it +// doesn't exist, and overwriting it if it does. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) SaveFile(fname string) error { + fp, err := os.Create(fname) + if err != nil { + return err + } + err = c.Save(fp) + if err != nil { + fp.Close() + return err + } + return fp.Close() +} + +// Add (Gob-serialized) cache items from an io.Reader, excluding any items with +// keys that already exist (and haven't expired) in the current cache. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) Load(r io.Reader) error { + dec := gob.NewDecoder(r) + items := map[string]Item{} + err := dec.Decode(&items) + if err == nil { + c.mu.Lock() + defer c.mu.Unlock() + for k, v := range items { + ov, found := c.items[k] + if !found || ov.Expired() { + c.items[k] = v + } + } + } + return err +} + +// Load and add cache items from the given filename, excluding any items with +// keys that already exist in the current cache. +// +// NOTE: This method is deprecated in favor of c.Items() and NewFrom() (see the +// documentation for NewFrom().) +func (c *cache) LoadFile(fname string) error { + fp, err := os.Open(fname) + if err != nil { + return err + } + err = c.Load(fp) + if err != nil { + fp.Close() + return err + } + return fp.Close() +} + +// Copies all unexpired items in the cache into a new map and returns it. +func (c *cache) Items() map[string]Item { + c.mu.RLock() + defer c.mu.RUnlock() + m := make(map[string]Item, len(c.items)) + now := time.Now().UnixNano() + for k, v := range c.items { + // "Inlining" of Expired + if v.Expiration > 0 { + if now > v.Expiration { + continue + } + } + m[k] = v + } + return m +} + +// Returns the number of items in the cache. This may include items that have +// expired, but have not yet been cleaned up. +func (c *cache) ItemCount() int { + c.mu.RLock() + n := len(c.items) + c.mu.RUnlock() + return n +} + +// Delete all items from the cache. +func (c *cache) Flush() { + c.mu.Lock() + c.items = map[string]Item{} + c.mu.Unlock() +} + +type janitor struct { + Interval time.Duration + stop chan bool +} + +func (j *janitor) Run(c *cache) { + ticker := time.NewTicker(j.Interval) + for { + select { + case <-ticker.C: + c.DeleteExpired() + case <-j.stop: + ticker.Stop() + return + } + } +} + +func stopJanitor(c *Cache) { + c.janitor.stop <- true +} + +func runJanitor(c *cache, ci time.Duration) { + j := &janitor{ + Interval: ci, + stop: make(chan bool), + } + c.janitor = j + go j.Run(c) +} + +func newCache(de time.Duration, m map[string]Item) *cache { + if de == 0 { + de = -1 + } + c := &cache{ + defaultExpiration: de, + items: m, + } + return c +} + +func newCacheWithJanitor(de time.Duration, ci time.Duration, m map[string]Item) *Cache { + c := newCache(de, m) + // This trick ensures that the janitor goroutine (which--granted it + // was enabled--is running DeleteExpired on c forever) does not keep + // the returned C object from being garbage collected. When it is + // garbage collected, the finalizer stops the janitor goroutine, after + // which c can be collected. + C := &Cache{c} + if ci > 0 { + runJanitor(c, ci) + runtime.SetFinalizer(C, stopJanitor) + } + return C +} + +// Return a new cache with a given default expiration duration and cleanup +// interval. If the expiration duration is less than one (or NoExpiration), +// the items in the cache never expire (by default), and must be deleted +// manually. If the cleanup interval is less than one, expired items are not +// deleted from the cache before calling c.DeleteExpired(). +func New(defaultExpiration, cleanupInterval time.Duration) *Cache { + items := make(map[string]Item) + return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) +} + +// Return a new cache with a given default expiration duration and cleanup +// interval. If the expiration duration is less than one (or NoExpiration), +// the items in the cache never expire (by default), and must be deleted +// manually. If the cleanup interval is less than one, expired items are not +// deleted from the cache before calling c.DeleteExpired(). +// +// NewFrom() also accepts an items map which will serve as the underlying map +// for the cache. This is useful for starting from a deserialized cache +// (serialized using e.g. gob.Encode() on c.Items()), or passing in e.g. +// make(map[string]Item, 500) to improve startup performance when the cache +// is expected to reach a certain minimum size. +// +// Only the cache's methods synchronize access to this map, so it is not +// recommended to keep any references to the map around after creating a cache. +// If need be, the map can be accessed at a later point using c.Items() (subject +// to the same caveat.) +// +// Note regarding serialization: When using e.g. gob, make sure to +// gob.Register() the individual types stored in the cache before encoding a +// map retrieved with c.Items(), and to register those same types before +// decoding a blob containing an items map. +func NewFrom(defaultExpiration, cleanupInterval time.Duration, items map[string]Item) *Cache { + return newCacheWithJanitor(defaultExpiration, cleanupInterval, items) +} diff --git a/vendor/github.com/patrickmn/go-cache/sharded.go b/vendor/github.com/patrickmn/go-cache/sharded.go new file mode 100644 index 000000000..bcc0538bc --- /dev/null +++ b/vendor/github.com/patrickmn/go-cache/sharded.go @@ -0,0 +1,192 @@ +package cache + +import ( + "crypto/rand" + "math" + "math/big" + insecurerand "math/rand" + "os" + "runtime" + "time" +) + +// This is an experimental and unexported (for now) attempt at making a cache +// with better algorithmic complexity than the standard one, namely by +// preventing write locks of the entire cache when an item is added. As of the +// time of writing, the overhead of selecting buckets results in cache +// operations being about twice as slow as for the standard cache with small +// total cache sizes, and faster for larger ones. +// +// See cache_test.go for a few benchmarks. + +type unexportedShardedCache struct { + *shardedCache +} + +type shardedCache struct { + seed uint32 + m uint32 + cs []*cache + janitor *shardedJanitor +} + +// djb2 with better shuffling. 5x faster than FNV with the hash.Hash overhead. +func djb33(seed uint32, k string) uint32 { + var ( + l = uint32(len(k)) + d = 5381 + seed + l + i = uint32(0) + ) + // Why is all this 5x faster than a for loop? + if l >= 4 { + for i < l-4 { + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + d = (d * 33) ^ uint32(k[i+2]) + d = (d * 33) ^ uint32(k[i+3]) + i += 4 + } + } + switch l - i { + case 1: + case 2: + d = (d * 33) ^ uint32(k[i]) + case 3: + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + case 4: + d = (d * 33) ^ uint32(k[i]) + d = (d * 33) ^ uint32(k[i+1]) + d = (d * 33) ^ uint32(k[i+2]) + } + return d ^ (d >> 16) +} + +func (sc *shardedCache) bucket(k string) *cache { + return sc.cs[djb33(sc.seed, k)%sc.m] +} + +func (sc *shardedCache) Set(k string, x interface{}, d time.Duration) { + sc.bucket(k).Set(k, x, d) +} + +func (sc *shardedCache) Add(k string, x interface{}, d time.Duration) error { + return sc.bucket(k).Add(k, x, d) +} + +func (sc *shardedCache) Replace(k string, x interface{}, d time.Duration) error { + return sc.bucket(k).Replace(k, x, d) +} + +func (sc *shardedCache) Get(k string) (interface{}, bool) { + return sc.bucket(k).Get(k) +} + +func (sc *shardedCache) Increment(k string, n int64) error { + return sc.bucket(k).Increment(k, n) +} + +func (sc *shardedCache) IncrementFloat(k string, n float64) error { + return sc.bucket(k).IncrementFloat(k, n) +} + +func (sc *shardedCache) Decrement(k string, n int64) error { + return sc.bucket(k).Decrement(k, n) +} + +func (sc *shardedCache) Delete(k string) { + sc.bucket(k).Delete(k) +} + +func (sc *shardedCache) DeleteExpired() { + for _, v := range sc.cs { + v.DeleteExpired() + } +} + +// Returns the items in the cache. This may include items that have expired, +// but have not yet been cleaned up. If this is significant, the Expiration +// fields of the items should be checked. Note that explicit synchronization +// is needed to use a cache and its corresponding Items() return values at +// the same time, as the maps are shared. +func (sc *shardedCache) Items() []map[string]Item { + res := make([]map[string]Item, len(sc.cs)) + for i, v := range sc.cs { + res[i] = v.Items() + } + return res +} + +func (sc *shardedCache) Flush() { + for _, v := range sc.cs { + v.Flush() + } +} + +type shardedJanitor struct { + Interval time.Duration + stop chan bool +} + +func (j *shardedJanitor) Run(sc *shardedCache) { + j.stop = make(chan bool) + tick := time.Tick(j.Interval) + for { + select { + case <-tick: + sc.DeleteExpired() + case <-j.stop: + return + } + } +} + +func stopShardedJanitor(sc *unexportedShardedCache) { + sc.janitor.stop <- true +} + +func runShardedJanitor(sc *shardedCache, ci time.Duration) { + j := &shardedJanitor{ + Interval: ci, + } + sc.janitor = j + go j.Run(sc) +} + +func newShardedCache(n int, de time.Duration) *shardedCache { + max := big.NewInt(0).SetUint64(uint64(math.MaxUint32)) + rnd, err := rand.Int(rand.Reader, max) + var seed uint32 + if err != nil { + os.Stderr.Write([]byte("WARNING: go-cache's newShardedCache failed to read from the system CSPRNG (/dev/urandom or equivalent.) Your system's security may be compromised. Continuing with an insecure seed.\n")) + seed = insecurerand.Uint32() + } else { + seed = uint32(rnd.Uint64()) + } + sc := &shardedCache{ + seed: seed, + m: uint32(n), + cs: make([]*cache, n), + } + for i := 0; i < n; i++ { + c := &cache{ + defaultExpiration: de, + items: map[string]Item{}, + } + sc.cs[i] = c + } + return sc +} + +func unexportedNewSharded(defaultExpiration, cleanupInterval time.Duration, shards int) *unexportedShardedCache { + if defaultExpiration == 0 { + defaultExpiration = -1 + } + sc := newShardedCache(shards, defaultExpiration) + SC := &unexportedShardedCache{sc} + if cleanupInterval > 0 { + runShardedJanitor(sc, cleanupInterval) + runtime.SetFinalizer(SC, stopShardedJanitor) + } + return SC +} diff --git a/vendor/vendor.json b/vendor/vendor.json index 1df7f7efb..f0c2cb3f8 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -196,6 +196,12 @@ "path": "github.com/namsral/flag", "revision": "67f268f20922975c067ed799e4be6bacf152208c" }, + { + "checksumSHA1": "W8mzTLRjnooGtHwWaxSX8eq8hlY=", + "path": "github.com/patrickmn/go-cache", + "revision": "5633e0862627c011927fa39556acae8b1f1df58a", + "revisionTime": "2018-08-15T05:31:27Z" + }, { "checksumSHA1": "ljd3FhYRJ91cLZz3wsH9BQQ2JbA=", "path": "github.com/pkg/errors", -- GitLab