diff --git a/app.go b/app.go index ddafb0bf254e1fb7cd458e56e7603e0b514e0754..d8a48b1cff1c1a809e38f26c29a3710ec4e9e473 100644 --- a/app.go +++ b/app.go @@ -13,6 +13,9 @@ import ( "syscall" "time" + "gitlab.com/gitlab-org/gitlab-pages/internal/feature" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + ghandlers "github.com/gorilla/handlers" "github.com/hashicorp/go-multierror" "github.com/rs/cors" @@ -51,6 +54,7 @@ var ( type theApp struct { config *cfg.Config source source.Source + tlsConfig *cryptotls.Config Artifact *artifact.Artifact Auth *auth.Auth Handlers *handlers.Handlers @@ -62,7 +66,8 @@ func (a *theApp) isReady() bool { return true } -func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { +func (a *theApp) GetCertificate(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate, error) { + log.Info("GetCertificate called") if ch.ServerName == "" { return nil, nil } @@ -75,6 +80,31 @@ func (a *theApp) ServeTLS(ch *cryptotls.ClientHelloInfo) (*cryptotls.Certificate return nil, nil } +func (a *theApp) getTLSConfig() (*cryptotls.Config, error) { + if a.tlsConfig != nil { + return a.tlsConfig, nil + } + TLSRateLimiter := ratelimiter.New( + "tls", + ratelimiter.WithCacheMaxSize(ratelimiter.DefaultSourceIPCacheSize), + ratelimiter.WithCachedEntriesMetric(metrics.RateLimitDomainTLSCachedEntries), + ratelimiter.WithCachedRequestsMetric(metrics.RateLimitDomainTLSCacheRequests), + ratelimiter.WithBlockedCountMetric(metrics.RateLimitDomainTLSBlockedCount), + ratelimiter.WithLimitPerSecond(a.config.RateLimit.DomainTLSLimitPerSecond), + ratelimiter.WithBurstSize(a.config.RateLimit.DomainTLSBurst), + ratelimiter.WithEnforce(feature.EnforceDomainTLSRateLimits.Enabled()), + ) + + getCertificate := TLSRateLimiter.GetCertificateMiddleware(a.GetCertificate) + + tlsConfig, err := tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, getCertificate, + a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion) + + a.tlsConfig = tlsConfig + + return a.tlsConfig, err +} + func (a *theApp) redirectToHTTPS(w http.ResponseWriter, r *http.Request, statusCode int) { u := *r.URL u.Scheme = request.SchemeHTTPS @@ -306,7 +336,7 @@ func (a *theApp) Run() { // Listen for HTTPS for _, addr := range a.config.ListenHTTPSStrings.Split() { - tlsConfig, err := a.TLSConfig() + tlsConfig, err := a.getTLSConfig() if err != nil { log.WithError(err).Fatal("Unable to retrieve tls config") } @@ -334,7 +364,7 @@ func (a *theApp) Run() { // Listen for HTTPS PROXYv2 requests for _, addr := range a.config.ListenHTTPSProxyv2Strings.Split() { - tlsConfig, err := a.TLSConfig() + tlsConfig, err := a.getTLSConfig() if err != nil { log.WithError(err).Fatal("Unable to retrieve tls config") } @@ -478,11 +508,6 @@ func fatal(err error, message string) { log.WithError(err).Fatal(message) } -func (a *theApp) TLSConfig() (*cryptotls.Config, error) { - return tls.Create(a.config.General.RootCertificate, a.config.General.RootKey, a.ServeTLS, - a.config.General.InsecureCiphers, a.config.TLS.MinVersion, a.config.TLS.MaxVersion) -} - // handlePanicMiddleware logs and captures the recover() information from any panic func handlePanicMiddleware(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/config/config.go b/internal/config/config.go index 3bb7b1262509cd689524223a2145571477009e10..bcd0cfc7b2804934cb4c422ebdc889a1c832f6a6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -58,10 +58,12 @@ type General struct { // RateLimit config struct type RateLimit struct { - SourceIPLimitPerSecond float64 - SourceIPBurst int - DomainLimitPerSecond float64 - DomainBurst int + SourceIPLimitPerSecond float64 + SourceIPBurst int + DomainLimitPerSecond float64 + DomainBurst int + DomainTLSLimitPerSecond float64 + DomainTLSBurst int } // ArtifactsServer groups settings related to configuring Artifacts @@ -179,10 +181,12 @@ func loadConfig() (*Config, error) { ShowVersion: *showVersion, }, RateLimit: RateLimit{ - SourceIPLimitPerSecond: *rateLimitSourceIP, - SourceIPBurst: *rateLimitSourceIPBurst, - DomainLimitPerSecond: *rateLimitDomain, - DomainBurst: *rateLimitDomainBurst, + SourceIPLimitPerSecond: *rateLimitSourceIP, + SourceIPBurst: *rateLimitSourceIPBurst, + DomainLimitPerSecond: *rateLimitDomain, + DomainBurst: *rateLimitDomainBurst, + DomainTLSLimitPerSecond: *rateLimitDomainTLS, + DomainTLSBurst: *rateLimitDomainTLSBurst, }, GitLab: GitLab{ ClientHTTPTimeout: *gitlabClientHTTPTimeout, diff --git a/internal/config/flags.go b/internal/config/flags.go index 93228827ce614dde91a752d74cfcd52835f5c2a6..bd40f3626a967e4e9dd090cb933fdd6067c5819d 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -19,6 +19,8 @@ var ( rateLimitSourceIPBurst = flag.Int("rate-limit-source-ip-burst", 100, "Rate limit per source IP maximum burst allowed per second") rateLimitDomain = flag.Float64("rate-limit-domain", 0.0, "Rate limit per domain in number of requests per second, 0 means is disabled") rateLimitDomainBurst = flag.Int("rate-limit-domain-burst", 100, "Rate limit per domain maximum burst allowed per second") + rateLimitDomainTLS = flag.Float64("rate-limit-domain-tls", 0.0, "Rate limit per domain in number new TLS connections per second, 0 means is disabled") + rateLimitDomainTLSBurst = flag.Int("rate-limit-domain-tls-burst", 100, "Rate limit per domain maximum burst of TLS connections allowed per second") artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") diff --git a/internal/feature/feature.go b/internal/feature/feature.go index 81eef9a0965f1897fe83c172204e4302829b28a9..3c3f1ee36ebc32b46fe1e462a3b4950b9a3527d4 100644 --- a/internal/feature/feature.go +++ b/internal/feature/feature.go @@ -19,6 +19,12 @@ var EnforceDomainRateLimits = Feature{ EnvVariable: "FF_ENFORCE_DOMAIN_RATE_LIMITS", } +// EnforceDomainTLSRateLimits enforces domain rate limits on establishing new TLS connections +// TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/655 +var EnforceDomainTLSRateLimits = Feature{ + EnvVariable: "FF_ENFORCE_DOMAIN_TLS_RATE_LIMITS", +} + // RedirectsPlaceholders enables support for placeholders in redirects file // TODO: remove https://gitlab.com/gitlab-org/gitlab-pages/-/issues/620 var RedirectsPlaceholders = Feature{ diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go index af7b0881dc18f9e594ac19e304a3db9ac0df1ab8..2faaac08775f6b4a57c1106837ccd0786f04db72 100644 --- a/internal/ratelimiter/middleware.go +++ b/internal/ratelimiter/middleware.go @@ -8,7 +8,6 @@ import ( "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/log" - "gitlab.com/gitlab-org/gitlab-pages/internal/feature" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) @@ -34,7 +33,7 @@ func (rl *RateLimiter) Middleware(handler http.Handler) http.Handler { rl.logRateLimitedRequest(r) if rl.blockedCount != nil { - rl.blockedCount.WithLabelValues(strconv.FormatBool(feature.EnforceIPRateLimits.Enabled())).Inc() + rl.blockedCount.WithLabelValues(strconv.FormatBool(rl.enforce)).Inc() } if rl.enforce { @@ -59,7 +58,7 @@ func (rl *RateLimiter) logRateLimitedRequest(r *http.Request) { "x_forwarded_proto": r.Header.Get(headerXForwardedProto), "x_forwarded_for": r.Header.Get(headerXForwardedFor), "gitlab_real_ip": r.Header.Get(headerGitLabRealIP), - "rate_limiter_enabled": feature.EnforceIPRateLimits.Enabled(), + "rate_limiter_enabled": rl.enforce, "rate_limiter_limit_per_second": rl.limitPerSecond, "rate_limiter_burst_size": rl.burstSize, }). // TODO: change to Debug with https://gitlab.com/gitlab-org/gitlab-pages/-/issues/629 diff --git a/internal/ratelimiter/ratelimiter.go b/internal/ratelimiter/ratelimiter.go index feeb8cb4c217e0125522eba1efcb9ece87f2c509..24fc05fe4e107aa7a3244abf2492618776f83b1a 100644 --- a/internal/ratelimiter/ratelimiter.go +++ b/internal/ratelimiter/ratelimiter.go @@ -138,6 +138,11 @@ func (rl *RateLimiter) limiter(key string) *rate.Limiter { // requestAllowed checks if request is within the rate-limit func (rl *RateLimiter) requestAllowed(r *http.Request) bool { rateLimitedKey := rl.keyFunc(r) + + return rl.allowed(rateLimitedKey) +} + +func (rl *RateLimiter) allowed(rateLimitedKey string) bool { limiter := rl.limiter(rateLimitedKey) // AllowN allows us to use the rl.now function, so we can test this more easily. diff --git a/internal/ratelimiter/tls.go b/internal/ratelimiter/tls.go new file mode 100644 index 0000000000000000000000000000000000000000..e8d6dc72663a1ca265abb938abe63e079a513071 --- /dev/null +++ b/internal/ratelimiter/tls.go @@ -0,0 +1,62 @@ +package ratelimiter + +import ( + "crypto/tls" + "errors" + "net" + "strconv" + + "github.com/sirupsen/logrus" + "gitlab.com/gitlab-org/labkit/log" + + tlsconfig "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" +) + +var TLSRateLimitedError = errors.New("TLS connection is being rate-limited") + +func (rl *RateLimiter) GetCertificateMiddleware(getCertificate tlsconfig.GetCertificateFunc) tlsconfig.GetCertificateFunc { + log.Info("GetCertificateMiddleware set") + return func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + log.WithFields(logrus.Fields{ + "server_name": hi.ServerName, + }).Info("GetCertificateMiddleware called") + + return getCertificate(hi) + + if rl.allowed(hi.ServerName) { + return getCertificate(hi) + } + + rl.logRateLimitedTLS(hi) + + if rl.blockedCount != nil { + rl.blockedCount.WithLabelValues(strconv.FormatBool(rl.enforce)).Inc() + } + + if !rl.enforce { + return getCertificate(hi) + } + + return nil, TLSRateLimitedError + } +} + +func (rl *RateLimiter) logRateLimitedTLS(hi *tls.ClientHelloInfo) { + log.WithFields(logrus.Fields{ + "rate_limiter_name": rl.name, + "source_ip": getRemoteAddrFromHelloInfo(hi), + "req_host": hi.ServerName, + "rate_limiter_limit_per_second": rl.limitPerSecond, + "rate_limiter_burst_size": rl.burstSize, + }).Info("TLS connection rate-limited") +} + +func getRemoteAddrFromHelloInfo(hi *tls.ClientHelloInfo) string { + remoteAddr := hi.Conn.RemoteAddr().String() + remoteAddr, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr + } + + return remoteAddr +} diff --git a/metrics/metrics.go b/metrics/metrics.go index e0e4ab2925473df89cba4d153d6b94b69f9c670e..4954e6d10fcd7a3badd9266f919e7de426b47aff 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -248,6 +248,34 @@ var ( }, []string{"enforced"}, ) + + // RateLimitDomainTLSCacheRequests is the number of cache hits/misses + RateLimitDomainTLSCacheRequests = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_pages_rate_limit_domain_tls_cache_requests", + Help: "The number of source_ip cache hits/misses in the rate limiter", + }, + []string{"op", "cache"}, + ) + + // RateLimitDomainTLSCachedEntries is the number of entries in the cache + RateLimitDomainTLSCachedEntries = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_pages_rate_limit_domain_tls_cached_entries", + Help: "The number of entries in the cache", + }, + []string{"op"}, + ) + + // RateLimitDomainTLSBlockedCount is the number of TLS connections that have been blocked by the + // domain TLS rate limiter + RateLimitDomainTLSBlockedCount = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "gitlab_pages_rate_limit_domain_tls_blocked_count", + Help: "The number of requests addresses that have been blocked by the domain TLS rate limiter", + }, + []string{"enforced"}, + ) ) // MustRegister collectors with the Prometheus client diff --git a/test/acceptance/artifacts_test.go b/test/acceptance/artifacts_test.go index f087581c8e96be61e47f82aae22d88858ffc5fb9..653e31cfea864b8e7bfbc9f2aaf833041d66aaa8 100644 --- a/test/acceptance/artifacts_test.go +++ b/test/acceptance/artifacts_test.go @@ -14,9 +14,6 @@ import ( ) func TestArtifactProxyRequest(t *testing.T) { - transport := (TestHTTPSClient.Transport).(*http.Transport).Clone() - transport.ResponseHeaderTimeout = 5 * time.Second - content := "