From cab08fb78cc040fa19d74ce221fe7718dddc064a Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Sun, 12 Sep 2021 12:55:49 +1000 Subject: [PATCH 1/9] feat: wip ratelimiter (cherry picked from commit 6744061806137931c0429c84b8ec02b3992d81e0) --- internal/ratelimiter/ratelimit.go | 163 +++++++++++++++++++++++++ internal/ratelimiter/ratelimit_test.go | 1 + 2 files changed, 164 insertions(+) create mode 100644 internal/ratelimiter/ratelimit.go create mode 100644 internal/ratelimiter/ratelimit_test.go diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go new file mode 100644 index 000000000..afb7a1038 --- /dev/null +++ b/internal/ratelimiter/ratelimit.go @@ -0,0 +1,163 @@ +package ratelimiter + +import ( + "errors" + "sync" + "time" + + "gitlab.com/gitlab-org/labkit/log" +) + +const ( + DefaultCleanupInterval = time.Second + DefaultWindowPerDomain = time.Second + DefaultPerDomainMaxCount = 100 +) + +var ( + errDomainCounterNotFound = errors.New("domain counter not found") +) + +type counter struct { + count int64 + lastSeen time.Time +} + +type Option func(*RateLimiter) + +type RateLimiter struct { + now func() time.Time + cleanupTimer *time.Ticker + domainWindow time.Duration + maxCountPerDomain int64 + domainMux *sync.RWMutex + // TODO: this could be an LRU cache like what we do in the zip VFS + perDomain map[string]counter +} + +// New creates a new RateLimiter with default values +func New(opts ...Option) *RateLimiter { + rl := &RateLimiter{ + now: time.Now, + cleanupTimer: time.NewTicker(DefaultCleanupInterval), + domainWindow: DefaultWindowPerDomain, + maxCountPerDomain: DefaultPerDomainMaxCount, + domainMux: &sync.RWMutex{}, + perDomain: make(map[string]counter), + } + + for _, opt := range opts { + opt(rl) + } + + go rl.cleanup() + + return rl +} + +func WithNow(now func() time.Time) Option { + return func(rl *RateLimiter) { + rl.now = now + } +} + +func WithCleanupInterval(d time.Duration) Option { + return func(rl *RateLimiter) { + rl.cleanupTimer.Reset(d) + } +} + +func WithDomainWindow(d time.Duration) Option { + return func(rl *RateLimiter) { + rl.domainWindow = d + } +} +func WithDomainMaxCount(c int64) Option { + return func(rl *RateLimiter) { + rl.maxCountPerDomain = c + } +} + +// AddDomain to the current RateLimiter per domain count +func (rl *RateLimiter) AddDomain(domain string) { + rl.domainMux.Lock() + defer rl.domainMux.Unlock() + + // TODO: add metrics + currentCounter, ok := rl.perDomain[domain] + if !ok { + newCounter := counter{ + lastSeen: rl.now(), + count: 1, + } + + rl.perDomain[domain] = newCounter + return + } + + currentCounter.count++ +} + +// DomainAllowed checks that the requested domain can be accessed within +// the maxCountPerDomain in the given domainWindow. +func (rl *RateLimiter) DomainAllowed(domain string) bool { + // increment counter for this domain regardless if allowed or not + defer rl.AddDomain(domain) + + domainCounter, err := rl.getDomainCounter(domain) + if err != nil && errors.Is(err, errDomainCounterNotFound) { + // we haven't seen this domain so it should be allowed + log.WithError(err).Warn("DomainAllowed did not find the requested domain") + return true + } + + now := rl.now() + lastSeen := domainCounter.lastSeen + count := domainCounter.count + + //if requested within time window and the count is less thant the max count + // e.g. maxCount = 10 and window is 10s + // now is 1s, count is 1 -> true + // now is 11s, count is < 10 -> true + // now is 2s, count > 10 -> false + if now.Sub(lastSeen) < rl.domainWindow { + if count < rl.maxCountPerDomain { + return true + } + } + + return false +} + +func (rl *RateLimiter) getDomainCounter(domain string) (counter, error) { + rl.domainMux.RLock() + defer rl.domainMux.RUnlock() + + currentCounter, ok := rl.perDomain[domain] + if !ok { + return counter{}, errDomainCounterNotFound + } + + return currentCounter, nil +} + +func (rl *RateLimiter) cleanup() { + select { + case t := <-rl.cleanupTimer.C: + log.WithField("cleanup", t).Info("cleaning rate limiter") + go func() { + rl.domainMux.Lock() + defer rl.domainMux.Unlock() + for _, counter := range rl.perDomain { + if rl.now().Sub(counter.lastSeen) > rl.domainWindow { + counter.count -= rl.maxCountPerDomain + if counter.count < 0 { + counter.count = 0 + } + } + } + }() + default: + + } +} diff --git a/internal/ratelimiter/ratelimit_test.go b/internal/ratelimiter/ratelimit_test.go new file mode 100644 index 000000000..631185b31 --- /dev/null +++ b/internal/ratelimiter/ratelimit_test.go @@ -0,0 +1 @@ +package ratelimiter -- GitLab From 1c0025bef6121da179f3ef664d39ec71b735caee Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Mon, 13 Sep 2021 12:10:18 +1000 Subject: [PATCH 2/9] feat: use Golang's rate.Limiter --- go.mod | 1 + go.sum | 1 + internal/ratelimiter/ratelimit.go | 138 ++++++++++--------------- internal/ratelimiter/ratelimit_test.go | 53 ++++++++++ 4 files changed, 111 insertions(+), 82 deletions(-) diff --git a/go.mod b/go.mod index 5f17d2536..0ff516edd 100644 --- a/go.mod +++ b/go.mod @@ -25,4 +25,5 @@ require ( golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 + golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 ) diff --git a/go.sum b/go.sum index f004e86f1..c57b16add 100644 --- a/go.sum +++ b/go.sum @@ -398,6 +398,7 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181221001348-537d06c36207/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go index afb7a1038..60e8f1524 100644 --- a/internal/ratelimiter/ratelimit.go +++ b/internal/ratelimiter/ratelimit.go @@ -1,49 +1,56 @@ package ratelimiter import ( - "errors" + "fmt" "sync" "time" "gitlab.com/gitlab-org/labkit/log" + "golang.org/x/time/rate" ) const ( - DefaultCleanupInterval = time.Second - DefaultWindowPerDomain = time.Second - DefaultPerDomainMaxCount = 100 -) - -var ( - errDomainCounterNotFound = errors.New("domain counter not found") + // DefaultCleanupInterval is the time at which cleanup will run + DefaultCleanupInterval = 30 * time.Second + // DefaultMaxTimePerDomain is the maximum time to keep a domain in the rate limiter map + DefaultMaxTimePerDomain = 30 * time.Second + + // DefaultRatePerDomainPerSecond transformed to rate.Limit = 1 / DefaultRatePerDomainPerSecond. + // The default value is equivalent to 100 requests per second per domain + DefaultRatePerDomainPerSecond = 0.01 + // DefaultPerDomainMaxBurstPerSecond is the maximum burst in requests. TODO need to understand and test this + DefaultPerDomainMaxBurstPerSecond = 100 ) type counter struct { - count int64 + limiter *rate.Limiter lastSeen time.Time } +// Option function to configure a RateLimiter type Option func(*RateLimiter) type RateLimiter struct { - now func() time.Time - cleanupTimer *time.Ticker - domainWindow time.Duration - maxCountPerDomain int64 - domainMux *sync.RWMutex - // TODO: this could be an LRU cache like what we do in the zip VFS - perDomain map[string]counter + now func() time.Time + cleanupTimer *time.Ticker + maxTimePerDomain time.Duration + domainRatePerSecond float64 + perDomainBurstPerSecond int + domainMux *sync.RWMutex + // TODO: this could be an LRU cache like what we do in the zip VFS instead of cleaning manually ? + perDomain map[string]*counter } // New creates a new RateLimiter with default values func New(opts ...Option) *RateLimiter { rl := &RateLimiter{ - now: time.Now, - cleanupTimer: time.NewTicker(DefaultCleanupInterval), - domainWindow: DefaultWindowPerDomain, - maxCountPerDomain: DefaultPerDomainMaxCount, - domainMux: &sync.RWMutex{}, - perDomain: make(map[string]counter), + now: time.Now, + cleanupTimer: time.NewTicker(DefaultCleanupInterval), + maxTimePerDomain: DefaultMaxTimePerDomain, + domainRatePerSecond: DefaultRatePerDomainPerSecond, + perDomainBurstPerSecond: DefaultPerDomainMaxBurstPerSecond, + domainMux: &sync.RWMutex{}, + perDomain: make(map[string]*counter), } for _, opt := range opts { @@ -67,97 +74,64 @@ func WithCleanupInterval(d time.Duration) Option { } } -func WithDomainWindow(d time.Duration) Option { +func WithDomainRatePerSecond(r float64) Option { return func(rl *RateLimiter) { - rl.domainWindow = d + rl.domainRatePerSecond = r } } -func WithDomainMaxCount(c int64) Option { + +func WithDomainBurstPerSecond(burst int) Option { return func(rl *RateLimiter) { - rl.maxCountPerDomain = c + rl.perDomainBurstPerSecond = burst } } -// AddDomain to the current RateLimiter per domain count -func (rl *RateLimiter) AddDomain(domain string) { +func (rl *RateLimiter) getDomainCounter(domain string) *counter { rl.domainMux.Lock() defer rl.domainMux.Unlock() // TODO: add metrics currentCounter, ok := rl.perDomain[domain] if !ok { - newCounter := counter{ + newCounter := &counter{ lastSeen: rl.now(), - count: 1, + limiter: rate.NewLimiter(rate.Limit(rl.domainRatePerSecond), rl.perDomainBurstPerSecond), } rl.perDomain[domain] = newCounter - return + return newCounter } - currentCounter.count++ + currentCounter.lastSeen = rl.now() + return currentCounter } // DomainAllowed checks that the requested domain can be accessed within // the maxCountPerDomain in the given domainWindow. -func (rl *RateLimiter) DomainAllowed(domain string) bool { - // increment counter for this domain regardless if allowed or not - defer rl.AddDomain(domain) - - domainCounter, err := rl.getDomainCounter(domain) - if err != nil && errors.Is(err, errDomainCounterNotFound) { - // we haven't seen this domain so it should be allowed - log.WithError(err).Warn("DomainAllowed did not find the requested domain") - return true - } - - now := rl.now() - lastSeen := domainCounter.lastSeen - count := domainCounter.count - - //if requested within time window and the count is less thant the max count - // e.g. maxCount = 10 and window is 10s - // now is 1s, count is 1 -> true - // now is 11s, count is < 10 -> true - // now is 2s, count > 10 -> false - if now.Sub(lastSeen) < rl.domainWindow { - if count < rl.maxCountPerDomain { - return true - } - } - - return false -} - -func (rl *RateLimiter) getDomainCounter(domain string) (counter, error) { - rl.domainMux.RLock() - defer rl.domainMux.RUnlock() +func (rl *RateLimiter) DomainAllowed(domain string) (res bool) { - currentCounter, ok := rl.perDomain[domain] - if !ok { - return counter{}, errDomainCounterNotFound - } + counter := rl.getDomainCounter(domain) + defer func() { + fmt.Printf("limiter info: limit: %f - burst: %d\n", counter.limiter.Limit(), counter.limiter.Burst()) + fmt.Printf("calling DomainAllowed for: %q returned: %t\n", domain, res) + }() - return currentCounter, nil + // TODO: we could use Wait(ctx) if we want to moderate the request rate rather than denying requests + return counter.limiter.Allow() } func (rl *RateLimiter) cleanup() { select { case t := <-rl.cleanupTimer.C: - log.WithField("cleanup", t).Info("cleaning rate limiter") - go func() { - rl.domainMux.Lock() - defer rl.domainMux.Unlock() - for _, counter := range rl.perDomain { - if rl.now().Sub(counter.lastSeen) > rl.domainWindow { - counter.count -= rl.maxCountPerDomain - if counter.count < 0 { - counter.count = 0 - } - } + log.WithField("cleanup", t).Debug("cleaning perDomain rate") + + rl.domainMux.Lock() + for domain, counter := range rl.perDomain { + if time.Since(counter.lastSeen) > rl.maxTimePerDomain { + delete(rl.perDomain, domain) } - }() + } + rl.domainMux.Unlock() default: - } } diff --git a/internal/ratelimiter/ratelimit_test.go b/internal/ratelimiter/ratelimit_test.go index 631185b31..628d5e6b2 100644 --- a/internal/ratelimiter/ratelimit_test.go +++ b/internal/ratelimiter/ratelimit_test.go @@ -1 +1,54 @@ package ratelimiter + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func mockNow(tb testing.TB, now string) func() time.Time { + tb.Helper() + + return func() time.Time { + parsedT, err := time.Parse(time.RFC3339, now) + require.NoError(tb, err) + + return parsedT + } +} + +func TestDomainAllowed(t *testing.T) { + now := "2021-09-13T15:00:00Z" + + tcs := map[string]struct { + now string + domainRatePerSecond float64 + perDomainBurstPerSecond int + domain string + reqNum int + }{ + "some test": { + domainRatePerSecond: 1, // 1 per second + perDomainBurstPerSecond: 1, + reqNum: 1, + domain: "rate.gitlab.io", + }, + } + + for tn, tc := range tcs { + t.Run(tn, func(t *testing.T) { + rl := New( + WithNow(mockNow(t, now)), + WithDomainRatePerSecond(tc.domainRatePerSecond), + WithDomainBurstPerSecond(tc.perDomainBurstPerSecond), + ) + + for i := 0; i < tc.reqNum; i++ { + got := rl.DomainAllowed(tc.domain) + require.True(t, got, "req num: %d failed", i+1) + } + + }) + } +} -- GitLab From bcdd0a70895b99eb06000713519f418b928a7caf Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Mon, 13 Sep 2021 18:12:29 +1000 Subject: [PATCH 3/9] feat: add middleware and configure it --- app.go | 6 ++ internal/config/config.go | 18 +++--- internal/config/flags.go | 1 + internal/httperrors/httperrors.go | 12 ++++ internal/middleware/ratelimit.go | 28 +++++++++ internal/ratelimiter/ratelimit.go | 21 +++---- internal/ratelimiter/ratelimit_test.go | 84 +++++++++++++++++++++----- 7 files changed, 134 insertions(+), 36 deletions(-) create mode 100644 internal/middleware/ratelimit.go diff --git a/app.go b/app.go index 389ce0b47..f21263238 100644 --- a/app.go +++ b/app.go @@ -15,6 +15,8 @@ import ( "github.com/rs/cors" "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/go-mimedb" "gitlab.com/gitlab-org/labkit/correlation" "gitlab.com/gitlab-org/labkit/errortracking" @@ -337,6 +339,10 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { metricsMiddleware := labmetrics.NewHandlerFactory(labmetrics.WithNamespace("gitlab_pages")) handler = metricsMiddleware(handler) + if !a.config.General.DisableRateLimiter { + handler = middleware.DomainRateLimiter(ratelimiter.New())(handler) + } + handler = a.routingMiddleware(handler) // Health Check diff --git a/internal/config/config.go b/internal/config/config.go index 61767a57e..4ecbc4912 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,14 +45,15 @@ type Config struct { // General groups settings that are general to GitLab Pages and can not // be categorized under other head. type General struct { - Domain string - MaxConns int - MetricsAddress string - RedirectHTTP bool - RootCertificate []byte - RootDir string - RootKey []byte - StatusPath string + Domain string + DisableRateLimiter bool + MaxConns int + MetricsAddress string + RedirectHTTP bool + RootCertificate []byte + RootDir string + RootKey []byte + StatusPath string DisableCrossOriginRequests bool InsecureCiphers bool @@ -180,6 +181,7 @@ func loadConfig() (*Config, error) { config := &Config{ General: General{ Domain: strings.ToLower(*pagesDomain), + DisableRateLimiter: *disableRateLimiter, MaxConns: *maxConns, MetricsAddress: *metricsAddress, RedirectHTTP: *redirectHTTP, diff --git a/internal/config/flags.go b/internal/config/flags.go index aa5bf1c50..84adf8c77 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -12,6 +12,7 @@ var ( pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") + disableRateLimiter = flag.Bool("disable-rate-limiter", false, "Disable in-built rate limiter") _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") diff --git a/internal/httperrors/httperrors.go b/internal/httperrors/httperrors.go index ed56ee103..6aea01759 100644 --- a/internal/httperrors/httperrors.go +++ b/internal/httperrors/httperrors.go @@ -34,6 +34,13 @@ var (

Make sure the address is correct and that the page hasn't moved.

Please contact your GitLab administrator if you think this is a mistake.

`, } + content429 = content{ + http.StatusTooManyRequests, + "Too many requests (429)", + "429", + "Too many requests.", + `

The resource that you are attempting to access is being rate limited.

`, + } content500 = content{ http.StatusInternalServerError, "Something went wrong (500)", @@ -176,6 +183,11 @@ func Serve404(w http.ResponseWriter) { serveErrorPage(w, content404) } +// Serve429 returns a 429 error response / HTML page to the http.ResponseWriter +func Serve429(w http.ResponseWriter) { + serveErrorPage(w, content429) +} + // Serve500 returns a 500 error response / HTML page to the http.ResponseWriter func Serve500(w http.ResponseWriter) { serveErrorPage(w, content500) diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go new file mode 100644 index 000000000..cdd8c7f56 --- /dev/null +++ b/internal/middleware/ratelimit.go @@ -0,0 +1,28 @@ +package middleware + +import ( + "net/http" + + "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/gitlab-pages/internal/request" +) + +// DomainRateLimiter middleware ensures that the requested domain can be served by the current +// rate limit. See -rate-limiter +func DomainRateLimiter(rl *ratelimiter.RateLimiter) func(http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + d := request.GetDomain(r) + if d != nil { + if !rl.DomainAllowed(d.Name) { + //w.WriteHeader(http.StatusTooManyRequests) + httperrors.Serve429(w) + return + } + } + + handler.ServeHTTP(w, r) + }) + } +} diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go index 60e8f1524..5b039f1af 100644 --- a/internal/ratelimiter/ratelimit.go +++ b/internal/ratelimiter/ratelimit.go @@ -1,7 +1,6 @@ package ratelimiter import ( - "fmt" "sync" "time" @@ -15,10 +14,10 @@ const ( // DefaultMaxTimePerDomain is the maximum time to keep a domain in the rate limiter map DefaultMaxTimePerDomain = 30 * time.Second - // DefaultRatePerDomainPerSecond transformed to rate.Limit = 1 / DefaultRatePerDomainPerSecond. - // The default value is equivalent to 100 requests per second per domain - DefaultRatePerDomainPerSecond = 0.01 - // DefaultPerDomainMaxBurstPerSecond is the maximum burst in requests. TODO need to understand and test this + // DefaultRatePerDomainPerSecond the maximum number of requests per second to be allowed per domain + DefaultRatePerDomainPerSecond = 100 + // DefaultPerDomainMaxBurstPerSecond is the maximum burst in requests. It means the maximum number of requests + // at any given time, including DefaultRatePerDomainPerSecond DefaultPerDomainMaxBurstPerSecond = 100 ) @@ -95,7 +94,8 @@ func (rl *RateLimiter) getDomainCounter(domain string) *counter { if !ok { newCounter := &counter{ lastSeen: rl.now(), - limiter: rate.NewLimiter(rate.Limit(rl.domainRatePerSecond), rl.perDomainBurstPerSecond), + // the first argument is the number of requests per second that will be allowed, + limiter: rate.NewLimiter(rate.Limit(rl.domainRatePerSecond), rl.perDomainBurstPerSecond), } rl.perDomain[domain] = newCounter @@ -108,16 +108,11 @@ func (rl *RateLimiter) getDomainCounter(domain string) *counter { // DomainAllowed checks that the requested domain can be accessed within // the maxCountPerDomain in the given domainWindow. -func (rl *RateLimiter) DomainAllowed(domain string) (res bool) { - +func (rl *RateLimiter) DomainAllowed(domain string) bool { counter := rl.getDomainCounter(domain) - defer func() { - fmt.Printf("limiter info: limit: %f - burst: %d\n", counter.limiter.Limit(), counter.limiter.Burst()) - fmt.Printf("calling DomainAllowed for: %q returned: %t\n", domain, res) - }() // TODO: we could use Wait(ctx) if we want to moderate the request rate rather than denying requests - return counter.limiter.Allow() + return counter.limiter.AllowN(rl.now(), 1) } func (rl *RateLimiter) cleanup() { diff --git a/internal/ratelimiter/ratelimit_test.go b/internal/ratelimiter/ratelimit_test.go index 628d5e6b2..97445508c 100644 --- a/internal/ratelimiter/ratelimit_test.go +++ b/internal/ratelimiter/ratelimit_test.go @@ -1,26 +1,24 @@ package ratelimiter import ( + "fmt" "testing" "time" "github.com/stretchr/testify/require" ) -func mockNow(tb testing.TB, now string) func() time.Time { - tb.Helper() - - return func() time.Time { - parsedT, err := time.Parse(time.RFC3339, now) - require.NoError(tb, err) +var ( + now = "2021-09-13T15:00:00Z" + validTime, _ = time.Parse(time.RFC3339, now) +) - return parsedT - } +func mockNow() time.Time { + validTime = validTime.Add(time.Millisecond) + return validTime } func TestDomainAllowed(t *testing.T) { - now := "2021-09-13T15:00:00Z" - tcs := map[string]struct { now string domainRatePerSecond float64 @@ -28,10 +26,28 @@ func TestDomainAllowed(t *testing.T) { domain string reqNum int }{ - "some test": { + "one_request_per_second": { domainRatePerSecond: 1, // 1 per second perDomainBurstPerSecond: 1, - reqNum: 1, + reqNum: 2, + domain: "rate.gitlab.io", + }, + "one_request_per_second_but_big_bucket": { + domainRatePerSecond: 1, // 1 per second + perDomainBurstPerSecond: 10, + reqNum: 11, + domain: "rate.gitlab.io", + }, + "three_req_per_second_bucket_size_one": { + domainRatePerSecond: 3, // 3 per second + perDomainBurstPerSecond: 1, // max burst 1 means 1 at a time + reqNum: 3, + domain: "rate.gitlab.io", + }, + "10_requests_per_second": { + domainRatePerSecond: 10, + perDomainBurstPerSecond: 10, + reqNum: 11, domain: "rate.gitlab.io", }, } @@ -39,16 +55,54 @@ func TestDomainAllowed(t *testing.T) { for tn, tc := range tcs { t.Run(tn, func(t *testing.T) { rl := New( - WithNow(mockNow(t, now)), + WithNow(mockNow), WithDomainRatePerSecond(tc.domainRatePerSecond), WithDomainBurstPerSecond(tc.perDomainBurstPerSecond), ) for i := 0; i < tc.reqNum; i++ { got := rl.DomainAllowed(tc.domain) - require.True(t, got, "req num: %d failed", i+1) + if i < tc.perDomainBurstPerSecond { + require.Truef(t, got, "expected true for request no. %d", i+1) + } else { + require.False(t, got, "expected false for request no. %d", i+1) + } } - }) } } + +func TestDomainAllowedWitSleeps(t *testing.T) { + rate := 100.0 + fmt.Printf("what: %f\n", rate) + rl := New( + WithNow(mockNow), + WithDomainRatePerSecond(rate), + WithDomainBurstPerSecond(2), + ) + domain := "test.gitlab.io" + + t.Run("one request every millisecond with burst 1", func(t *testing.T) { + for i := 0; i < 10; i++ { + got := rl.DomainAllowed(domain) + require.Truef(t, got, "expected true for request no. %d", i+1) + time.Sleep(10 * time.Millisecond) + } + }) + + t.Run("requests start failing after reaching burst", func(t *testing.T) { + //now := mockNow() + for i := 0; i < 5; i++ { + got := rl.DomainAllowed(domain) + fmt.Printf("for:%d got: %t\n", i, got) + //require.True(t, true) + if i < 2 { + require.Truef(t, got, "expected true for request no. %d", i) + } else { + require.False(t, got, "expected false for request no. %d", i) + } + + time.Sleep(3 * time.Millisecond) + } + }) +} -- GitLab From f139c6df556f7dbe8b0777e7069c9b9db9ca1ace Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Mon, 13 Sep 2021 18:45:47 +1000 Subject: [PATCH 4/9] test: add acceptance test --- app.go | 6 +++++- internal/config/config.go | 2 ++ internal/config/flags.go | 3 +++ internal/middleware/ratelimit.go | 11 ++++------- internal/ratelimiter/ratelimit.go | 10 ++++++++-- test/acceptance/serving_test.go | 33 +++++++++++++++++++++++++++++++ 6 files changed, 55 insertions(+), 10 deletions(-) diff --git a/app.go b/app.go index f21263238..75df3c78e 100644 --- a/app.go +++ b/app.go @@ -340,7 +340,11 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { handler = metricsMiddleware(handler) if !a.config.General.DisableRateLimiter { - handler = middleware.DomainRateLimiter(ratelimiter.New())(handler) + handler = middleware.DomainRateLimiter( + ratelimiter.New( + ratelimiter.WithDomainRatePerSecond(a.config.General.RateLimitPerDomain), + ), + )(handler) } handler = a.routingMiddleware(handler) diff --git a/internal/config/config.go b/internal/config/config.go index 4ecbc4912..1f43cec59 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -47,6 +47,7 @@ type Config struct { type General struct { Domain string DisableRateLimiter bool + RateLimitPerDomain float64 MaxConns int MetricsAddress string RedirectHTTP bool @@ -182,6 +183,7 @@ func loadConfig() (*Config, error) { General: General{ Domain: strings.ToLower(*pagesDomain), DisableRateLimiter: *disableRateLimiter, + RateLimitPerDomain: *reqDomainPerSecond, MaxConns: *maxConns, MetricsAddress: *metricsAddress, RedirectHTTP: *redirectHTTP, diff --git a/internal/config/flags.go b/internal/config/flags.go index 84adf8c77..3c5c303a3 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -5,6 +5,8 @@ import ( "github.com/namsral/flag" + "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" + "gitlab.com/gitlab-org/gitlab-pages/internal/config/tls" ) @@ -13,6 +15,7 @@ var ( pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") disableRateLimiter = flag.Bool("disable-rate-limiter", false, "Disable in-built rate limiter") + reqDomainPerSecond = flag.Float64("req-domain-per-second", ratelimiter.DefaultRatePerDomainPerSecond, "Requests per domain limit per second") _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index cdd8c7f56..cebf194c3 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -13,13 +13,10 @@ import ( func DomainRateLimiter(rl *ratelimiter.RateLimiter) func(http.Handler) http.Handler { return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - d := request.GetDomain(r) - if d != nil { - if !rl.DomainAllowed(d.Name) { - //w.WriteHeader(http.StatusTooManyRequests) - httperrors.Serve429(w) - return - } + host := request.GetHostWithoutPort(r) + if !rl.DomainAllowed(host) { + httperrors.Serve429(w) + return } handler.ServeHTTP(w, r) diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go index 5b039f1af..20e1f8f7e 100644 --- a/internal/ratelimiter/ratelimit.go +++ b/internal/ratelimiter/ratelimit.go @@ -1,6 +1,7 @@ package ratelimiter import ( + "fmt" "sync" "time" @@ -108,9 +109,14 @@ func (rl *RateLimiter) getDomainCounter(domain string) *counter { // DomainAllowed checks that the requested domain can be accessed within // the maxCountPerDomain in the given domainWindow. -func (rl *RateLimiter) DomainAllowed(domain string) bool { - counter := rl.getDomainCounter(domain) +func (rl *RateLimiter) DomainAllowed(domain string) (res bool) { + defer func() { + fmt.Printf("was domain: %q allowed? - %t\n", domain, res) + }() + counter := rl.getDomainCounter(domain) + counter.limiter.Reserve() + fmt.Printf("COUNTER DETAILS? now: %s :limit: %f burst: %d\n", rl.now(), counter.limiter.Limit(), counter.limiter.Burst()) // TODO: we could use Wait(ctx) if we want to moderate the request rate rather than denying requests return counter.limiter.AllowN(rl.now(), 1) } diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index c6a7d3ef5..37e33c1f0 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -243,6 +243,39 @@ func TestCustomHeaders(t *testing.T) { } } +func TestRateLimitMiddleware(t *testing.T) { + RunPagesProcess(t, + withListeners([]ListenSpec{httpListener}), + // PerDomain 1 request per 1s, so making 2 requests in a row should fail + withExtraArgument("req-domain-per-second", "1"), + ) + + rsp1, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp1.Body.Close() + require.Equal(t, http.StatusOK, rsp1.StatusCode, "group.gitlab-example.com") + + // make another request right away should fail + rsp2, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp2.Body.Close() + require.Equal(t, http.StatusTooManyRequests, rsp2.StatusCode, "group.gitlab-example.com without waiting") + + // wait for ratelimiter to clear + time.Sleep(time.Second) + rsp3, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp3.Body.Close() + require.Equal(t, http.StatusOK, rsp3.StatusCode, "group.gitlab-example.com after waiting 1s") + + // request another domain + rsp4, err := GetPageFromListener(t, httpListener, "CapitalGroup.gitlab-example.com", "/") + require.NoError(t, err) + defer rsp4.Body.Close() + require.Equal(t, http.StatusOK, rsp4.StatusCode, "CapitalGroup.gitlab-example.com for another domain") + +} + func TestKnownHostWithPortReturns200(t *testing.T) { RunPagesProcess(t) -- GitLab From 39cabcc5e7ed97aef71519f1567712243ecaf306 Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Mon, 13 Sep 2021 20:56:22 +1000 Subject: [PATCH 5/9] test: make acceptance work --- app.go | 1 + internal/config/config.go | 4 ++- internal/config/flags.go | 3 +- internal/middleware/ratelimit.go | 2 +- internal/ratelimiter/ratelimit.go | 21 +++++++----- internal/ratelimiter/ratelimit_test.go | 36 ++++++++++----------- test/acceptance/ratelimit_test.go | 44 ++++++++++++++++++++++++++ test/acceptance/serving_test.go | 37 ++-------------------- 8 files changed, 84 insertions(+), 64 deletions(-) create mode 100644 test/acceptance/ratelimit_test.go diff --git a/app.go b/app.go index 75df3c78e..d10c7c5af 100644 --- a/app.go +++ b/app.go @@ -343,6 +343,7 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { handler = middleware.DomainRateLimiter( ratelimiter.New( ratelimiter.WithDomainRatePerSecond(a.config.General.RateLimitPerDomain), + ratelimiter.WithDomainBurstPerSecond(a.config.General.RateLimitMax), ), )(handler) } diff --git a/internal/config/config.go b/internal/config/config.go index 1f43cec59..ee1891491 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -47,7 +47,8 @@ type Config struct { type General struct { Domain string DisableRateLimiter bool - RateLimitPerDomain float64 + RateLimitPerDomain time.Duration + RateLimitMax int MaxConns int MetricsAddress string RedirectHTTP bool @@ -184,6 +185,7 @@ func loadConfig() (*Config, error) { Domain: strings.ToLower(*pagesDomain), DisableRateLimiter: *disableRateLimiter, RateLimitPerDomain: *reqDomainPerSecond, + RateLimitMax: *reqDomainBucketSize, MaxConns: *maxConns, MetricsAddress: *metricsAddress, RedirectHTTP: *redirectHTTP, diff --git a/internal/config/flags.go b/internal/config/flags.go index 3c5c303a3..5385f8a9b 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -15,7 +15,8 @@ var ( pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") disableRateLimiter = flag.Bool("disable-rate-limiter", false, "Disable in-built rate limiter") - reqDomainPerSecond = flag.Float64("req-domain-per-second", ratelimiter.DefaultRatePerDomainPerSecond, "Requests per domain limit per second") + reqDomainPerSecond = flag.Duration("req-domain-per-second", ratelimiter.DefaultRatePerDomainPerSecond, "Requests per domain limit per second") + reqDomainBucketSize = flag.Int("req-domain-bucket-size", ratelimiter.DefaultPerDomainMaxBurstPerSecond, "Bucket size = max number of tokens held by the limiter") _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index cebf194c3..954656ebd 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -14,7 +14,7 @@ func DomainRateLimiter(rl *ratelimiter.RateLimiter) func(http.Handler) http.Hand return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { host := request.GetHostWithoutPort(r) - if !rl.DomainAllowed(host) { + if host != "127.0.0.1" && !rl.DomainAllowed(host) { httperrors.Serve429(w) return } diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go index 20e1f8f7e..3ec80f74a 100644 --- a/internal/ratelimiter/ratelimit.go +++ b/internal/ratelimiter/ratelimit.go @@ -15,11 +15,15 @@ const ( // DefaultMaxTimePerDomain is the maximum time to keep a domain in the rate limiter map DefaultMaxTimePerDomain = 30 * time.Second + //example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html + // DefaultRatePerDomainPerSecond the maximum number of requests per second to be allowed per domain - DefaultRatePerDomainPerSecond = 100 - // DefaultPerDomainMaxBurstPerSecond is the maximum burst in requests. It means the maximum number of requests - // at any given time, including DefaultRatePerDomainPerSecond - DefaultPerDomainMaxBurstPerSecond = 100 + // 1 request every 25ms = 40 rps + DefaultRatePerDomainPerSecond = 25 * time.Millisecond + // DefaultPerDomainMaxBurstPerSecond is the maximum burst allowed per rate limiter + // 40 items in the bucket is the max + // so if there are 40 rquests in 25 milliseconds they will succeed, but request 41st will fail + DefaultPerDomainMaxBurstPerSecond = 40 ) type counter struct { @@ -34,7 +38,7 @@ type RateLimiter struct { now func() time.Time cleanupTimer *time.Ticker maxTimePerDomain time.Duration - domainRatePerSecond float64 + domainRatePerSecond time.Duration perDomainBurstPerSecond int domainMux *sync.RWMutex // TODO: this could be an LRU cache like what we do in the zip VFS instead of cleaning manually ? @@ -74,9 +78,9 @@ func WithCleanupInterval(d time.Duration) Option { } } -func WithDomainRatePerSecond(r float64) Option { +func WithDomainRatePerSecond(d time.Duration) Option { return func(rl *RateLimiter) { - rl.domainRatePerSecond = r + rl.domainRatePerSecond = d } } @@ -96,7 +100,7 @@ func (rl *RateLimiter) getDomainCounter(domain string) *counter { newCounter := &counter{ lastSeen: rl.now(), // the first argument is the number of requests per second that will be allowed, - limiter: rate.NewLimiter(rate.Limit(rl.domainRatePerSecond), rl.perDomainBurstPerSecond), + limiter: rate.NewLimiter(rate.Every(rl.domainRatePerSecond), rl.perDomainBurstPerSecond), } rl.perDomain[domain] = newCounter @@ -117,6 +121,7 @@ func (rl *RateLimiter) DomainAllowed(domain string) (res bool) { counter := rl.getDomainCounter(domain) counter.limiter.Reserve() fmt.Printf("COUNTER DETAILS? now: %s :limit: %f burst: %d\n", rl.now(), counter.limiter.Limit(), counter.limiter.Burst()) + counter.limiter.Burst() // TODO: we could use Wait(ctx) if we want to moderate the request rate rather than denying requests return counter.limiter.AllowN(rl.now(), 1) } diff --git a/internal/ratelimiter/ratelimit_test.go b/internal/ratelimiter/ratelimit_test.go index 97445508c..8a00c1b47 100644 --- a/internal/ratelimiter/ratelimit_test.go +++ b/internal/ratelimiter/ratelimit_test.go @@ -21,31 +21,31 @@ func mockNow() time.Time { func TestDomainAllowed(t *testing.T) { tcs := map[string]struct { now string - domainRatePerSecond float64 + domainRate time.Duration perDomainBurstPerSecond int domain string reqNum int }{ "one_request_per_second": { - domainRatePerSecond: 1, // 1 per second + domainRate: 1, // 1 per second perDomainBurstPerSecond: 1, reqNum: 2, domain: "rate.gitlab.io", }, "one_request_per_second_but_big_bucket": { - domainRatePerSecond: 1, // 1 per second + domainRate: 1, // 1 per second perDomainBurstPerSecond: 10, reqNum: 11, domain: "rate.gitlab.io", }, "three_req_per_second_bucket_size_one": { - domainRatePerSecond: 3, // 3 per second + domainRate: 3, // 3 per second perDomainBurstPerSecond: 1, // max burst 1 means 1 at a time reqNum: 3, domain: "rate.gitlab.io", }, "10_requests_per_second": { - domainRatePerSecond: 10, + domainRate: 10, perDomainBurstPerSecond: 10, reqNum: 11, domain: "rate.gitlab.io", @@ -56,7 +56,7 @@ func TestDomainAllowed(t *testing.T) { t.Run(tn, func(t *testing.T) { rl := New( WithNow(mockNow), - WithDomainRatePerSecond(tc.domainRatePerSecond), + WithDomainRatePerSecond(tc.domainRate), WithDomainBurstPerSecond(tc.perDomainBurstPerSecond), ) @@ -73,12 +73,11 @@ func TestDomainAllowed(t *testing.T) { } func TestDomainAllowedWitSleeps(t *testing.T) { - rate := 100.0 - fmt.Printf("what: %f\n", rate) + rate := 10 * time.Millisecond rl := New( WithNow(mockNow), WithDomainRatePerSecond(rate), - WithDomainBurstPerSecond(2), + WithDomainBurstPerSecond(1), ) domain := "test.gitlab.io" @@ -86,23 +85,22 @@ func TestDomainAllowedWitSleeps(t *testing.T) { for i := 0; i < 10; i++ { got := rl.DomainAllowed(domain) require.Truef(t, got, "expected true for request no. %d", i+1) - time.Sleep(10 * time.Millisecond) + time.Sleep(rate) } }) t.Run("requests start failing after reaching burst", func(t *testing.T) { //now := mockNow() - for i := 0; i < 5; i++ { - got := rl.DomainAllowed(domain) + for i := 0; i < 10; i++ { + got := rl.DomainAllowed(domain + ".diff") fmt.Printf("for:%d got: %t\n", i, got) //require.True(t, true) - if i < 2 { - require.Truef(t, got, "expected true for request no. %d", i) - } else { - require.False(t, got, "expected false for request no. %d", i) - } - - time.Sleep(3 * time.Millisecond) + //if i < 2 { + require.Truef(t, got, "expected true for request no. %d", i) + //} else { + // require.False(t, got, "expected false for request no. %d", i) + //} + time.Sleep(time.Nanosecond) } }) } diff --git a/test/acceptance/ratelimit_test.go b/test/acceptance/ratelimit_test.go new file mode 100644 index 000000000..cd56f5dc6 --- /dev/null +++ b/test/acceptance/ratelimit_test.go @@ -0,0 +1,44 @@ +package acceptance_test + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateLimitMiddleware(t *testing.T) { + RunPagesProcess(t, + withListeners([]ListenSpec{httpListener}), + // allows a max of 2 tokens every 10ms -> 2rp 10ms + withExtraArgument("req-domain-per-second", "100ms"), + withExtraArgument("req-domain-bucket-size", "2"), + ) + + rsp1, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp1.Body.Close() + + // make another request right away should fail + rsp2, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp2.Body.Close() + + // wait for ratelimiter to clear + time.Sleep(300 * time.Millisecond) + + rsp3, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp3.Body.Close() + + // request another domain + rsp4, err := GetPageFromListener(t, httpListener, "CapitalGroup.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp4.Body.Close() + + require.Equal(t, http.StatusOK, rsp1.StatusCode, "group.gitlab-example.com first request") + require.Equal(t, http.StatusTooManyRequests, rsp2.StatusCode, "group.gitlab-example.com without waiting") + require.Equal(t, http.StatusOK, rsp3.StatusCode, "rsp3 group.gitlab-example.com after waiting 1s") + require.Equal(t, http.StatusOK, rsp4.StatusCode, "CapitalGroup.gitlab-example.com for another domain") +} diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index 37e33c1f0..dd6e92868 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -180,7 +180,9 @@ func TestCORSWhenDisabled(t *testing.T) { } func TestCORSAllowsMethod(t *testing.T) { - RunPagesProcess(t) + RunPagesProcess(t, + withExtraArgument("disable-rate-limiter", "true"), + ) tests := []struct { name string @@ -243,39 +245,6 @@ func TestCustomHeaders(t *testing.T) { } } -func TestRateLimitMiddleware(t *testing.T) { - RunPagesProcess(t, - withListeners([]ListenSpec{httpListener}), - // PerDomain 1 request per 1s, so making 2 requests in a row should fail - withExtraArgument("req-domain-per-second", "1"), - ) - - rsp1, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp1.Body.Close() - require.Equal(t, http.StatusOK, rsp1.StatusCode, "group.gitlab-example.com") - - // make another request right away should fail - rsp2, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp2.Body.Close() - require.Equal(t, http.StatusTooManyRequests, rsp2.StatusCode, "group.gitlab-example.com without waiting") - - // wait for ratelimiter to clear - time.Sleep(time.Second) - rsp3, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp3.Body.Close() - require.Equal(t, http.StatusOK, rsp3.StatusCode, "group.gitlab-example.com after waiting 1s") - - // request another domain - rsp4, err := GetPageFromListener(t, httpListener, "CapitalGroup.gitlab-example.com", "/") - require.NoError(t, err) - defer rsp4.Body.Close() - require.Equal(t, http.StatusOK, rsp4.StatusCode, "CapitalGroup.gitlab-example.com for another domain") - -} - func TestKnownHostWithPortReturns200(t *testing.T) { RunPagesProcess(t) -- GitLab From 90bd58ebebdfad76d406fce0fc1ce00fa9f279c2 Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Tue, 14 Sep 2021 11:50:38 +1000 Subject: [PATCH 6/9] test: make test clearer --- test/acceptance/ratelimit_test.go | 50 ++++++++++++++----------------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/test/acceptance/ratelimit_test.go b/test/acceptance/ratelimit_test.go index cd56f5dc6..e7fc59a7a 100644 --- a/test/acceptance/ratelimit_test.go +++ b/test/acceptance/ratelimit_test.go @@ -1,6 +1,7 @@ package acceptance_test import ( + "fmt" "net/http" "testing" "time" @@ -11,34 +12,27 @@ import ( func TestRateLimitMiddleware(t *testing.T) { RunPagesProcess(t, withListeners([]ListenSpec{httpListener}), - // allows a max of 2 tokens every 10ms -> 2rp 10ms - withExtraArgument("req-domain-per-second", "100ms"), - withExtraArgument("req-domain-bucket-size", "2"), + //refills 1 token every 50ms, bound by the burst/bucket size + withExtraArgument("req-domain-per-second", "50ms"), + // allows a max of 10 tokens at a time PER SECOND + withExtraArgument("req-domain-bucket-size", "10"), ) - rsp1, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp1.Body.Close() - - // make another request right away should fail - rsp2, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp2.Body.Close() - - // wait for ratelimiter to clear - time.Sleep(300 * time.Millisecond) - - rsp3, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp3.Body.Close() - - // request another domain - rsp4, err := GetPageFromListener(t, httpListener, "CapitalGroup.gitlab-example.com", "project/") - require.NoError(t, err) - defer rsp4.Body.Close() - - require.Equal(t, http.StatusOK, rsp1.StatusCode, "group.gitlab-example.com first request") - require.Equal(t, http.StatusTooManyRequests, rsp2.StatusCode, "group.gitlab-example.com without waiting") - require.Equal(t, http.StatusOK, rsp3.StatusCode, "rsp3 group.gitlab-example.com after waiting 1s") - require.Equal(t, http.StatusOK, rsp4.StatusCode, "CapitalGroup.gitlab-example.com for another domain") + for i := 0; i < 20; i++ { + rsp1, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") + require.NoError(t, err) + defer rsp1.Body.Close() + fmt.Printf("req: %d - status: %d\n", i, rsp1.StatusCode) + + // every ~10th request should fail + if (i+1)%10 == 0 { + require.Equal(t, http.StatusTooManyRequests, rsp1.StatusCode, "group.gitlab-example.com request: %d failed", i) + time.Sleep(500 * time.Millisecond) + continue + } + + require.Equal(t, http.StatusOK, rsp1.StatusCode, "group.gitlab-example.com request: %d failed", i) + // sleep almost close to req-domain-per-second + time.Sleep(49 * time.Millisecond) + } } -- GitLab From d2c7e3aac5cf3634a5fa24f250c94841e31a157f Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Tue, 14 Sep 2021 15:54:41 +1000 Subject: [PATCH 7/9] chore: move things around and cleanup --- app.go | 8 +- internal/config/config.go | 50 ++++++------- internal/config/flags.go | 68 ++++++++--------- .../middleware.go} | 19 +++-- internal/ratelimiter/ratelimit.go | 74 +++++++++---------- internal/ratelimiter/ratelimit_test.go | 32 ++++---- test/acceptance/ratelimit_test.go | 23 +++--- test/acceptance/serving_test.go | 4 +- 8 files changed, 139 insertions(+), 139 deletions(-) rename internal/{middleware/ratelimit.go => ratelimiter/middleware.go} (64%) diff --git a/app.go b/app.go index d10c7c5af..a6e8e6ffb 100644 --- a/app.go +++ b/app.go @@ -339,11 +339,11 @@ func (a *theApp) buildHandlerPipeline() (http.Handler, error) { metricsMiddleware := labmetrics.NewHandlerFactory(labmetrics.WithNamespace("gitlab_pages")) handler = metricsMiddleware(handler) - if !a.config.General.DisableRateLimiter { - handler = middleware.DomainRateLimiter( + if a.config.General.EnableRateLimiter { + handler = ratelimiter.DomainRateLimiter( ratelimiter.New( - ratelimiter.WithDomainRatePerSecond(a.config.General.RateLimitPerDomain), - ratelimiter.WithDomainBurstPerSecond(a.config.General.RateLimitMax), + ratelimiter.WithPerDomainFrequency(a.config.General.RateLimitPerDomainFrequency), + ratelimiter.WithPerDomainBurstSize(a.config.General.RateLimitPerDomainBurstSize), ), )(handler) } diff --git a/internal/config/config.go b/internal/config/config.go index ee1891491..d2ed2363c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,17 +45,17 @@ type Config struct { // General groups settings that are general to GitLab Pages and can not // be categorized under other head. type General struct { - Domain string - DisableRateLimiter bool - RateLimitPerDomain time.Duration - RateLimitMax int - MaxConns int - MetricsAddress string - RedirectHTTP bool - RootCertificate []byte - RootDir string - RootKey []byte - StatusPath string + Domain string + EnableRateLimiter bool + RateLimitPerDomainFrequency time.Duration + RateLimitPerDomainBurstSize int + MaxConns int + MetricsAddress string + RedirectHTTP bool + RootCertificate []byte + RootDir string + RootKey []byte + StatusPath string DisableCrossOriginRequests bool InsecureCiphers bool @@ -182,20 +182,20 @@ func setGitLabAPISecretKey(secretFile string, config *Config) error { func loadConfig() (*Config, error) { config := &Config{ General: General{ - Domain: strings.ToLower(*pagesDomain), - DisableRateLimiter: *disableRateLimiter, - RateLimitPerDomain: *reqDomainPerSecond, - RateLimitMax: *reqDomainBucketSize, - MaxConns: *maxConns, - MetricsAddress: *metricsAddress, - RedirectHTTP: *redirectHTTP, - RootDir: *pagesRoot, - StatusPath: *pagesStatus, - DisableCrossOriginRequests: *disableCrossOriginRequests, - InsecureCiphers: *insecureCiphers, - PropagateCorrelationID: *propagateCorrelationID, - CustomHeaders: header.Split(), - ShowVersion: *showVersion, + Domain: strings.ToLower(*pagesDomain), + EnableRateLimiter: *enableRateLimiter, + RateLimitPerDomainFrequency: *rateLimitPerDomain, + RateLimitPerDomainBurstSize: *rateLimitPerDomainBurstSize, + MaxConns: *maxConns, + MetricsAddress: *metricsAddress, + RedirectHTTP: *redirectHTTP, + RootDir: *pagesRoot, + StatusPath: *pagesStatus, + DisableCrossOriginRequests: *disableCrossOriginRequests, + InsecureCiphers: *insecureCiphers, + PropagateCorrelationID: *propagateCorrelationID, + CustomHeaders: header.Split(), + ShowVersion: *showVersion, }, GitLab: GitLab{ ClientHTTPTimeout: *gitlabClientHTTPTimeout, diff --git a/internal/config/flags.go b/internal/config/flags.go index 5385f8a9b..7e2781be9 100644 --- a/internal/config/flags.go +++ b/internal/config/flags.go @@ -11,40 +11,40 @@ import ( ) var ( - pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") - pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") - redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") - disableRateLimiter = flag.Bool("disable-rate-limiter", false, "Disable in-built rate limiter") - reqDomainPerSecond = flag.Duration("req-domain-per-second", ratelimiter.DefaultRatePerDomainPerSecond, "Requests per domain limit per second") - reqDomainBucketSize = flag.Int("req-domain-bucket-size", ratelimiter.DefaultPerDomainMaxBurstPerSecond, "Bucket size = max number of tokens held by the limiter") - _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") - pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") - pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") - artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") - artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") - pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") - metricsAddress = flag.String("metrics-address", "", "The address to listen on for metrics requests") - sentryDSN = flag.String("sentry-dsn", "", "The address for sending sentry crash reporting to") - sentryEnvironment = flag.String("sentry-environment", "", "The environment for sentry crash reporting") - daemonUID = flag.Uint("daemon-uid", 0, "Drop privileges to this user") - daemonGID = flag.Uint("daemon-gid", 0, "Drop privileges to this group") - _ = flag.Bool("daemon-enable-jail", false, "DEPRECATED and ignored, will be removed in 15.0") - _ = flag.Bool("daemon-inplace-chroot", false, "DEPRECATED and ignored, will be removed in 15.0") // TODO: https://gitlab.com/gitlab-org/gitlab-pages/-/issues/599 - propagateCorrelationID = flag.Bool("propagate-correlation-id", false, "Reuse existing Correlation-ID from the incoming request header `X-Request-ID` if present") - logFormat = flag.String("log-format", "json", "The log output format: 'text' or 'json'") - logVerbose = flag.Bool("log-verbose", false, "Verbose logging") - secret = flag.String("auth-secret", "", "Cookie store hash key, should be at least 32 bytes long") - publicGitLabServer = flag.String("gitlab-server", "", "Public GitLab server, for example https://www.gitlab.com") - internalGitLabServer = flag.String("internal-gitlab-server", "", "Internal GitLab server used for API requests, useful if you want to send that traffic over an internal load balancer, example value https://gitlab.example.internal (defaults to value of gitlab-server)") - gitLabAPISecretKey = flag.String("api-secret-key", "", "File with secret key used to authenticate with the GitLab API") - gitlabClientHTTPTimeout = flag.Duration("gitlab-client-http-timeout", 10*time.Second, "GitLab API HTTP client connection timeout in seconds (default: 10s)") - gitlabClientJWTExpiry = flag.Duration("gitlab-client-jwt-expiry", 30*time.Second, "JWT Token expiry time in seconds (default: 30s)") - gitlabCacheExpiry = flag.Duration("gitlab-cache-expiry", 10*time.Minute, "The maximum time a domain's configuration is stored in the cache") - gitlabCacheRefresh = flag.Duration("gitlab-cache-refresh", time.Minute, "The interval at which a domain's configuration is set to be due to refresh") - gitlabCacheCleanup = flag.Duration("gitlab-cache-cleanup", time.Minute, "The interval at which expired items are removed from the cache") - gitlabRetrievalTimeout = flag.Duration("gitlab-retrieval-timeout", 30*time.Second, "The maximum time to wait for a response from the GitLab API per request") - gitlabRetrievalInterval = flag.Duration("gitlab-retrieval-interval", time.Second, "The interval to wait before retrying to resolve a domain's configuration via the GitLab API") - gitlabRetrievalRetries = flag.Int("gitlab-retrieval-retries", 3, "The maximum number of times to retry to resolve a domain's configuration via the API") + pagesRootCert = flag.String("root-cert", "", "The default path to file certificate to serve static pages") + pagesRootKey = flag.String("root-key", "", "The default path to file certificate to serve static pages") + redirectHTTP = flag.Bool("redirect-http", false, "Redirect pages from HTTP to HTTPS") + enableRateLimiter = flag.Bool("enable-rate-limiter", false, "Enable in-built rate limiter") + rateLimitPerDomain = flag.Duration("rate-limit-per-domain", ratelimiter.DefaultPerDomainFrequency, "Rate of allowed requests as duration. E.g. 25ms allows 1 request every 25ms or 40 requests per second") + rateLimitPerDomainBurstSize = flag.Int("rate-limit-per-domain-burst-size", ratelimiter.DefaultPerDomainBurstSize, "Maximum number of requests to allow in a rate-limit-per-domain period. E.g. setting value to 10 will allow up to 10 requests within 25ms, then only 1 req per 25ms will be allowed.") + _ = flag.Bool("use-http2", true, "DEPRECATED: HTTP2 is always enabled for pages") + pagesRoot = flag.String("pages-root", "shared/pages", "The directory where pages are stored") + pagesDomain = flag.String("pages-domain", "gitlab-example.com", "The domain to serve static pages") + artifactsServer = flag.String("artifacts-server", "", "API URL to proxy artifact requests to, e.g.: 'https://gitlab.com/api/v4'") + artifactsServerTimeout = flag.Int("artifacts-server-timeout", 10, "Timeout (in seconds) for a proxied request to the artifacts server") + pagesStatus = flag.String("pages-status", "", "The url path for a status page, e.g., /@status") + metricsAddress = flag.String("metrics-address", "", "The address to listen on for metrics requests") + sentryDSN = flag.String("sentry-dsn", "", "The address for sending sentry crash reporting to") + sentryEnvironment = flag.String("sentry-environment", "", "The environment for sentry crash reporting") + daemonUID = flag.Uint("daemon-uid", 0, "Drop privileges to this user") + daemonGID = flag.Uint("daemon-gid", 0, "Drop privileges to this group") + _ = flag.Bool("daemon-enable-jail", false, "DEPRECATED and ignored, will be removed in 15.0") + _ = flag.Bool("daemon-inplace-chroot", false, "DEPRECATED and ignored, will be removed in 15.0") // TODO: https://gitlab.com/gitlab-org/gitlab-pages/-/issues/599 + propagateCorrelationID = flag.Bool("propagate-correlation-id", false, "Reuse existing Correlation-ID from the incoming request header `X-Request-ID` if present") + logFormat = flag.String("log-format", "json", "The log output format: 'text' or 'json'") + logVerbose = flag.Bool("log-verbose", false, "Verbose logging") + secret = flag.String("auth-secret", "", "Cookie store hash key, should be at least 32 bytes long") + publicGitLabServer = flag.String("gitlab-server", "", "Public GitLab server, for example https://www.gitlab.com") + internalGitLabServer = flag.String("internal-gitlab-server", "", "Internal GitLab server used for API requests, useful if you want to send that traffic over an internal load balancer, example value https://gitlab.example.internal (defaults to value of gitlab-server)") + gitLabAPISecretKey = flag.String("api-secret-key", "", "File with secret key used to authenticate with the GitLab API") + gitlabClientHTTPTimeout = flag.Duration("gitlab-client-http-timeout", 10*time.Second, "GitLab API HTTP client connection timeout in seconds (default: 10s)") + gitlabClientJWTExpiry = flag.Duration("gitlab-client-jwt-expiry", 30*time.Second, "JWT Token expiry time in seconds (default: 30s)") + gitlabCacheExpiry = flag.Duration("gitlab-cache-expiry", 10*time.Minute, "The maximum time a domain's configuration is stored in the cache") + gitlabCacheRefresh = flag.Duration("gitlab-cache-refresh", time.Minute, "The interval at which a domain's configuration is set to be due to refresh") + gitlabCacheCleanup = flag.Duration("gitlab-cache-cleanup", time.Minute, "The interval at which expired items are removed from the cache") + gitlabRetrievalTimeout = flag.Duration("gitlab-retrieval-timeout", 30*time.Second, "The maximum time to wait for a response from the GitLab API per request") + gitlabRetrievalInterval = flag.Duration("gitlab-retrieval-interval", time.Second, "The interval to wait before retrying to resolve a domain's configuration via the GitLab API") + gitlabRetrievalRetries = flag.Int("gitlab-retrieval-retries", 3, "The maximum number of times to retry to resolve a domain's configuration via the API") _ = flag.String("domain-config-source", "gitlab", "DEPRECATED and has not affect, see https://gitlab.com/gitlab-org/gitlab-pages/-/merge_requests/541") enableDisk = flag.Bool("enable-disk", true, "Enable disk access, shall be disabled in environments where shared disk storage isn't available") diff --git a/internal/middleware/ratelimit.go b/internal/ratelimiter/middleware.go similarity index 64% rename from internal/middleware/ratelimit.go rename to internal/ratelimiter/middleware.go index 954656ebd..f78ecd99c 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/ratelimiter/middleware.go @@ -1,19 +1,19 @@ -package middleware +package ratelimiter import ( + "net" "net/http" "gitlab.com/gitlab-org/gitlab-pages/internal/httperrors" - "gitlab.com/gitlab-org/gitlab-pages/internal/ratelimiter" - "gitlab.com/gitlab-org/gitlab-pages/internal/request" ) // DomainRateLimiter middleware ensures that the requested domain can be served by the current // rate limit. See -rate-limiter -func DomainRateLimiter(rl *ratelimiter.RateLimiter) func(http.Handler) http.Handler { +func DomainRateLimiter(rl *RateLimiter) func(http.Handler) http.Handler { return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - host := request.GetHostWithoutPort(r) + host := getHost(r) + if host != "127.0.0.1" && !rl.DomainAllowed(host) { httperrors.Serve429(w) return @@ -23,3 +23,12 @@ func DomainRateLimiter(rl *ratelimiter.RateLimiter) func(http.Handler) http.Hand }) } } + +func getHost(r *http.Request) string { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + + return host +} diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go index 3ec80f74a..4494e6386 100644 --- a/internal/ratelimiter/ratelimit.go +++ b/internal/ratelimiter/ratelimit.go @@ -1,7 +1,6 @@ package ratelimiter import ( - "fmt" "sync" "time" @@ -15,15 +14,13 @@ const ( // DefaultMaxTimePerDomain is the maximum time to keep a domain in the rate limiter map DefaultMaxTimePerDomain = 30 * time.Second - //example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html - - // DefaultRatePerDomainPerSecond the maximum number of requests per second to be allowed per domain - // 1 request every 25ms = 40 rps - DefaultRatePerDomainPerSecond = 25 * time.Millisecond - // DefaultPerDomainMaxBurstPerSecond is the maximum burst allowed per rate limiter - // 40 items in the bucket is the max - // so if there are 40 rquests in 25 milliseconds they will succeed, but request 41st will fail - DefaultPerDomainMaxBurstPerSecond = 40 + // DefaultPerDomainFrequency the maximum number of requests per second to be allowed per domain. + // The default value of 25ms equals 1 request every 25ms -> 40 rps + DefaultPerDomainFrequency = 25 * time.Millisecond + // DefaultPerDomainBurstSize is the maximum burst allowed per rate limiter + // E.g. The first 40 requests within 25ms will succeed, but the 41st will fail until the next + // refill occurs at DefaultPerDomainFrequency, allowing only 1 request per rate frequency. + DefaultPerDomainBurstSize = 40 ) type counter struct { @@ -34,27 +31,32 @@ type counter struct { // Option function to configure a RateLimiter type Option func(*RateLimiter) +// RateLimiter holds a map ot domain names with counters that enable rate limiting per domain. +// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per domain entry. +// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html +// Cleanup runs every cleanupTimer iteration over all domains and removing them if +// the time since counter.lastSeen is greater than the domainMaxTTL. type RateLimiter struct { - now func() time.Time - cleanupTimer *time.Ticker - maxTimePerDomain time.Duration - domainRatePerSecond time.Duration - perDomainBurstPerSecond int - domainMux *sync.RWMutex + now func() time.Time + cleanupTimer *time.Ticker + domainMaxTTL time.Duration + perDomainFrequency time.Duration + perDomainBurstSize int + domainMux *sync.RWMutex // TODO: this could be an LRU cache like what we do in the zip VFS instead of cleaning manually ? perDomain map[string]*counter } -// New creates a new RateLimiter with default values +// New creates a new RateLimiter with default values that can be configured via Option functions func New(opts ...Option) *RateLimiter { rl := &RateLimiter{ - now: time.Now, - cleanupTimer: time.NewTicker(DefaultCleanupInterval), - maxTimePerDomain: DefaultMaxTimePerDomain, - domainRatePerSecond: DefaultRatePerDomainPerSecond, - perDomainBurstPerSecond: DefaultPerDomainMaxBurstPerSecond, - domainMux: &sync.RWMutex{}, - perDomain: make(map[string]*counter), + now: time.Now, + cleanupTimer: time.NewTicker(DefaultCleanupInterval), + domainMaxTTL: DefaultMaxTimePerDomain, + perDomainFrequency: DefaultPerDomainFrequency, + perDomainBurstSize: DefaultPerDomainBurstSize, + domainMux: &sync.RWMutex{}, + perDomain: make(map[string]*counter), } for _, opt := range opts { @@ -66,27 +68,31 @@ func New(opts ...Option) *RateLimiter { return rl } +// WithNow replaces the RateLimiter now function func WithNow(now func() time.Time) Option { return func(rl *RateLimiter) { rl.now = now } } +// WithCleanupInterval replaces the RateLimiter cleanup interval func WithCleanupInterval(d time.Duration) Option { return func(rl *RateLimiter) { rl.cleanupTimer.Reset(d) } } -func WithDomainRatePerSecond(d time.Duration) Option { +// WithPerDomainFrequency allows configuring perDomain frequency for the RateLimiter +func WithPerDomainFrequency(d time.Duration) Option { return func(rl *RateLimiter) { - rl.domainRatePerSecond = d + rl.perDomainFrequency = d } } -func WithDomainBurstPerSecond(burst int) Option { +// WithPerDomainBurstSize configures burst per domain for the RateLimiter +func WithPerDomainBurstSize(burst int) Option { return func(rl *RateLimiter) { - rl.perDomainBurstPerSecond = burst + rl.perDomainBurstSize = burst } } @@ -100,7 +106,7 @@ func (rl *RateLimiter) getDomainCounter(domain string) *counter { newCounter := &counter{ lastSeen: rl.now(), // the first argument is the number of requests per second that will be allowed, - limiter: rate.NewLimiter(rate.Every(rl.domainRatePerSecond), rl.perDomainBurstPerSecond), + limiter: rate.NewLimiter(rate.Every(rl.perDomainFrequency), rl.perDomainBurstSize), } rl.perDomain[domain] = newCounter @@ -114,14 +120,8 @@ func (rl *RateLimiter) getDomainCounter(domain string) *counter { // DomainAllowed checks that the requested domain can be accessed within // the maxCountPerDomain in the given domainWindow. func (rl *RateLimiter) DomainAllowed(domain string) (res bool) { - defer func() { - fmt.Printf("was domain: %q allowed? - %t\n", domain, res) - }() - counter := rl.getDomainCounter(domain) - counter.limiter.Reserve() - fmt.Printf("COUNTER DETAILS? now: %s :limit: %f burst: %d\n", rl.now(), counter.limiter.Limit(), counter.limiter.Burst()) - counter.limiter.Burst() + // TODO: we could use Wait(ctx) if we want to moderate the request rate rather than denying requests return counter.limiter.AllowN(rl.now(), 1) } @@ -133,7 +133,7 @@ func (rl *RateLimiter) cleanup() { rl.domainMux.Lock() for domain, counter := range rl.perDomain { - if time.Since(counter.lastSeen) > rl.maxTimePerDomain { + if time.Since(counter.lastSeen) > rl.domainMaxTTL { delete(rl.perDomain, domain) } } diff --git a/internal/ratelimiter/ratelimit_test.go b/internal/ratelimiter/ratelimit_test.go index 8a00c1b47..73ba9e030 100644 --- a/internal/ratelimiter/ratelimit_test.go +++ b/internal/ratelimiter/ratelimit_test.go @@ -1,10 +1,10 @@ package ratelimiter import ( - "fmt" "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -14,7 +14,6 @@ var ( ) func mockNow() time.Time { - validTime = validTime.Add(time.Millisecond) return validTime } @@ -56,8 +55,8 @@ func TestDomainAllowed(t *testing.T) { t.Run(tn, func(t *testing.T) { rl := New( WithNow(mockNow), - WithDomainRatePerSecond(tc.domainRate), - WithDomainBurstPerSecond(tc.perDomainBurstPerSecond), + WithPerDomainFrequency(tc.domainRate), + WithPerDomainBurstSize(tc.perDomainBurstPerSecond), ) for i := 0; i < tc.reqNum; i++ { @@ -75,32 +74,27 @@ func TestDomainAllowed(t *testing.T) { func TestDomainAllowedWitSleeps(t *testing.T) { rate := 10 * time.Millisecond rl := New( - WithNow(mockNow), - WithDomainRatePerSecond(rate), - WithDomainBurstPerSecond(1), + WithPerDomainFrequency(rate), + WithPerDomainBurstSize(1), ) domain := "test.gitlab.io" - t.Run("one request every millisecond with burst 1", func(t *testing.T) { + t.Run("one request every 10ms with burst 1", func(t *testing.T) { for i := 0; i < 10; i++ { got := rl.DomainAllowed(domain) - require.Truef(t, got, "expected true for request no. %d", i+1) + assert.Truef(t, got, "expected true for request no. %d", i+1) time.Sleep(rate) } }) t.Run("requests start failing after reaching burst", func(t *testing.T) { - //now := mockNow() - for i := 0; i < 10; i++ { + for i := 0; i < 5; i++ { got := rl.DomainAllowed(domain + ".diff") - fmt.Printf("for:%d got: %t\n", i, got) - //require.True(t, true) - //if i < 2 { - require.Truef(t, got, "expected true for request no. %d", i) - //} else { - // require.False(t, got, "expected false for request no. %d", i) - //} - time.Sleep(time.Nanosecond) + if i < 1 { + require.Truef(t, got, "expected true for request no. %d", i) + } else { + require.False(t, got, "expected false for request no. %d", i) + } } }) } diff --git a/test/acceptance/ratelimit_test.go b/test/acceptance/ratelimit_test.go index e7fc59a7a..04518fe5e 100644 --- a/test/acceptance/ratelimit_test.go +++ b/test/acceptance/ratelimit_test.go @@ -1,7 +1,6 @@ package acceptance_test import ( - "fmt" "net/http" "testing" "time" @@ -12,27 +11,27 @@ import ( func TestRateLimitMiddleware(t *testing.T) { RunPagesProcess(t, withListeners([]ListenSpec{httpListener}), - //refills 1 token every 50ms, bound by the burst/bucket size - withExtraArgument("req-domain-per-second", "50ms"), - // allows a max of 10 tokens at a time PER SECOND - withExtraArgument("req-domain-bucket-size", "10"), + withExtraArgument("enable-rate-limiter", "true"), + //refills 1 token every 10ms, bound by the burst/bucket size + withExtraArgument("rate-limit-per-domain", "10ms"), + // allows a max of 1 token at a time PER instance of time + withExtraArgument("rate-limit-per-domain-burst-size", "1"), ) for i := 0; i < 20; i++ { rsp1, err := GetPageFromListener(t, httpListener, "group.gitlab-example.com", "project/") require.NoError(t, err) - defer rsp1.Body.Close() - fmt.Printf("req: %d - status: %d\n", i, rsp1.StatusCode) + rsp1.Body.Close() - // every ~10th request should fail - if (i+1)%10 == 0 { + // every other request should fail + if i%2 != 0 { require.Equal(t, http.StatusTooManyRequests, rsp1.StatusCode, "group.gitlab-example.com request: %d failed", i) - time.Sleep(500 * time.Millisecond) + // wait for another token to become available + time.Sleep(10 * time.Millisecond) continue } require.Equal(t, http.StatusOK, rsp1.StatusCode, "group.gitlab-example.com request: %d failed", i) - // sleep almost close to req-domain-per-second - time.Sleep(49 * time.Millisecond) + time.Sleep(time.Millisecond) } } diff --git a/test/acceptance/serving_test.go b/test/acceptance/serving_test.go index dd6e92868..c6a7d3ef5 100644 --- a/test/acceptance/serving_test.go +++ b/test/acceptance/serving_test.go @@ -180,9 +180,7 @@ func TestCORSWhenDisabled(t *testing.T) { } func TestCORSAllowsMethod(t *testing.T) { - RunPagesProcess(t, - withExtraArgument("disable-rate-limiter", "true"), - ) + RunPagesProcess(t) tests := []struct { name string -- GitLab From 9d8e2ec112f0d2b33705511bb9aed5d33ab1938e Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Wed, 15 Sep 2021 12:21:32 +1000 Subject: [PATCH 8/9] feat: use lru cache instead --- internal/ratelimiter/lru_cache.go | 64 +++++++++++++++++++++++++++++ internal/ratelimiter/ratelimit.go | 67 +++++++++++-------------------- 2 files changed, 88 insertions(+), 43 deletions(-) create mode 100644 internal/ratelimiter/lru_cache.go diff --git a/internal/ratelimiter/lru_cache.go b/internal/ratelimiter/lru_cache.go new file mode 100644 index 000000000..34c66c811 --- /dev/null +++ b/internal/ratelimiter/lru_cache.go @@ -0,0 +1,64 @@ +package ratelimiter + +import ( + "time" + + "github.com/karlseguin/ccache/v2" +) + +// lruCacheGetPerPromote is a value that makes the item to be promoted +// it is taken arbitrarily as a sane value indicating that the item +// was frequently picked +// promotion moves the item to the front of the LRU list +const lruCacheGetsPerPromote = 64 + +// lruCacheItemsToPruneDiv is a value that indicates how many items +// need to be pruned on OOM, this prunes 1/16 of items +const lruCacheItemsToPruneDiv = 16 + +type lruCache struct { + op string + duration time.Duration + cache *ccache.Cache +} + +func newLruCache(op string, maxEntries int64, duration time.Duration) *lruCache { + configuration := ccache.Configure() + configuration.MaxSize(maxEntries) + configuration.ItemsToPrune(uint32(maxEntries) / lruCacheItemsToPruneDiv) + configuration.GetsPerPromote(lruCacheGetsPerPromote) // if item gets requested frequently promote it + configuration.OnDelete(func(*ccache.Item) { + // TODO: add metrics + //metrics.ZipCachedEntries.WithLabelValues(op).Dec() + }) + + return &lruCache{ + op: op, + cache: ccache.New(configuration), + duration: duration, + } +} + +func (c *lruCache) findOrFetch(cacheNamespace, key string, fetchFn func() (interface{}, error)) (interface{}, error) { + item := c.cache.Get(cacheNamespace + key) + + if item != nil && !item.Expired() { + // TODO: add metrics + //metrics.ZipCacheRequests.WithLabelValues(c.op, "hit").Inc() + return item.Value(), nil + } + + value, err := fetchFn() + if err != nil { + // TODO: add metrics + //metrics.ZipCacheRequests.WithLabelValues(c.op, "error").Inc() + return nil, err + } + + // TODO: add metrics + //metrics.ZipCacheRequests.WithLabelValues(c.op, "miss").Inc() + //metrics.ZipCachedEntries.WithLabelValues(c.op).Inc() + + c.cache.Set(cacheNamespace+key, value, c.duration) + return value, nil +} diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go index 4494e6386..63342cea7 100644 --- a/internal/ratelimiter/ratelimit.go +++ b/internal/ratelimiter/ratelimit.go @@ -1,7 +1,6 @@ package ratelimiter import ( - "sync" "time" "gitlab.com/gitlab-org/labkit/log" @@ -21,6 +20,11 @@ const ( // E.g. The first 40 requests within 25ms will succeed, but the 41st will fail until the next // refill occurs at DefaultPerDomainFrequency, allowing only 1 request per rate frequency. DefaultPerDomainBurstSize = 40 + + // avg of ~18,000 unique domains per hour + // https://log.gprd.gitlab.net/app/lens#/edit/3c45a610-15c9-11ec-a012-eb2e5674cacf?_g=h@e78830b + defaultDomainsItems = 20000 + defaultDomainsExpirationInterval = time.Hour ) type counter struct { @@ -42,9 +46,9 @@ type RateLimiter struct { domainMaxTTL time.Duration perDomainFrequency time.Duration perDomainBurstSize int - domainMux *sync.RWMutex - // TODO: this could be an LRU cache like what we do in the zip VFS instead of cleaning manually ? - perDomain map[string]*counter + //domainMux *sync.RWMutex + domainsCache *lruCache + // TODO: add sourceIPCache } // New creates a new RateLimiter with default values that can be configured via Option functions @@ -55,16 +59,14 @@ func New(opts ...Option) *RateLimiter { domainMaxTTL: DefaultMaxTimePerDomain, perDomainFrequency: DefaultPerDomainFrequency, perDomainBurstSize: DefaultPerDomainBurstSize, - domainMux: &sync.RWMutex{}, - perDomain: make(map[string]*counter), + //domainMux: &sync.RWMutex{}, + domainsCache: newLruCache("domains", defaultDomainsItems, defaultDomainsExpirationInterval), } for _, opt := range opts { opt(rl) } - go rl.cleanup() - return rl } @@ -96,48 +98,27 @@ func WithPerDomainBurstSize(burst int) Option { } } -func (rl *RateLimiter) getDomainCounter(domain string) *counter { - rl.domainMux.Lock() - defer rl.domainMux.Unlock() - - // TODO: add metrics - currentCounter, ok := rl.perDomain[domain] - if !ok { - newCounter := &counter{ - lastSeen: rl.now(), - // the first argument is the number of requests per second that will be allowed, - limiter: rate.NewLimiter(rate.Every(rl.perDomainFrequency), rl.perDomainBurstSize), - } - - rl.perDomain[domain] = newCounter - return newCounter +func (rl *RateLimiter) getDomainCounter(domain string) (*rate.Limiter, error) { + limiterI, err := rl.domainsCache.findOrFetch(domain, domain, func() (interface{}, error) { + return rate.NewLimiter(rate.Every(rl.perDomainFrequency), rl.perDomainBurstSize), nil + }) + if err != nil { + return nil, err } - currentCounter.lastSeen = rl.now() - return currentCounter + return limiterI.(*rate.Limiter), nil } // DomainAllowed checks that the requested domain can be accessed within // the maxCountPerDomain in the given domainWindow. func (rl *RateLimiter) DomainAllowed(domain string) (res bool) { - counter := rl.getDomainCounter(domain) + limiter, err := rl.getDomainCounter(domain) + if err != nil { + // TODO: return and handle error appropriately + log.WithError(err).Warnf("failed to get rate limiter for domain: %s", domain) + return true + } // TODO: we could use Wait(ctx) if we want to moderate the request rate rather than denying requests - return counter.limiter.AllowN(rl.now(), 1) -} - -func (rl *RateLimiter) cleanup() { - select { - case t := <-rl.cleanupTimer.C: - log.WithField("cleanup", t).Debug("cleaning perDomain rate") - - rl.domainMux.Lock() - for domain, counter := range rl.perDomain { - if time.Since(counter.lastSeen) > rl.domainMaxTTL { - delete(rl.perDomain, domain) - } - } - rl.domainMux.Unlock() - default: - } + return limiter.AllowN(rl.now(), 1) } -- GitLab From a954b0d0832b3abff37e26380f2aaac8b1f65fb7 Mon Sep 17 00:00:00 2001 From: Jaime Martinez Date: Wed, 15 Sep 2021 12:24:13 +1000 Subject: [PATCH 9/9] chore: cleanup unused fields --- internal/ratelimiter/ratelimit.go | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/internal/ratelimiter/ratelimit.go b/internal/ratelimiter/ratelimit.go index 63342cea7..70a30294f 100644 --- a/internal/ratelimiter/ratelimit.go +++ b/internal/ratelimiter/ratelimit.go @@ -8,11 +8,6 @@ import ( ) const ( - // DefaultCleanupInterval is the time at which cleanup will run - DefaultCleanupInterval = 30 * time.Second - // DefaultMaxTimePerDomain is the maximum time to keep a domain in the rate limiter map - DefaultMaxTimePerDomain = 30 * time.Second - // DefaultPerDomainFrequency the maximum number of requests per second to be allowed per domain. // The default value of 25ms equals 1 request every 25ms -> 40 rps DefaultPerDomainFrequency = 25 * time.Millisecond @@ -27,11 +22,6 @@ const ( defaultDomainsExpirationInterval = time.Hour ) -type counter struct { - limiter *rate.Limiter - lastSeen time.Time -} - // Option function to configure a RateLimiter type Option func(*RateLimiter) @@ -42,8 +32,6 @@ type Option func(*RateLimiter) // the time since counter.lastSeen is greater than the domainMaxTTL. type RateLimiter struct { now func() time.Time - cleanupTimer *time.Ticker - domainMaxTTL time.Duration perDomainFrequency time.Duration perDomainBurstSize int //domainMux *sync.RWMutex @@ -55,8 +43,6 @@ type RateLimiter struct { func New(opts ...Option) *RateLimiter { rl := &RateLimiter{ now: time.Now, - cleanupTimer: time.NewTicker(DefaultCleanupInterval), - domainMaxTTL: DefaultMaxTimePerDomain, perDomainFrequency: DefaultPerDomainFrequency, perDomainBurstSize: DefaultPerDomainBurstSize, //domainMux: &sync.RWMutex{}, @@ -77,13 +63,6 @@ func WithNow(now func() time.Time) Option { } } -// WithCleanupInterval replaces the RateLimiter cleanup interval -func WithCleanupInterval(d time.Duration) Option { - return func(rl *RateLimiter) { - rl.cleanupTimer.Reset(d) - } -} - // WithPerDomainFrequency allows configuring perDomain frequency for the RateLimiter func WithPerDomainFrequency(d time.Duration) Option { return func(rl *RateLimiter) { -- GitLab