diff --git a/app.go b/app.go index 389ce0b47960b87f6b4c6076df59e8644a4434ea..a6e8e6ffbe46645d894a8afc548de74e129e9fb9 100644 --- a/app.go +++ b/app.go @@ -15,6 +15,8 @@ import ( "github.com/rs/cors" "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/go-mimedb" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/errortracking" @@ -337,6 +339,15 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { metricsMiddleware := labmetrics.NewHandlerFactory(labmetrics.WithNamespace("gitlab_pages")) handler = metricsMiddleware(handler) + if a.config.General.EnableRateLimiter { + handler = ratelimiter.DomainRateLimiter( + ratelimiter.New( + ratelimiter.WithPerDomainFrequency(a.config.General.RateLimitPerDomainFrequency), + ratelimiter.WithPerDomainBurstSize(a.config.General.RateLimitPerDomainBurstSize), + ), + )(handler) + } + handler = a.routingMiddleware(handler) // Health Check diff --git a/go.mod b/go.mod index 5f17d2536b040d9bfffa49fd9dcb33cd3cdd5bc9..0ff516edd6ec2c8059b4f93cc86704a9f4165215 100644 --- a/go.mod +++ b/go.mod @@ -25,4 +25,5 @@ require ( golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 + golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 ) diff --git a/go.sum b/go.sum index f004e86f126eafb5acfcbd147379a8866c178c11..c57b16add36a4855972bc03b3375be203edefaa2 100644 --- a/go.sum +++ b/go.sum @@ -398,6 +398,7 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/config/config.go b/internal/config/config.go index 61767a57e9cbc543fb107beffa2d1d4d3837dccd..d2ed2363c8d517e587680997c7e2ea9f6b8b1147 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,14 +45,17 @@ type Config struct { // General groups settings that are general to GitLab Pages and can not // be categorized under other head. type General struct { - Domain string - MaxConns int - MetricsAddress string - RedirectHTTP bool - RootCertificate []byte - RootDir string - RootKey []byte - StatusPath string + Domain string + EnableRateLimiter bool + RateLimitPerDomainFrequency time.Duration + RateLimitPerDomainBurstSize int + MaxConns int + MetricsAddress string + RedirectHTTP bool + RootCertificate []byte + RootDir string + RootKey []byte + StatusPath string DisableCrossOriginRequests bool InsecureCiphers bool @@ -179,17 +182,20 @@ func setGitLabAPISecretKey(secretFile string, config *Config) error { func loadConfig() (*Config, error) { config := &Config{ General: General{ - Domain: strings.ToLower(*pagesDomain), - MaxConns: *maxConns, - MetricsAddress: *metricsAddress, - RedirectHTTP: *redirectHTTP, - RootDir: *pagesRoot, - StatusPath: *pagesStatus, - DisableCrossOriginRequests: *disableCrossOriginRequests, - InsecureCiphers: *insecureCiphers, - PropagateCorrelationID: *propagateCorrelationID, - CustomHeaders: header.Split(), - ShowVersion: *showVersion, + Domain: strings.ToLower(*pagesDomain), + EnableRateLimiter: *enableRateLimiter, + RateLimitPerDomainFrequency: *rateLimitPerDomain, + RateLimitPerDomainBurstSize: *rateLimitPerDomainBurstSize, + MaxConns: *maxConns, + MetricsAddress: *metricsAddress, + RedirectHTTP: *redirectHTTP, + RootDir: *pagesRoot, + StatusPath: *pagesStatus, + DisableCrossOriginRequests: *disableCrossOriginRequests, + InsecureCiphers: *insecureCiphers, + PropagateCorrelationID: *propagateCorrelationID, + CustomHeaders: header.Split(), + ShowVersion: *showVersion, }, GitLab: GitLab{ ClientHTTPTimeout: *gitlabClientHTTPTimeout, diff --git a/internal/config/flags.go b/internal/config/flags.go index aa5bf1c5096aa235f46864c439cde70ba192dd35..7e2781be955676f1537517d0ac8b89e5f98f119a 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -5,41 +5,46 @@ import ( "github.com/namsral/flag" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) var ( - pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") - pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") - redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") - _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") - pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") - pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") - artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") - artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") - pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") - metricsAddress = flag.String("metrics-address", "", "The address to listen on for metrics requests") - sentryDSN = flag.String("sentry-dsn", "", "The address for sending sentry crash reporting to") - sentryEnvironment = flag.String("sentry-environment", "", "The environment for sentry crash reporting") - daemonUID = flag.Uint("daemon-uid", 0, "Drop privileges to this user") - daemonGID = flag.Uint("daemon-gid", 0, "Drop privileges to this group") - _ = flag.Bool("daemon-enable-jail", false, "DEPRECATED and ignored, will be removed in 15.0") - _ = flag.Bool("daemon-inplace-chroot", false, "DEPRECATED and ignored, will be removed in 15.0") // TODO: https://gitlab.com/gitlab-org/gitlab-pages/-/issues/599 - propagateCorrelationID = flag.Bool("propagate-correlation-id", false, "Reuse existing Correlation-ID from the incoming request header `X-Request-ID` if present") - logFormat = flag.String("log-format", "json", "The log output format: 'text' or 'json'") - logVerbose = flag.Bool("log-verbose", false, "Verbose logging") - secret = flag.String("auth-secret", "", "Cookie store hash key, should be at least 32 bytes long") - publicGitLabServer = flag.String("gitlab-server", "", "Public GitLab server, for example https://www.gitlab.com") - internalGitLabServer = flag.String("internal-gitlab-server", "", "Internal GitLab server used for API requests, useful if you want to send that traffic over an internal load balancer, example value https://gitlab.example.internal (defaults to value of gitlab-server)") - gitLabAPISecretKey = flag.String("api-secret-key", "", "File with secret key used to authenticate with the GitLab API") - gitlabClientHTTPTimeout = flag.Duration("gitlab-client-http-timeout", 10*time.Second, "GitLab API HTTP client connection timeout in seconds (default: 10s)") - gitlabClientJWTExpiry = flag.Duration("gitlab-client-jwt-expiry", 30*time.Second, "JWT Token expiry time in seconds (default: 30s)") - gitlabCacheExpiry = flag.Duration("gitlab-cache-expiry", 10*time.Minute, "The maximum time a domain's configuration is stored in the cache") - gitlabCacheRefresh = flag.Duration("gitlab-cache-refresh", time.Minute, "The interval at which a domain's configuration is set to be due to refresh") - gitlabCacheCleanup = flag.Duration("gitlab-cache-cleanup", time.Minute, "The interval at which expired items are removed from the cache") - gitlabRetrievalTimeout = flag.Duration("gitlab-retrieval-timeout", 30*time.Second, "The maximum time to wait for a response from the GitLab API per request") - gitlabRetrievalInterval = flag.Duration("gitlab-retrieval-interval", time.Second, "The interval to wait before retrying to resolve a domain's configuration via the GitLab API") - gitlabRetrievalRetries = flag.Int("gitlab-retrieval-retries", 3, "The maximum number of times to retry to resolve a domain's configuration via the API") + pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") + pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") + redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") + enableRateLimiter = flag.Bool("enable-rate-limiter", false, "Enable in-built rate limiter") + rateLimitPerDomain = flag.Duration("rate-limit-per-domain", ratelimiter.DefaultPerDomainFrequency, "Rate of allowed requests as duration. E.g. 25ms allows 1 request every 25ms or 40 requests per second") + rateLimitPerDomainBurstSize = flag.Int("rate-limit-per-domain-burst-size", ratelimiter.DefaultPerDomainBurstSize, "Maximum number of requests to allow in a rate-limit-per-domain period. E.g. setting value to 10 will allow up to 10 requests within 25ms, then only 1 req per 25ms will be allowed.") + _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") + pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") + pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") + artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") + artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") + pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") + metricsAddress = flag.String("metrics-address", "", "The address to listen on for metrics requests") + sentryDSN = flag.String("sentry-dsn", "", "The address for sending sentry crash reporting to") + sentryEnvironment = flag.String("sentry-environment", "", "The environment for sentry crash reporting") + daemonUID = flag.Uint("daemon-uid", 0, "Drop privileges to this user") + daemonGID = flag.Uint("daemon-gid", 0, "Drop privileges to this group") + _ = flag.Bool("daemon-enable-jail", false, "DEPRECATED and ignored, will be removed in 15.0") + _ = flag.Bool("daemon-inplace-chroot", false, "DEPRECATED and ignored, will be removed in 15.0") // TODO: https://gitlab.com/gitlab-org/gitlab-pages/-/issues/599 + propagateCorrelationID = flag.Bool("propagate-correlation-id", false, "Reuse existing Correlation-ID from the incoming request header `X-Request-ID` if present") + logFormat = flag.String("log-format", "json", "The log output format: 'text' or 'json'") + logVerbose = flag.Bool("log-verbose", false, "Verbose logging") + secret = flag.String("auth-secret", "", "Cookie store hash key, should be at least 32 bytes long") + publicGitLabServer = flag.String("gitlab-server", "", "Public GitLab server, for example https://www.gitlab.com") + internalGitLabServer = flag.String("internal-gitlab-server", "", "Internal GitLab server used for API requests, useful if you want to send that traffic over an internal load balancer, example value https://gitlab.example.internal (defaults to value of gitlab-server)") + gitLabAPISecretKey = flag.String("api-secret-key", "", "File with secret key used to authenticate with the GitLab API") + gitlabClientHTTPTimeout = flag.Duration("gitlab-client-http-timeout", 10*time.Second, "GitLab API HTTP client connection timeout in seconds (default: 10s)") + gitlabClientJWTExpiry = flag.Duration("gitlab-client-jwt-expiry", 30*time.Second, "JWT Token expiry time in seconds (default: 30s)") + gitlabCacheExpiry = flag.Duration("gitlab-cache-expiry", 10*time.Minute, "The maximum time a domain's configuration is stored in the cache") + gitlabCacheRefresh = flag.Duration("gitlab-cache-refresh", time.Minute, "The interval at which a domain's configuration is set to be due to refresh") + gitlabCacheCleanup = flag.Duration("gitlab-cache-cleanup", time.Minute, "The interval at which expired items are removed from the cache") + gitlabRetrievalTimeout = flag.Duration("gitlab-retrieval-timeout", 30*time.Second, "The maximum time to wait for a response from the GitLab API per request") + gitlabRetrievalInterval = flag.Duration("gitlab-retrieval-interval", time.Second, "The interval to wait before retrying to resolve a domain's configuration via the GitLab API") + gitlabRetrievalRetries = flag.Int("gitlab-retrieval-retries", 3, "The maximum number of times to retry to resolve a domain's configuration via the API") _ = flag.String("domain-config-source", "gitlab", "DEPRECATED and has not affect, see https://gitlab.com/gitlab-org/gitlab-pages/-/merge_requests/541") enableDisk = flag.Bool("enable-disk", true, "Enable disk access, shall be disabled in environments where shared disk storage isn't available") diff --git a/internal/httperrors/httperrors.go b/internal/httperrors/httperrors.go index ed56ee103ecbbd8cfbe2a6a2f62e7232cb134d6a..6aea017594ee743eac5953fd2a1c4fa8598e4eb2 100644 --- a/internal/httperrors/httperrors.go +++ b/internal/httperrors/httperrors.go @@ -34,6 +34,13 @@ var (

Make sure the address is correct and that the page hasn't moved.

Please contact your GitLab administrator if you think this is a mistake.

`, } + content429 = content{ + http.StatusTooManyRequests, + "Too many requests (429)", + "429", + "Too many requests.", + `

The resource that you are attempting to access is being rate limited.

`, + } content500 = content{ http.StatusInternalServerError, "Something went wrong (500)", @@ -176,6 +183,11 @@ func Serve404(w http.ResponseWriter) { serveErrorPage(w, content404) } +// Serve429 returns a 429 error response / HTML page to the http.ResponseWriter +func Serve429(w http.ResponseWriter) { + serveErrorPage(w, content429) +} + // Serve500 returns a 500 error response / HTML page to the http.ResponseWriter func Serve500(w http.ResponseWriter) { serveErrorPage(w, content500) diff --git a/internal/ratelimiter/lru_cache.go b/internal/ratelimiter/lru_cache.go new file mode 100644 index 0000000000000000000000000000000000000000..34c66c811c9040ac1e92e10da3b98f2a0742d33d --- /dev/null +++ b/internal/ratelimiter/lru_cache.go @@ -0,0 +1,64 @@ +package ratelimiter + +import ( + "time" + + "github.com/karlseguin/ccache/v2" +) + +// lruCacheGetPerPromote is a value that makes the item to be promoted +// it is taken arbitrarily as a sane value indicating that the item +// was frequently picked +// promotion moves the item to the front of the LRU list +const lruCacheGetsPerPromote = 64 + +// lruCacheItemsToPruneDiv is a value that indicates how many items +// need to be pruned on OOM, this prunes 1/16 of items +const lruCacheItemsToPruneDiv = 16 + +type lruCache struct { + op string + duration time.Duration + cache *ccache.Cache +} + +func newLruCache(op string, maxEntries int64, duration time.Duration) *lruCache { + configuration := ccache.Configure() + configuration.MaxSize(maxEntries) + configuration.ItemsToPrune(uint32(maxEntries) / lruCacheItemsToPruneDiv) + configuration.GetsPerPromote(lruCacheGetsPerPromote) // if item gets requested frequently promote it + configuration.OnDelete(func(*ccache.Item) { + // TODO: add metrics + //metrics.ZipCachedEntries.WithLabelValues(op).Dec() + }) + + return &lruCache{ + op: op, + cache: ccache.New(configuration), + duration: duration, + } +} + +func (c *lruCache) findOrFetch(cacheNamespace, key string, fetchFn func() (interface{}, error)) (interface{}, error) { + item := c.cache.Get(cacheNamespace + key) + + if item != nil && !item.Expired() { + // TODO: add metrics + //metrics.ZipCacheRequests.WithLabelValues(c.op, "hit").Inc() + return item.Value(), nil + } + + value, err := fetchFn() + if err != nil { + // TODO: add metrics + //metrics.ZipCacheRequests.WithLabelValues(c.op, "error").Inc() + return nil, err + } + + // TODO: add metrics + //metrics.ZipCacheRequests.WithLabelValues(c.op, "miss").Inc() + //metrics.ZipCachedEntries.WithLabelValues(c.op).Inc() + + c.cache.Set(cacheNamespace+key, value, c.duration) + return value, nil +} diff --git a/internal/ratelimiter/middleware.go b/internal/ratelimiter/middleware.go new file mode 100644 index 0000000000000000000000000000000000000000..f78ecd99c4f939b9cabd91c0485a5f62556508cb --- /dev/null +++ b/internal/ratelimiter/middleware.go @@ -0,0 +1,34 @@ +package ratelimiter + +import ( + "net" + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" +) + +// DomainRateLimiter middleware ensures that the requested domain can be served by the current +// rate limit. See -rate-limiter +func DomainRateLimiter(rl *RateLimiter) func(http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := getHost(r) + + if host != "127.0.0.1" && !rl.DomainAllowed(host) { + httperrors.Serve429(w) + return + } + + handler.ServeHTTP(w, r) + }) + } +} + +func getHost(r *http.Request) string { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + + return host +} diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go new file mode 100644 index 0000000000000000000000000000000000000000..70a30294f6abaf047c11e934a4be8ff94cff372b --- /dev/null +++ b/internal/ratelimiter/ratelimit.go @@ -0,0 +1,103 @@ +package ratelimiter + +import ( + "time" + + "gitlab.com/gitlab-org/labkit/log" + "golang.org/x/time/rate" +) + +const ( + // DefaultPerDomainFrequency the maximum number of requests per second to be allowed per domain. + // The default value of 25ms equals 1 request every 25ms -> 40 rps + DefaultPerDomainFrequency = 25 * time.Millisecond + // DefaultPerDomainBurstSize is the maximum burst allowed per rate limiter + // E.g. The first 40 requests within 25ms will succeed, but the 41st will fail until the next + // refill occurs at DefaultPerDomainFrequency, allowing only 1 request per rate frequency. + DefaultPerDomainBurstSize = 40 + + // avg of ~18,000 unique domains per hour + // https://log.gprd.gitlab.net/app/lens#/edit/3c45a610-15c9-11ec-a012-eb2e5674cacf?_g=h@e78830b + defaultDomainsItems = 20000 + defaultDomainsExpirationInterval = time.Hour +) + +// Option function to configure a RateLimiter +type Option func(*RateLimiter) + +// RateLimiter holds a map ot domain names with counters that enable rate limiting per domain. +// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per domain entry. +// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html +// Cleanup runs every cleanupTimer iteration over all domains and removing them if +// the time since counter.lastSeen is greater than the domainMaxTTL. +type RateLimiter struct { + now func() time.Time + perDomainFrequency time.Duration + perDomainBurstSize int + //domainMux *sync.RWMutex + domainsCache *lruCache + // TODO: add sourceIPCache +} + +// New creates a new RateLimiter with default values that can be configured via Option functions +func New(opts ...Option) *RateLimiter { + rl := &RateLimiter{ + now: time.Now, + perDomainFrequency: DefaultPerDomainFrequency, + perDomainBurstSize: DefaultPerDomainBurstSize, + //domainMux: &sync.RWMutex{}, + domainsCache: newLruCache("domains", defaultDomainsItems, defaultDomainsExpirationInterval), + } + + for _, opt := range opts { + opt(rl) + } + + return rl +} + +// WithNow replaces the RateLimiter now function +func WithNow(now func() time.Time) Option { + return func(rl *RateLimiter) { + rl.now = now + } +} + +// WithPerDomainFrequency allows configuring perDomain frequency for the RateLimiter +func WithPerDomainFrequency(d time.Duration) Option { + return func(rl *RateLimiter) { + rl.perDomainFrequency = d + } +} + +// WithPerDomainBurstSize configures burst per domain for the RateLimiter +func WithPerDomainBurstSize(burst int) Option { + return func(rl *RateLimiter) { + rl.perDomainBurstSize = burst + } +} + +func (rl *RateLimiter) getDomainCounter(domain string) (*rate.Limiter, error) { + limiterI, err := rl.domainsCache.findOrFetch(domain, domain, func() (interface{}, error) { + return rate.NewLimiter(rate.Every(rl.perDomainFrequency), rl.perDomainBurstSize), nil + }) + if err != nil { + return nil, err + } + + return limiterI.(*rate.Limiter), nil +} + +// DomainAllowed checks that the requested domain can be accessed within +// the maxCountPerDomain in the given domainWindow. +func (rl *RateLimiter) DomainAllowed(domain string) (res bool) { + limiter, err := rl.getDomainCounter(domain) + if err != nil { + // TODO: return and handle error appropriately + log.WithError(err).Warnf("failed to get rate limiter for domain: %s", domain) + return true + } + + // TODO: we could use Wait(ctx) if we want to moderate the request rate rather than denying requests + return limiter.AllowN(rl.now(), 1) +} diff --git a/internal/ratelimiter/ratelimit_test.go b/internal/ratelimiter/ratelimit_test.go new file mode 100644 index 0000000000000000000000000000000000000000..73ba9e030eab8e11d3caee6dca705a4d86738c22 --- /dev/null +++ b/internal/ratelimiter/ratelimit_test.go @@ -0,0 +1,100 @@ +package ratelimiter + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + now = "2021-09-13T15:00:00Z" + validTime, _ = time.Parse(time.RFC3339, now) +) + +func mockNow() time.Time { + return validTime +} + +func TestDomainAllowed(t *testing.T) { + tcs := map[string]struct { + now string + domainRate time.Duration + perDomainBurstPerSecond int + domain string + reqNum int + }{ + "one_request_per_second": { + domainRate: 1, // 1 per second + perDomainBurstPerSecond: 1, + reqNum: 2, + domain: "rate.gitlab.io", + }, + "one_request_per_second_but_big_bucket": { + domainRate: 1, // 1 per second + perDomainBurstPerSecond: 10, + reqNum: 11, + domain: "rate.gitlab.io", + }, + "three_req_per_second_bucket_size_one": { + domainRate: 3, // 3 per second + perDomainBurstPerSecond: 1, // max burst 1 means 1 at a time + reqNum: 3, + domain: "rate.gitlab.io", + }, + "10_requests_per_second": { + domainRate: 10, + perDomainBurstPerSecond: 10, + reqNum: 11, + domain: "rate.gitlab.io", + }, + } + + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + rl := New( + WithNow(mockNow), + WithPerDomainFrequency(tc.domainRate), + WithPerDomainBurstSize(tc.perDomainBurstPerSecond), + ) + + for i := 0; i < tc.reqNum; i++ { + got := rl.DomainAllowed(tc.domain) + if i < tc.perDomainBurstPerSecond { + require.Truef(t, got, "expected true for request no. %d", i+1) + } else { + require.False(t, got, "expected false for request no. %d", i+1) + } + } + }) + } +} + +func TestDomainAllowedWitSleeps(t *testing.T) { + rate := 10 * time.Millisecond + rl := New( + WithPerDomainFrequency(rate), + WithPerDomainBurstSize(1), + ) + domain := "test.gitlab.io" + + t.Run("one request every 10ms with burst 1", func(t *testing.T) { + for i := 0; i < 10; i++ { + got := rl.DomainAllowed(domain) + assert.Truef(t, got, "expected true for request no. %d", i+1) + time.Sleep(rate) + } + }) + + t.Run("requests start failing after reaching burst", func(t *testing.T) { + for i := 0; i < 5; i++ { + got := rl.DomainAllowed(domain + ".diff") + if i < 1 { + require.Truef(t, got, "expected true for request no. %d", i) + } else { + require.False(t, got, "expected false for request no. %d", i) + } + } + }) +} diff --git a/test/acceptance/ratelimit_test.go b/test/acceptance/ratelimit_test.go new file mode 100644 index 0000000000000000000000000000000000000000..04518fe5e697c3f688520b0c4933ca757809bb56 --- /dev/null +++ b/test/acceptance/ratelimit_test.go @@ -0,0 +1,37 @@ +package acceptance_test + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateLimitMiddleware(t *testing.T) { + RunPagesProcess(t, + withListeners([]ListenSpec{httpListener}), + withExtraArgument("enable-rate-limiter", "true"), + //refills 1 token every 10ms, bound by the burst/bucket size + withExtraArgument("rate-limit-per-domain", "10ms"), + // allows a max of 1 token at a time PER instance of time + withExtraArgument("rate-limit-per-domain-burst-size", "1"), + ) + + for i := 0; i < 20; i++ { + rsp1, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + rsp1.Body.Close() + + // every other request should fail + if i%2 != 0 { + require.Equal(t, http.StatusTooManyRequests, rsp1.StatusCode, "group.gitlab-example.com request: %d failed", i) + // wait for another token to become available + time.Sleep(10 * time.Millisecond) + continue + } + + require.Equal(t, http.StatusOK, rsp1.StatusCode, "group.gitlab-example.com request: %d failed", i) + time.Sleep(time.Millisecond) + } +}