diff --git a/internal/cli/gitaly/serve.go b/internal/cli/gitaly/serve.go index 87ad0ebe309579fa096e1a96dc5f3030e563b1b5..038fd929703ea829e9a94eb9b4e4686a8487001e 100644 --- a/internal/cli/gitaly/serve.go +++ b/internal/cli/gitaly/serve.go @@ -314,7 +314,7 @@ func run(appCtx *cli.Command, cfg config.Cfg, logger log.Logger) error { ) prometheus.MustRegister(perRPCLimitHandler) - var packObjectLimit *limiter.AdaptiveLimit + var packObjectLimit, packObjectLimitUnauthenticated *limiter.AdaptiveLimit if cfg.PackObjectsLimiting.Adaptive { packObjectLimit = limiter.NewAdaptiveLimit("packObjects", limiter.AdaptiveSetting{ Initial: cfg.PackObjectsLimiting.InitialLimit, @@ -329,6 +329,20 @@ func run(appCtx *cli.Command, cfg config.Cfg, logger log.Logger) error { }) } + if cfg.PackObjectsLimiting.Unauthenticated.Adaptive { + packObjectLimitUnauthenticated = limiter.NewAdaptiveLimit("packObjectsUnauthenticated", limiter.AdaptiveSetting{ + Initial: cfg.PackObjectsLimiting.Unauthenticated.InitialLimit, + Max: cfg.PackObjectsLimiting.Unauthenticated.MaxLimit, + Min: cfg.PackObjectsLimiting.Unauthenticated.MinLimit, + BackoffFactor: limiter.DefaultBackoffFactor, + }) + adaptiveLimits = append(adaptiveLimits, packObjectLimitUnauthenticated) + } else { + packObjectLimitUnauthenticated = limiter.NewAdaptiveLimit("packObjectsUnauthenticated", limiter.AdaptiveSetting{ + Initial: cfg.PackObjectsLimiting.Unauthenticated.MaxConcurrency, + }) + } + packObjectsMonitor := limiter.NewPackObjectsConcurrencyMonitor( cfg.Prometheus.GRPCLatencyBuckets, ) @@ -338,6 +352,12 @@ func run(appCtx *cli.Command, cfg config.Cfg, logger log.Logger) error { cfg.PackObjectsLimiting.MaxQueueWait.Duration(), packObjectsMonitor, ) + packObjectsLimiterUnauthenticated := limiter.NewConcurrencyLimiter( + packObjectLimitUnauthenticated, + cfg.PackObjectsLimiting.Unauthenticated.MaxQueueLength, + cfg.PackObjectsLimiting.Unauthenticated.MaxQueueWait.Duration(), + packObjectsMonitor, + ) prometheus.MustRegister(packObjectsMonitor) // Enable the adaptive calculator only if there is any limit needed to be adaptive. @@ -692,28 +712,29 @@ func run(appCtx *cli.Command, cfg config.Cfg, logger log.Logger) error { } setup.RegisterAll(srv, &service.Dependencies{ - Logger: logger, - Cfg: cfg, - GitalyHookManager: hookManager, - TransactionManager: transactionManager, - StorageLocator: locator, - ClientPool: conns, - GitCmdFactory: gitCmdFactory, - CatfileCache: catfileCache, - DiskCache: diskCache, - PackObjectsCache: packObjectStreamCache, - PackObjectsLimiter: packObjectsLimiter, - RepositoryCounter: repoCounter, - UpdaterWithHooks: updaterWithHooks, - Node: node, - TransactionRegistry: txRegistry, - HousekeepingManager: housekeepingManager, - BackupSink: backupSink, - BackupLocator: backupLocator, - LocalRepositoryFactory: localrepoFactory, - BundleURIManager: bundleURIManager, - MigrationStateManager: migration.NewStateManager(&migrations), - ArchiveCache: archiveStreamCache, + Logger: logger, + Cfg: cfg, + GitalyHookManager: hookManager, + TransactionManager: transactionManager, + StorageLocator: locator, + ClientPool: conns, + GitCmdFactory: gitCmdFactory, + CatfileCache: catfileCache, + DiskCache: diskCache, + PackObjectsCache: packObjectStreamCache, + PackObjectsLimiter: packObjectsLimiter, + PackObjectsLimiterUnauthenticated: packObjectsLimiterUnauthenticated, + RepositoryCounter: repoCounter, + UpdaterWithHooks: updaterWithHooks, + Node: node, + TransactionRegistry: txRegistry, + HousekeepingManager: housekeepingManager, + BackupSink: backupSink, + BackupLocator: backupLocator, + LocalRepositoryFactory: localrepoFactory, + BundleURIManager: bundleURIManager, + MigrationStateManager: migration.NewStateManager(&migrations), + ArchiveCache: archiveStreamCache, }) b.RegisterStarter(starter.New(c, srv, logger)) } diff --git a/internal/featureflag/ff_unauthenticated_concurrency.go b/internal/featureflag/ff_unauthenticated_concurrency.go new file mode 100644 index 0000000000000000000000000000000000000000..d68a5c13d78ec831048f98f6ac8fdc1c1e4d66ec --- /dev/null +++ b/internal/featureflag/ff_unauthenticated_concurrency.go @@ -0,0 +1,10 @@ +package featureflag + +// LimitUnauthenticated allows the concurrency limiter to limit unauthenticated +// requests separately from authenticated requests. +var LimitUnauthenticated = NewFeatureFlag( + "limit_unauthenticated", + "v18.6.0", + "https://gitlab.com/gitlab-org/gitaly/-/issues/6955", + true, +) diff --git a/internal/gitaly/config/config.go b/internal/gitaly/config/config.go index 6652cdda5760646fe4ce3f1e479ad7113302e35f..1088c254998af720441bb80a7194c3baea72d242 100644 --- a/internal/gitaly/config/config.go +++ b/internal/gitaly/config/config.go @@ -525,8 +525,31 @@ type Logging struct { // Requests that come in after the maximum number of concurrent requests are in progress will wait // in a queue that is bounded by MaxQueueSize. type Concurrency struct { + ConcurrencyLimits // RPC is the name of the RPC to set concurrency limits for RPC string `json:"rpc" toml:"rpc"` + // Unauthenticated sets the limits for unauthenticated requests + Unauthenticated ConcurrencyLimits `json:"unauthenticated" toml:"unauthenticated"` +} + +// ConcurrencyLimits sets the limits for adaptive limiting +type ConcurrencyLimits struct { + // MaxPerRepo is the maximum number of concurrent calls for a given repository. This config is used only + // if Adaptive is false. + MaxPerRepo int `json:"max_per_repo" toml:"max_per_repo"` + // MaxConcurrency is the static maximum number of concurrent processes for a given key. This config + // is used only if Adaptive is false. This field is provided for compatibility with pack-objects specific configuration + // and is treated the same as MaxPerRepo. + MaxConcurrency int `json:"max_concurrency,omitempty" toml:"max_concurrency,omitempty"` + // MaxQueueSize is the maximum number of requests in the queue waiting to be picked up + // after which subsequent requests will return with an error. + MaxQueueSize int `json:"max_queue_size" toml:"max_queue_size"` + // MaxQueueLength is the maximum length of the request queue. This field is provided for compatibility with + // pack-objects specific configuration and is treated the same as MaxQueueSize. + MaxQueueLength int `json:"max_queue_length,omitempty" toml:"max_queue_length,omitempty"` + // MaxQueueWait is the maximum time a request can remain in the concurrency queue + // waiting to be picked up by Gitaly + MaxQueueWait duration.Duration `json:"max_queue_wait" toml:"max_queue_wait"` // Adaptive determines the behavior of the concurrency limit. If set to true, the concurrency limit is dynamic // and starts at InitialLimit, then adjusts within the range [MinLimit, MaxLimit] based on current resource // usage. If set to false, the concurrency limit is static and is set to MaxPerRepo. @@ -537,33 +560,87 @@ type Concurrency struct { MaxLimit int `json:"max_limit,omitempty" toml:"max_limit,omitempty"` // MinLimit is the mini adaptive concurrency limit. MinLimit int `json:"min_limit,omitempty" toml:"min_limit,omitempty"` - // MaxPerRepo is the maximum number of concurrent calls for a given repository. This config is used only - // if Adaptive is false. - MaxPerRepo int `json:"max_per_repo" toml:"max_per_repo"` - // MaxQueueSize is the maximum number of requests in the queue waiting to be picked up - // after which subsequent requests will return with an error. - MaxQueueSize int `json:"max_queue_size" toml:"max_queue_size"` - // MaxQueueWait is the maximum time a request can remain in the concurrency queue - // waiting to be picked up by Gitaly - MaxQueueWait duration.Duration `json:"max_queue_wait" toml:"max_queue_wait"` } -// Validate runs validation on all fields and compose all found errors. -func (c Concurrency) Validate() error { - errs := cfgerror.New(). - Append(cfgerror.Comparable(c.MaxPerRepo).GreaterOrEqual(0), "max_per_repo"). - Append(cfgerror.Comparable(c.MaxQueueSize).GreaterThan(0), "max_queue_size"). - Append(cfgerror.Comparable(c.MaxQueueWait.Duration()).GreaterOrEqual(0), "max_queue_wait") +// validateWithQueueField validates concurrency limits with the specified queue field name for error reporting +func (cl ConcurrencyLimits) validateWithQueueField(preferQueueLength bool) error { + errs := cfgerror.New() - if c.Adaptive { + // Validate that MaxConcurrency and MaxPerRepo are not set to different values + // Sanitize() will mirror one to the other, so they should have the same value if both are set + if cl.MaxConcurrency != 0 && cl.MaxPerRepo != 0 && cl.MaxConcurrency != cl.MaxPerRepo { + errs = errs.Append( + cfgerror.NewValidationError( + errors.New("max_concurrency and max_per_repo cannot be set to different values"), + "max_concurrency", + "max_per_repo", + ), + ) + } + + // Validate that MaxQueueSize and MaxQueueLength are not set to different values + // Sanitize() will mirror one to the other, so they should have the same value if both are set + if cl.MaxQueueSize != 0 && cl.MaxQueueLength != 0 && cl.MaxQueueSize != cl.MaxQueueLength { + errs = errs.Append( + cfgerror.NewValidationError( + errors.New("max_queue_size and max_queue_length cannot be set to different values"), + "max_queue_size", + "max_queue_length", + ), + ) + } + + // Validate MaxConcurrency or MaxPerRepo (whichever is set) + if cl.MaxConcurrency != 0 { + errs = errs.Append(cfgerror.Comparable(cl.MaxConcurrency).GreaterOrEqual(0), "max_concurrency") + } else if cl.MaxPerRepo != 0 { + errs = errs.Append(cfgerror.Comparable(cl.MaxPerRepo).GreaterOrEqual(0), "max_per_repo") + } + + // Validate MaxQueueLength or MaxQueueSize (whichever is set) + // Choose which field to validate based on context (Concurrency vs PackObjectsLimiting) + if cl.MaxQueueSize != 0 { + errs = errs.Append(cfgerror.Comparable(cl.MaxQueueSize).GreaterThan(0), "max_queue_size") + } else if cl.MaxQueueLength != 0 { + errs = errs.Append(cfgerror.Comparable(cl.MaxQueueLength).GreaterThan(0), "max_queue_length") + } else { + // Neither is set, default based on context + if preferQueueLength { + errs = errs.Append(cfgerror.Comparable(cl.MaxQueueLength).GreaterThan(0), "max_queue_length") + } else { + errs = errs.Append(cfgerror.Comparable(cl.MaxQueueSize).GreaterThan(0), "max_queue_size") + } + } + + errs = errs.Append(cfgerror.Comparable(cl.MaxQueueWait.Duration()).GreaterOrEqual(0), "max_queue_wait") + + // Validate adaptive limiting fields + if cl.Adaptive { errs = errs. - Append(cfgerror.Comparable(c.MinLimit).GreaterThan(0), "min_limit"). - Append(cfgerror.Comparable(c.MaxLimit).GreaterOrEqual(c.InitialLimit), "max_limit"). - Append(cfgerror.Comparable(c.InitialLimit).GreaterOrEqual(c.MinLimit), "initial_limit") + Append(cfgerror.Comparable(cl.MinLimit).GreaterOrEqual(0), "min_limit"). + Append(cfgerror.Comparable(cl.MaxLimit).GreaterOrEqual(cl.InitialLimit), "max_limit"). + Append(cfgerror.Comparable(cl.InitialLimit).GreaterOrEqual(cl.MinLimit), "initial_limit") + } else { + errs = errs. + Append(cfgerror.Comparable(cl.MinLimit).GreaterOrEqual(0), "min_limit"). + Append(cfgerror.Comparable(cl.MaxLimit).GreaterOrEqual(cl.InitialLimit), "max_limit"). + Append(cfgerror.Comparable(cl.InitialLimit).GreaterOrEqual(cl.MinLimit), "initial_limit") } + return errs.AsError() } +// Validate runs validation on all fields and compose all found errors. +// Defaults to preferring max_queue_size for error reporting. +func (cl ConcurrencyLimits) Validate() error { + return cl.validateWithQueueField(false) +} + +// Validate runs validation on all fields and compose all found errors. +func (c Concurrency) Validate() error { + return c.ConcurrencyLimits.validateWithQueueField(false) +} + // AdaptiveLimiting defines a set of global config for the adaptive limiter. This config customizes how the resource // watchers and calculator works. Specific limits for each RPC or pack-objects operation should be configured // individually using the Concurrency and PackObjectsLimiting structs respectively. @@ -590,36 +667,14 @@ func (c AdaptiveLimiting) Validate() error { // Requests that come in after the maximum number of concurrent pack objects // processes have been reached will wait. type PackObjectsLimiting struct { - // Adaptive determines the behavior of the concurrency limit. If set to true, the concurrency limit is dynamic - // and starts at InitialLimit, then adjusts within the range [MinLimit, MaxLimit] based on current resource - // usage. If set to false, the concurrency limit is static and is set to MaxConcurrency. - Adaptive bool `json:"adaptive,omitempty" toml:"adaptive,omitempty"` - // InitialLimit is the concurrency limit to start with. - InitialLimit int `json:"initial_limit,omitempty" toml:"initial_limit,omitempty"` - // MaxLimit is the minimum adaptive concurrency limit. - MaxLimit int `json:"max_limit,omitempty" toml:"max_limit,omitempty"` - // MinLimit is the mini adaptive concurrency limit. - MinLimit int `json:"min_limit,omitempty" toml:"min_limit,omitempty"` - // MaxConcurrency is the static maximum number of concurrent pack objects processes for a given key. This config - // is used only if Adaptive is false. - MaxConcurrency int `json:"max_concurrency,omitempty" toml:"max_concurrency,omitempty"` - // MaxQueueWait is the maximum time a request can remain in the concurrency queue - // waiting to be picked up by Gitaly. - MaxQueueWait duration.Duration `json:"max_queue_wait,omitempty" toml:"max_queue_wait,omitempty"` - // MaxQueueLength is the maximum length of the request queue - MaxQueueLength int `json:"max_queue_length,omitempty" toml:"max_queue_length,omitempty"` + ConcurrencyLimits + // Unauthenticated sets the limits for unauthenticated requests + Unauthenticated ConcurrencyLimits `json:"unauthenticated" toml:"unauthenticated"` } // Validate runs validation on all fields and compose all found errors. func (pol PackObjectsLimiting) Validate() error { - return cfgerror.New(). - Append(cfgerror.Comparable(pol.MaxConcurrency).GreaterOrEqual(0), "max_concurrency"). - Append(cfgerror.Comparable(pol.MaxQueueLength).GreaterThan(0), "max_queue_length"). - Append(cfgerror.Comparable(pol.MaxQueueWait.Duration()).GreaterOrEqual(0), "max_queue_wait"). - Append(cfgerror.Comparable(pol.MinLimit).GreaterOrEqual(0), "min_limit"). - Append(cfgerror.Comparable(pol.MaxLimit).GreaterOrEqual(pol.InitialLimit), "max_limit"). - Append(cfgerror.Comparable(pol.InitialLimit).GreaterOrEqual(pol.MinLimit), "initial_limit"). - AsError() + return pol.ConcurrencyLimits.validateWithQueueField(true) } // BackupConfig configures server-side and write-ahead log backups. @@ -742,10 +797,12 @@ func defaultPackObjectsCacheConfig() StreamCacheConfig { func defaultPackObjectsLimiting() PackObjectsLimiting { return PackObjectsLimiting{ - MaxConcurrency: defaultPackObjectsLimitingConcurrency, - MaxQueueLength: defaultPackObjectsLimitingQueueSize, - // Requests can stay in the queue as long as they want - MaxQueueWait: 0, + ConcurrencyLimits: ConcurrencyLimits{ + // Requests can stay in the queue as long as they want + MaxQueueWait: 0, + MaxConcurrency: defaultPackObjectsLimitingConcurrency, + MaxQueueLength: defaultPackObjectsLimitingQueueSize, + }, } } @@ -909,8 +966,24 @@ func (cfg *Cfg) Sanitize() error { } } - if cfg.PackObjectsLimiting.MaxQueueLength == 0 { + // Mirror MaxConcurrency <-> MaxPerRepo for PackObjectsLimiting + if cfg.PackObjectsLimiting.MaxConcurrency != 0 && cfg.PackObjectsLimiting.MaxPerRepo == 0 { + cfg.PackObjectsLimiting.MaxPerRepo = cfg.PackObjectsLimiting.MaxConcurrency + } else if cfg.PackObjectsLimiting.MaxPerRepo != 0 && cfg.PackObjectsLimiting.MaxConcurrency == 0 { + cfg.PackObjectsLimiting.MaxConcurrency = cfg.PackObjectsLimiting.MaxPerRepo + } + + // Mirror MaxQueueLength <-> MaxQueueSize for PackObjectsLimiting + if cfg.PackObjectsLimiting.MaxQueueLength != 0 && cfg.PackObjectsLimiting.MaxQueueSize == 0 { + cfg.PackObjectsLimiting.MaxQueueSize = cfg.PackObjectsLimiting.MaxQueueLength + } else if cfg.PackObjectsLimiting.MaxQueueSize != 0 && cfg.PackObjectsLimiting.MaxQueueLength == 0 { + cfg.PackObjectsLimiting.MaxQueueLength = cfg.PackObjectsLimiting.MaxQueueSize + } + + // Set default if neither is set + if cfg.PackObjectsLimiting.MaxQueueLength == 0 && cfg.PackObjectsLimiting.MaxQueueSize == 0 { cfg.PackObjectsLimiting.MaxQueueLength = defaultPackObjectsLimitingQueueSize + cfg.PackObjectsLimiting.MaxQueueSize = defaultPackObjectsLimitingQueueSize } if cfg.ArchiveCache.Enabled { diff --git a/internal/gitaly/config/config_test.go b/internal/gitaly/config/config_test.go index 57516e83379f178bd73d969a7db7930cc084f7fe..31de1bb19707504fdbb643d3e0361106e8023588 100644 --- a/internal/gitaly/config/config_test.go +++ b/internal/gitaly/config/config_test.go @@ -1874,9 +1874,13 @@ func TestPackObjectsLimiting(t *testing.T) { max_queue_wait = "10s" `, expectedCfg: PackObjectsLimiting{ - MaxConcurrency: 20, - MaxQueueLength: 100, - MaxQueueWait: duration.Duration(10 * time.Second), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxConcurrency: 20, + MaxQueueSize: 100, + MaxQueueLength: 100, + MaxQueueWait: duration.Duration(10 * time.Second), + }, }, }, { @@ -1887,9 +1891,13 @@ func TestPackObjectsLimiting(t *testing.T) { max_queue_wait = "1m" `, expectedCfg: PackObjectsLimiting{ - MaxConcurrency: 10, - MaxQueueLength: 100, - MaxQueueWait: duration.Duration(1 * time.Minute), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 10, + MaxConcurrency: 10, + MaxQueueSize: 100, + MaxQueueLength: 100, + MaxQueueWait: duration.Duration(1 * time.Minute), + }, }, }, { @@ -1899,9 +1907,13 @@ func TestPackObjectsLimiting(t *testing.T) { max_queue_wait = "1m" `, expectedCfg: PackObjectsLimiting{ - MaxConcurrency: 10, - MaxQueueLength: 200, - MaxQueueWait: duration.Duration(1 * time.Minute), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 10, + MaxConcurrency: 10, + MaxQueueSize: 200, + MaxQueueLength: 200, + MaxQueueWait: duration.Duration(1 * time.Minute), + }, }, }, } @@ -1927,18 +1939,21 @@ func TestPackObjectsLimiting_defaultPackObjectsLimiting(t *testing.T) { cfg := defaultPackObjectsLimiting() require.Equal(t, PackObjectsLimiting{ - MaxConcurrency: 200, - MaxQueueWait: 0, - MaxQueueLength: 200, + ConcurrencyLimits: ConcurrencyLimits{ + MaxQueueWait: 0, + MaxConcurrency: 200, + MaxQueueLength: 200, + }, }, cfg) + // Note: MaxPerRepo and MaxQueueSize are not set in the default, they will be mirrored during Sanitize() } func TestPackObjectsLimiting_Validate(t *testing.T) { t.Parallel() - require.NoError(t, PackObjectsLimiting{MaxConcurrency: 0, MaxQueueLength: 1}.Validate()) - require.NoError(t, PackObjectsLimiting{MaxConcurrency: 1, MaxQueueLength: 1}.Validate()) - require.NoError(t, PackObjectsLimiting{MaxConcurrency: 100, MaxQueueLength: 1}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxConcurrency: 0, MaxQueueLength: 1}}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxConcurrency: 1, MaxQueueLength: 1}}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxConcurrency: 100, MaxQueueLength: 1}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -1947,12 +1962,12 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "max_concurrency", ), }, - PackObjectsLimiting{MaxConcurrency: -1, MaxQueueLength: 1}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxConcurrency: -1, MaxQueueLength: 1}}.Validate(), ) - require.NoError(t, PackObjectsLimiting{Adaptive: true, InitialLimit: 0, MinLimit: 0, MaxLimit: 100, MaxQueueLength: 100}.Validate()) - require.NoError(t, PackObjectsLimiting{Adaptive: true, InitialLimit: 10, MinLimit: 0, MaxLimit: 100, MaxQueueLength: 100}.Validate()) - require.NoError(t, PackObjectsLimiting{Adaptive: true, InitialLimit: 100, MinLimit: 0, MaxLimit: 100, MaxQueueLength: 100}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 0, MinLimit: 0, MaxLimit: 100, MaxQueueLength: 100}}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 0, MaxLimit: 100, MaxQueueLength: 100}}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 100, MinLimit: 0, MaxLimit: 100, MaxQueueLength: 100}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -1961,7 +1976,7 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "initial_limit", ), }, - PackObjectsLimiting{Adaptive: true, InitialLimit: -1, MinLimit: 0, MaxLimit: 100, MaxQueueLength: 100}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: -1, MinLimit: 0, MaxLimit: 100, MaxQueueLength: 100}}.Validate(), ) require.Equal( t, @@ -1971,7 +1986,7 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "initial_limit", ), }, - PackObjectsLimiting{Adaptive: true, InitialLimit: 10, MinLimit: 11, MaxLimit: 100, MaxQueueLength: 100}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 11, MaxLimit: 100, MaxQueueLength: 100}}.Validate(), ) require.Equal( t, @@ -1981,7 +1996,7 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "max_limit", ), }, - PackObjectsLimiting{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: 3, MaxQueueLength: 100}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: 3, MaxQueueLength: 100}}.Validate(), ) require.Equal( t, @@ -1991,7 +2006,7 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "min_limit", ), }, - PackObjectsLimiting{Adaptive: true, InitialLimit: 5, MinLimit: -1, MaxLimit: 99, MaxQueueLength: 100}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 5, MinLimit: -1, MaxLimit: 99, MaxQueueLength: 100}}.Validate(), ) require.Equal( t, @@ -2001,11 +2016,11 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "max_limit", ), }, - PackObjectsLimiting{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: -1, MaxQueueLength: 100}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: -1, MaxQueueLength: 100}}.Validate(), ) - require.NoError(t, PackObjectsLimiting{MaxQueueLength: 1}.Validate()) - require.NoError(t, PackObjectsLimiting{MaxQueueLength: 100}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxQueueLength: 1}}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxQueueLength: 100}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2014,7 +2029,7 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "max_queue_length", ), }, - PackObjectsLimiting{MaxQueueLength: 0}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxQueueLength: 0}}.Validate(), ) require.Equal( t, @@ -2024,10 +2039,10 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "max_queue_length", ), }, - PackObjectsLimiting{MaxQueueLength: -1}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxQueueLength: -1}}.Validate(), ) - require.NoError(t, PackObjectsLimiting{MaxQueueWait: duration.Duration(1), MaxQueueLength: 1}.Validate()) + require.NoError(t, PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxQueueWait: duration.Duration(1), MaxQueueLength: 1}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2036,7 +2051,7 @@ func TestPackObjectsLimiting_Validate(t *testing.T) { "max_queue_wait", ), }, - PackObjectsLimiting{MaxQueueWait: duration.Duration(-time.Minute), MaxQueueLength: 1}.Validate(), + PackObjectsLimiting{ConcurrencyLimits: ConcurrencyLimits{MaxQueueWait: duration.Duration(-time.Minute), MaxQueueLength: 1}}.Validate(), ) } @@ -2056,9 +2071,11 @@ func TestConcurrency(t *testing.T) { max_per_repo = 20 `, expectedCfg: []Concurrency{{ - RPC: "/gitaly.CommitService/ListCommitsByOid", - MaxPerRepo: 20, - MaxQueueSize: 500, + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxQueueSize: 500, + }, + RPC: "/gitaly.CommitService/ListCommitsByOid", }}, }, { @@ -2070,10 +2087,12 @@ func TestConcurrency(t *testing.T) { max_queue_wait = "10s" `, expectedCfg: []Concurrency{{ - RPC: "/gitaly.CommitService/ListCommitsByOid", - MaxPerRepo: 20, - MaxQueueSize: 100, - MaxQueueWait: duration.Duration(10 * time.Second), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxQueueSize: 100, + MaxQueueWait: duration.Duration(10 * time.Second), + }, + RPC: "/gitaly.CommitService/ListCommitsByOid", }}, }, { @@ -2085,10 +2104,12 @@ func TestConcurrency(t *testing.T) { max_queue_wait = "1m" `, expectedCfg: []Concurrency{{ - RPC: "/gitaly.CommitService/ListCommitsByOid", - MaxPerRepo: 20, - MaxQueueSize: 100, - MaxQueueWait: duration.Duration(1 * time.Minute), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxQueueSize: 100, + MaxQueueWait: duration.Duration(1 * time.Minute), + }, + RPC: "/gitaly.CommitService/ListCommitsByOid", }}, }, { @@ -2106,15 +2127,19 @@ func TestConcurrency(t *testing.T) { `, expectedCfg: []Concurrency{ { - RPC: "/gitaly.CommitService/ListCommits", - MaxPerRepo: 20, - MaxQueueSize: 20, + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 20, + MaxQueueSize: 20, + }, + RPC: "/gitaly.CommitService/ListCommits", }, { - RPC: "/gitaly.CommitService/ListCommitsByOid", - MaxPerRepo: 30, - MaxQueueSize: 500, - MaxQueueWait: duration.Duration(10 * time.Second), + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 30, + MaxQueueSize: 500, + MaxQueueWait: duration.Duration(10 * time.Second), + }, + RPC: "/gitaly.CommitService/ListCommitsByOid", }, }, }, @@ -2130,13 +2155,15 @@ func TestConcurrency(t *testing.T) { initial_limit = 40 `, expectedCfg: []Concurrency{{ - RPC: "/gitaly.SmartHTTPService/PostUploadPack", - MaxQueueSize: 100, - MaxQueueWait: duration.Duration(1 * time.Minute), - Adaptive: true, - MinLimit: 10, - MaxLimit: 60, - InitialLimit: 40, + ConcurrencyLimits: ConcurrencyLimits{ + MaxQueueSize: 100, + MaxQueueWait: duration.Duration(1 * time.Minute), + Adaptive: true, + MinLimit: 10, + MaxLimit: 60, + InitialLimit: 40, + }, + RPC: "/gitaly.SmartHTTPService/PostUploadPack", }}, }, } @@ -2160,9 +2187,9 @@ func TestConcurrency(t *testing.T) { func TestConcurrency_Validate(t *testing.T) { t.Parallel() - require.NoError(t, Concurrency{MaxPerRepo: 0, MaxQueueSize: 1}.Validate()) - require.NoError(t, Concurrency{MaxPerRepo: 1, MaxQueueSize: 1}.Validate()) - require.NoError(t, Concurrency{MaxPerRepo: 100, MaxQueueSize: 100}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxPerRepo: 0, MaxQueueSize: 1}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxPerRepo: 1, MaxQueueSize: 1}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxPerRepo: 100, MaxQueueSize: 100}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2171,22 +2198,13 @@ func TestConcurrency_Validate(t *testing.T) { "max_per_repo", ), }, - Concurrency{MaxPerRepo: -1, MaxQueueSize: 1}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxPerRepo: -1, MaxQueueSize: 1}}.Validate(), ) - require.NoError(t, Concurrency{Adaptive: true, InitialLimit: 1, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}.Validate()) - require.NoError(t, Concurrency{Adaptive: true, InitialLimit: 10, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}.Validate()) - require.NoError(t, Concurrency{Adaptive: true, InitialLimit: 100, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}.Validate()) - require.Equal( - t, - cfgerror.ValidationErrors{ - cfgerror.NewValidationError( - fmt.Errorf("%w: 0 is not greater than 0", cfgerror.ErrNotInRange), - "min_limit", - ), - }, - Concurrency{Adaptive: true, InitialLimit: 0, MinLimit: 0, MaxLimit: 100, MaxQueueSize: 100}.Validate(), - ) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 1, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 100, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 0, MinLimit: 0, MaxLimit: 100, MaxQueueSize: 100}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2195,7 +2213,7 @@ func TestConcurrency_Validate(t *testing.T) { "initial_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: -1, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: -1, MinLimit: 1, MaxLimit: 100, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, @@ -2205,7 +2223,7 @@ func TestConcurrency_Validate(t *testing.T) { "initial_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 10, MinLimit: 11, MaxLimit: 100, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 11, MaxLimit: 100, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, @@ -2215,17 +2233,17 @@ func TestConcurrency_Validate(t *testing.T) { "max_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: 3, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: 3, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, cfgerror.ValidationErrors{ cfgerror.NewValidationError( - fmt.Errorf("%w: -1 is not greater than 0", cfgerror.ErrNotInRange), + fmt.Errorf("%w: -1 is not greater than or equal to 0", cfgerror.ErrNotInRange), "min_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 5, MinLimit: -1, MaxLimit: 99, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 5, MinLimit: -1, MaxLimit: 99, MaxQueueSize: 100}}.Validate(), ) require.Equal( t, @@ -2235,11 +2253,11 @@ func TestConcurrency_Validate(t *testing.T) { "max_limit", ), }, - Concurrency{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: -1, MaxQueueSize: 100}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{Adaptive: true, InitialLimit: 10, MinLimit: 5, MaxLimit: -1, MaxQueueSize: 100}}.Validate(), ) - require.NoError(t, Concurrency{MaxQueueSize: 1}.Validate()) - require.NoError(t, Concurrency{MaxQueueSize: 100}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueSize: 1}}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueSize: 100}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2248,7 +2266,7 @@ func TestConcurrency_Validate(t *testing.T) { "max_queue_size", ), }, - Concurrency{MaxQueueSize: 0}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueSize: 0}}.Validate(), ) require.Equal( t, @@ -2258,10 +2276,10 @@ func TestConcurrency_Validate(t *testing.T) { "max_queue_size", ), }, - Concurrency{MaxQueueSize: -1}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueSize: -1}}.Validate(), ) - require.NoError(t, Concurrency{MaxQueueWait: duration.Duration(1), MaxQueueSize: 1}.Validate()) + require.NoError(t, Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueWait: duration.Duration(1), MaxQueueSize: 1}}.Validate()) require.Equal( t, cfgerror.ValidationErrors{ @@ -2270,7 +2288,7 @@ func TestConcurrency_Validate(t *testing.T) { "max_queue_wait", ), }, - Concurrency{MaxQueueWait: duration.Duration(-time.Minute), MaxQueueSize: 1}.Validate(), + Concurrency{ConcurrencyLimits: ConcurrencyLimits{MaxQueueWait: duration.Duration(-time.Minute), MaxQueueSize: 1}}.Validate(), ) } @@ -3209,8 +3227,12 @@ func TestLoadDefaults(t *testing.T) { Backpressure: true, }, PackObjectsLimiting: PackObjectsLimiting{ - MaxConcurrency: 200, - MaxQueueLength: 200, + ConcurrencyLimits: ConcurrencyLimits{ + MaxPerRepo: 200, + MaxConcurrency: 200, + MaxQueueSize: 200, + MaxQueueLength: 200, + }, }, Backup: BackupConfig{ WALWorkerCount: 1, diff --git a/internal/gitaly/server/auth/auth.go b/internal/gitaly/server/auth/auth.go index bf7843c64a38aafdbec9d62e2bb7bd57e55e40b8..fb7c6fdba8a2f0fb59b8d1a32f0494e0301c6111 100644 --- a/internal/gitaly/server/auth/auth.go +++ b/internal/gitaly/server/auth/auth.go @@ -23,6 +23,20 @@ var authCount = promauto.NewCounterVec( []string{"enforced", "status"}, ) +type authenticatedKey struct{} + +// IsAuthenticated returns true if the request has been validated by the auth interceptor. +// This is different from just having an auth token in the metadata - this confirms the token +// was cryptographically validated. +func IsAuthenticated(ctx context.Context) bool { + authenticated, ok := ctx.Value(authenticatedKey{}).(bool) + return ok && authenticated +} + +func setAuthenticated(ctx context.Context) context.Context { + return context.WithValue(ctx, authenticatedKey{}, true) +} + // UnauthenticatedHealthService wraps the health server and disables authentication for all of its methods. type UnauthenticatedHealthService struct{ grpc_health_v1.HealthServer } @@ -52,6 +66,8 @@ func checkFunc(conf gitalycfgauth.Config) func(ctx context.Context) (context.Con switch status.Code(err) { case codes.OK: countStatus(okLabel(conf.Transitioning), conf.Transitioning).Inc() + // Mark the context as authenticated only when validation succeeds + ctx = setAuthenticated(ctx) case codes.Unauthenticated: countStatus("unauthenticated", conf.Transitioning).Inc() case codes.PermissionDenied: diff --git a/internal/gitaly/server/auth_test.go b/internal/gitaly/server/auth_test.go index a50083334c00e398da1b2923a66388b716f79b8f..a425af47d937dd5452f4734da9bc9ca0fc3c644e 100644 --- a/internal/gitaly/server/auth_test.go +++ b/internal/gitaly/server/auth_test.go @@ -360,8 +360,10 @@ func TestAuthBeforeLimit(t *testing.T) { cfg := testcfg.Build(t, testcfg.WithBase(config.Cfg{ Auth: auth.Config{Token: "abc123"}, Concurrency: []config.Concurrency{{ - RPC: "/gitaly.OperationService/UserCreateTag", - MaxPerRepo: 1, + RPC: "/gitaly.OperationService/UserCreateTag", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + }, }}, }, )) diff --git a/internal/gitaly/service/dependencies.go b/internal/gitaly/service/dependencies.go index a20a0b015efe94321b21871526144abe03cf40ea..404136b9b80d31b0941f320712fe59cb1eacec7e 100644 --- a/internal/gitaly/service/dependencies.go +++ b/internal/gitaly/service/dependencies.go @@ -27,32 +27,32 @@ import ( // Dependencies assembles set of components required by different kinds of services. type Dependencies struct { - Logger log.Logger - Cfg config.Cfg - GitalyHookManager gitalyhook.Manager - TransactionManager transaction.Manager - StorageLocator storage.Locator - ClientPool *client.Pool - GitCmdFactory gitcmd.CommandFactory - BackchannelRegistry *backchannel.Registry - GitlabClient gitlab.Client - CatfileCache catfile.Cache - DiskCache cache.Cache - PackObjectsCache streamcache.Cache - PackObjectsLimiter limiter.Limiter - LimitHandler *limithandler.LimiterMiddleware - RepositoryCounter *counter.RepositoryCounter - UpdaterWithHooks *updateref.UpdaterWithHooks - HousekeepingManager housekeepingmgr.Manager - TransactionRegistry *storagemgr.TransactionRegistry - Node storage.Node - BackupSink *backup.Sink - BackupLocator backup.Locator - ProcReceiveRegistry *gitalyhook.ProcReceiveRegistry - BundleURIManager *bundleuri.GenerationManager - LocalRepositoryFactory localrepo.Factory - MigrationStateManager migration.StateManager - ArchiveCache streamcache.Cache + Logger log.Logger + Cfg config.Cfg + GitalyHookManager gitalyhook.Manager + TransactionManager transaction.Manager + StorageLocator storage.Locator + ClientPool *client.Pool + GitCmdFactory gitcmd.CommandFactory + BackchannelRegistry *backchannel.Registry + GitlabClient gitlab.Client + CatfileCache catfile.Cache + DiskCache cache.Cache + PackObjectsCache streamcache.Cache + PackObjectsLimiter, PackObjectsLimiterUnauthenticated limiter.Limiter + LimitHandler *limithandler.LimiterMiddleware + RepositoryCounter *counter.RepositoryCounter + UpdaterWithHooks *updateref.UpdaterWithHooks + HousekeepingManager housekeepingmgr.Manager + TransactionRegistry *storagemgr.TransactionRegistry + Node storage.Node + BackupSink *backup.Sink + BackupLocator backup.Locator + ProcReceiveRegistry *gitalyhook.ProcReceiveRegistry + BundleURIManager *bundleuri.GenerationManager + LocalRepositoryFactory localrepo.Factory + MigrationStateManager migration.StateManager + ArchiveCache streamcache.Cache } // GetLogger returns the logger. @@ -140,6 +140,12 @@ func (dc *Dependencies) GetPackObjectsLimiter() limiter.Limiter { return dc.PackObjectsLimiter } +// GetPackObjectsLimiterUnauthenticated returns the pack-objects limiter for +// unauthenticated requests. +func (dc *Dependencies) GetPackObjectsLimiterUnauthenticated() limiter.Limiter { + return dc.PackObjectsLimiterUnauthenticated +} + // GetTransactionRegistry returns the TransactionRegistry. func (dc *Dependencies) GetTransactionRegistry() *storagemgr.TransactionRegistry { return dc.TransactionRegistry diff --git a/internal/gitaly/service/hook/pack_objects.go b/internal/gitaly/service/hook/pack_objects.go index 9672890ab923b47069c47a96a78bc596cdab4cbb..efcadbddb758ff8698cac02a6f9878c6da98c4b9 100644 --- a/internal/gitaly/service/hook/pack_objects.go +++ b/internal/gitaly/service/hook/pack_objects.go @@ -17,9 +17,11 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "gitlab.com/gitlab-org/gitaly/v16/internal/featureflag" "gitlab.com/gitlab-org/gitaly/v16/internal/git/gitcmd" "gitlab.com/gitlab-org/gitaly/v16/internal/git/pktline" gitalyhook "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/hook" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/server/auth" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/storage" "gitlab.com/gitlab-org/gitaly/v16/internal/helper" "gitlab.com/gitlab-org/gitaly/v16/internal/log" @@ -189,7 +191,15 @@ func (s *server) runPackObjectsLimited( defer stdin.Close() - if _, err := s.packObjectsLimiter.Limit( + limiter := s.packObjectsLimiter + + if featureflag.LimitUnauthenticated.IsEnabled(ctx) { + if !auth.IsAuthenticated(ctx) { + limiter = s.packObjectsLimiterUnauthenticated + } + } + + if _, err := limiter.Limit( ctx, limitkey, func() (interface{}, error) { diff --git a/internal/gitaly/service/hook/pack_objects_test.go b/internal/gitaly/service/hook/pack_objects_test.go index 694cf5707f9f44ee541ca01b2c16c84986a180d4..185a56d5911e0660d37d73baa84664197e3d0863 100644 --- a/internal/gitaly/service/hook/pack_objects_test.go +++ b/internal/gitaly/service/hook/pack_objects_test.go @@ -958,6 +958,10 @@ func TestPackObjects_concurrencyLimit(t *testing.T) { }), }, testserver.WithPackObjectsLimiter(limiter), + // Since the LimitUnauthenticated feature flag is enabled by default in tests, + // and these test requests are unauthenticated, we need to set both limiters + // to the same instance so the test assertions work correctly. + testserver.WithPackObjectsLimiterUnauthenticated(limiter), ) requests := tc.setup(t, cfg) @@ -1054,3 +1058,137 @@ func TestPackObjects_concurrencyLimit(t *testing.T) { }) } } + +func TestPackObjects_authentication_based_concurrency_limit(t *testing.T) { + t.Parallel() + + ctx := testhelper.Context(t) + cfg := cfgWithCache(t, 0) + + args := []string{"pack-objects", "--revs", "--thin", "--stdout", "--progress", "--delta-base-offset"} + + // Test that unauthenticated requests use the unauthenticated limiter when the feature flag is enabled + t.Run("unauthenticated requests use unauthenticated limiter", func(t *testing.T) { + t.Parallel() + + limiterCtx, cancel, simulateTimeout := testhelper.ContextWithSimulatedTimeout(ctx) + defer cancel() + + authMonitor := limiter.NewPackObjectsConcurrencyMonitor( + cfg.Prometheus.GRPCLatencyBuckets, + ) + authLimiter := limiter.NewConcurrencyLimiter( + limiter.NewAdaptiveLimit("authenticatedLimit", limiter.AdaptiveSetting{Initial: 1}), + 0, + 1*time.Millisecond, + authMonitor, + ) + authLimiter.SetWaitTimeoutContext = func() context.Context { return limiterCtx } + + unauthMonitor := limiter.NewPackObjectsConcurrencyMonitor( + cfg.Prometheus.GRPCLatencyBuckets, + ) + unauthLimiter := limiter.NewConcurrencyLimiter( + limiter.NewAdaptiveLimit("unauthenticatedLimit", limiter.AdaptiveSetting{Initial: 1}), + 0, + 1*time.Millisecond, + unauthMonitor, + ) + unauthLimiter.SetWaitTimeoutContext = func() context.Context { return limiterCtx } + + authRegistry := prometheus.NewRegistry() + authRegistry.MustRegister(authMonitor) + + unauthRegistry := prometheus.NewRegistry() + unauthRegistry.MustRegister(unauthMonitor) + + receivedCh, blockCh := make(chan struct{}), make(chan struct{}) + cfg.SocketPath = runHooksServer(t, cfg, []serverOption{ + withRunPackObjectsFn(func( + context.Context, + gitcmd.CommandFactory, + io.Writer, + *gitalypb.PackObjectsHookWithSidechannelRequest, + *packObjectsArgs, + io.Reader, + string, + ) error { + receivedCh <- struct{}{} + <-blockCh + return nil + }), + }, + testserver.WithPackObjectsLimiter(authLimiter), + testserver.WithPackObjectsLimiterUnauthenticated(unauthLimiter), + ) + + repo, _ := gittest.CreateRepository(t, ctx, cfg) + hooksPayloadEnv := hooksPayloadEnvForRepository(t, ctx, cfg, repo) + + // Unauthenticated request (no auth token) + request := &gitalypb.PackObjectsHookWithSidechannelRequest{ + GlId: "user-123", + RemoteIp: "1.2.3.4", + Repository: repo, + Args: args, + EnvironmentVariables: hooksPayloadEnv, + } + + ctx1, wt1, err := setupSidechannel(t, ctx, "1dd08961455abf80ef9115f4afdc1c6f968b503c") + require.NoError(t, err) + + ctx2, wt2, err := setupSidechannel(t, ctx, "2dd08961455abf80ef9115f4afdc1c6f968b503") + require.NoError(t, err) + + client, conn := newHooksClient(t, cfg.SocketPath) + defer testhelper.MustClose(t, conn) + + var wg sync.WaitGroup + wg.Add(2) + + errChan := make(chan error) + + // Fire off two requests with the same IP address + // Both should use the unauthenticated limiter + for _, c := range []context.Context{ctx1, ctx2} { + go func(c context.Context) { + defer wg.Done() + _, err := client.PackObjectsHookWithSidechannel(c, request) + if err != nil { + errChan <- err + } + }(c) + } + + // Wait for the first request to be processed + <-receivedCh + + // Verify the unauthenticated limiter is being used + require.NoError(t, + testutil.GatherAndCompare(unauthRegistry, + bytes.NewBufferString(`# HELP gitaly_pack_objects_in_progress Gauge of number of concurrent in-progress calls +# TYPE gitaly_pack_objects_in_progress gauge +gitaly_pack_objects_in_progress 1 +`), "gitaly_pack_objects_in_progress")) + + // Verify the authenticated limiter is NOT being used + require.NoError(t, + testutil.GatherAndCompare(authRegistry, + bytes.NewBufferString(`# HELP gitaly_pack_objects_in_progress Gauge of number of concurrent in-progress calls +# TYPE gitaly_pack_objects_in_progress gauge +gitaly_pack_objects_in_progress 0 +`), "gitaly_pack_objects_in_progress")) + + // Trigger timeout for the queued request + simulateTimeout() + + err = <-errChan + testhelper.RequireGrpcCode(t, err, codes.ResourceExhausted) + + close(blockCh) + + wg.Wait() + require.NoError(t, wt1.Wait()) + require.NoError(t, wt2.Wait()) + }) +} diff --git a/internal/gitaly/service/hook/server.go b/internal/gitaly/service/hook/server.go index 26f3ee5f02eb29150943ffa39ae1825546b0bfa1..46d1f181882a148cbfaaed3dfc600a42c982546f 100644 --- a/internal/gitaly/service/hook/server.go +++ b/internal/gitaly/service/hook/server.go @@ -17,14 +17,14 @@ import ( type server struct { gitalypb.UnimplementedHookServiceServer - logger log.Logger - manager gitalyhook.Manager - locator storage.Locator - gitCmdFactory gitcmd.CommandFactory - packObjectsCache streamcache.Cache - packObjectsLimiter limiter.Limiter - txRegistry *storagemgr.TransactionRegistry - runPackObjectsFn func( + logger log.Logger + manager gitalyhook.Manager + locator storage.Locator + gitCmdFactory gitcmd.CommandFactory + packObjectsCache streamcache.Cache + packObjectsLimiter, packObjectsLimiterUnauthenticated limiter.Limiter + txRegistry *storagemgr.TransactionRegistry + runPackObjectsFn func( context.Context, gitcmd.CommandFactory, io.Writer, @@ -38,14 +38,15 @@ type server struct { // NewServer creates a new instance of a gRPC namespace server func NewServer(deps *service.Dependencies) gitalypb.HookServiceServer { srv := &server{ - logger: deps.GetLogger(), - manager: deps.GetHookManager(), - locator: deps.GetLocator(), - gitCmdFactory: deps.GetGitCmdFactory(), - packObjectsCache: deps.GetPackObjectsCache(), - packObjectsLimiter: deps.GetPackObjectsLimiter(), - txRegistry: deps.GetTransactionRegistry(), - runPackObjectsFn: runPackObjects, + logger: deps.GetLogger(), + manager: deps.GetHookManager(), + locator: deps.GetLocator(), + gitCmdFactory: deps.GetGitCmdFactory(), + packObjectsCache: deps.GetPackObjectsCache(), + packObjectsLimiter: deps.GetPackObjectsLimiter(), + packObjectsLimiterUnauthenticated: deps.GetPackObjectsLimiterUnauthenticated(), + txRegistry: deps.GetTransactionRegistry(), + runPackObjectsFn: runPackObjects, } return srv diff --git a/internal/grpc/middleware/limithandler/middleware.go b/internal/grpc/middleware/limithandler/middleware.go index d83e31fc5acf1d8b4df9ee8647fd1da147e67685..9f9abbc772a9e19b0a6458840aace6691c029d32 100644 --- a/internal/grpc/middleware/limithandler/middleware.go +++ b/internal/grpc/middleware/limithandler/middleware.go @@ -5,7 +5,9 @@ import ( "fmt" "github.com/prometheus/client_golang/prometheus" + "gitlab.com/gitlab-org/gitaly/v16/internal/featureflag" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/config" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/server/auth" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/requestinfohandler" "gitlab.com/gitlab-org/gitaly/v16/internal/limiter" "google.golang.org/grpc" @@ -25,10 +27,11 @@ func LimitConcurrencyByRepo(ctx context.Context) string { // LimiterMiddleware contains rate limiter state type LimiterMiddleware struct { - methodLimiters map[string]limiter.Limiter - getLockKey GetLockKey - requestsDroppedMetric *prometheus.CounterVec - collect func(metrics chan<- prometheus.Metric) + methodLimiters map[string]limiter.Limiter + methodLimitersUnauthenticated map[string]limiter.Limiter + getLockKey GetLockKey + requestsDroppedMetric *prometheus.CounterVec + collect func(metrics chan<- prometheus.Metric) } // New creates a new middleware that limits requests. SetupFunc sets up the @@ -76,7 +79,19 @@ func (c *LimiterMiddleware) UnaryInterceptor() grpc.UnaryServerInterceptor { return handler(ctx, req) } + // Check if request is authenticated limiter := c.methodLimiters[info.FullMethod] + + if featureflag.LimitUnauthenticated.IsEnabled(ctx) { + unauthLimiter, ok := c.methodLimitersUnauthenticated[info.FullMethod] + // Use auth.IsAuthenticated to check if the token was cryptographically validated, + // not just whether a token was present in metadata. This prevents spoofed tokens + // from bypassing unauthenticated rate limits. + if !auth.IsAuthenticated(ctx) && ok { + limiter = unauthLimiter + } + } + if limiter == nil { // No concurrency limiting return handler(ctx, req) @@ -125,7 +140,20 @@ func (w *wrappedStream) RecvMsg(m interface{}) error { return nil } + // Check if request is authenticated limiter := w.limiterMiddleware.methodLimiters[w.info.FullMethod] + + if featureflag.LimitUnauthenticated.IsEnabled(ctx) { + unauthLimiter, ok := w.limiterMiddleware.methodLimitersUnauthenticated[w.info.FullMethod] + // Use auth.IsAuthenticated to check if the token was cryptographically validated, + // not just whether a token was present in metadata. This prevents spoofed tokens + // from bypassing unauthenticated rate limits. + if !auth.IsAuthenticated(ctx) && ok { + // Unauthenticated request + limiter = unauthLimiter + } + } + if limiter == nil { // No concurrency limiting return nil @@ -158,7 +186,10 @@ func (w *wrappedStream) RecvMsg(m interface{}) error { // requests based on RPC and repository func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, SetupFunc) { perRPCLimits := map[string]*limiter.AdaptiveLimit{} + perRPCLimitsUnauthenticated := map[string]*limiter.AdaptiveLimit{} + for _, concurrency := range cfg.Concurrency { + // Create authenticated limiter limitName := fmt.Sprintf("perRPC%s", concurrency.RPC) if concurrency.Adaptive { perRPCLimits[concurrency.RPC] = limiter.NewAdaptiveLimit(limitName, limiter.AdaptiveSetting{ @@ -172,6 +203,25 @@ func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, Initial: concurrency.MaxPerRepo, }) } + + // Create unauthenticated limiter if configured + unauthLimits := concurrency.Unauthenticated + if unauthLimits.Adaptive || unauthLimits.MaxPerRepo > 0 || + unauthLimits.InitialLimit > 0 || unauthLimits.MaxLimit > 0 || unauthLimits.MinLimit > 0 { + limitNameUnauth := fmt.Sprintf("perRPC%s-unauthenticated", concurrency.RPC) + if unauthLimits.Adaptive { + perRPCLimitsUnauthenticated[concurrency.RPC] = limiter.NewAdaptiveLimit(limitNameUnauth, limiter.AdaptiveSetting{ + Initial: unauthLimits.InitialLimit, + Max: unauthLimits.MaxLimit, + Min: unauthLimits.MinLimit, + BackoffFactor: limiter.DefaultBackoffFactor, + }) + } else if unauthLimits.MaxPerRepo > 0 { + perRPCLimitsUnauthenticated[concurrency.RPC] = limiter.NewAdaptiveLimit(limitNameUnauth, limiter.AdaptiveSetting{ + Initial: unauthLimits.MaxPerRepo, + }) + } + } } return perRPCLimits, func(cfg config.Cfg, middleware *LimiterMiddleware) { acquiringSecondsMetric := prometheus.NewHistogramVec( @@ -210,7 +260,10 @@ func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, } result := make(map[string]limiter.Limiter) + resultUnauthenticated := make(map[string]limiter.Limiter) + for _, concurrency := range cfg.Concurrency { + // Create authenticated limiter result[concurrency.RPC] = limiter.NewConcurrencyLimiter( perRPCLimits[concurrency.RPC], concurrency.MaxQueueSize, @@ -220,6 +273,20 @@ func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, queuedMetric, inProgressMetric, acquiringSecondsMetric, middleware.requestsDroppedMetric, ), ) + + // Create unauthenticated limiter if configured + if adaptiveLimit, ok := perRPCLimitsUnauthenticated[concurrency.RPC]; ok { + unauthLimits := concurrency.Unauthenticated + resultUnauthenticated[concurrency.RPC] = limiter.NewConcurrencyLimiter( + adaptiveLimit, + unauthLimits.MaxQueueSize, + unauthLimits.MaxQueueWait.Duration(), + limiter.NewPerRPCPromMonitor( + "gitaly", concurrency.RPC+"-unauthenticated", + queuedMetric, inProgressMetric, acquiringSecondsMetric, middleware.requestsDroppedMetric, + ), + ) + } } // Set default for ReplicateRepository. @@ -237,5 +304,6 @@ func WithConcurrencyLimiters(cfg config.Cfg) (map[string]*limiter.AdaptiveLimit, } middleware.methodLimiters = result + middleware.methodLimitersUnauthenticated = resultUnauthenticated } } diff --git a/internal/grpc/middleware/limithandler/middleware_test.go b/internal/grpc/middleware/limithandler/middleware_test.go index eb944b88086b2c0bf9b44467027b9825c658ca22..fd598aa36442fa4f4f4378851ded316fc88c9602 100644 --- a/internal/grpc/middleware/limithandler/middleware_test.go +++ b/internal/grpc/middleware/limithandler/middleware_test.go @@ -11,9 +11,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + gitalyauth "gitlab.com/gitlab-org/gitaly/v16/auth" "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/config" + gitalycfgauth "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/config/auth" + "gitlab.com/gitlab-org/gitaly/v16/internal/gitaly/server/auth" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/client" "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/limithandler" + "gitlab.com/gitlab-org/gitaly/v16/internal/grpc/middleware/requestinfohandler" "gitlab.com/gitlab-org/gitaly/v16/internal/helper/duration" "gitlab.com/gitlab-org/gitaly/v16/internal/limiter" "gitlab.com/gitlab-org/gitaly/v16/internal/structerr" @@ -35,19 +39,25 @@ func TestWithConcurrencyLimiters(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ { - RPC: "/grpc.testing.TestService/UnaryCall", - MaxPerRepo: 1, + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + }, }, { - RPC: "/grpc.testing.TestService/FullDuplexCall", - MaxPerRepo: 99, + RPC: "/grpc.testing.TestService/FullDuplexCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 99, + }, }, { - RPC: "/grpc.testing.TestService/AnotherUnaryCall", - Adaptive: true, - MinLimit: 5, - InitialLimit: 10, - MaxLimit: 15, + RPC: "/grpc.testing.TestService/AnotherUnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + Adaptive: true, + MinLimit: 5, + InitialLimit: 10, + MaxLimit: 15, + }, }, }, } @@ -81,7 +91,12 @@ func TestUnaryLimitHandler(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ - {RPC: "/grpc.testing.TestService/UnaryCall", MaxPerRepo: 2}, + { + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 2, + }, + }, }, } @@ -141,23 +156,25 @@ func TestUnaryLimitHandler_queueing(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ { - RPC: "/grpc.testing.TestService/UnaryCall", - MaxPerRepo: 1, - MaxQueueSize: 1, - // This test setups two requests: - // - The first one is eligible. It enters the handler and blocks the queue. - // - The second request is blocked until timeout. - // Both of them shares this timeout. Internally, the limiter creates a context - // deadline to reject timed out requests. If it's set too low, there's a tiny - // possibility that the context reaches the deadline when the limiter checks the - // request. Thus, setting a reasonable timeout here and adding some retry - // attempts below make the test stable. - // Another approach is to implement a hooking mechanism that allows us to - // override context deadline setup. However, that approach exposes the internal - // implementation of the limiter. It also adds unnecessarily logics. - // Congiuring the timeout is more straight-forward and close to the expected - // behavior. - MaxQueueWait: duration.Duration(100 * time.Millisecond), + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + MaxQueueSize: 1, + // This test setups two requests: + // - The first one is eligible. It enters the handler and blocks the queue. + // - The second request is blocked until timeout. + // Both of them shares this timeout. Internally, the limiter creates a context + // deadline to reject timed out requests. If it's set too low, there's a tiny + // possibility that the context reaches the deadline when the limiter checks the + // request. Thus, setting a reasonable timeout here and adding some retry + // attempts below make the test stable. + // Another approach is to implement a hooking mechanism that allows us to + // override context deadline setup. However, that approach exposes the internal + // implementation of the limiter. It also adds unnecessarily logics. + // Congiuring the timeout is more straight-forward and close to the expected + // behavior. + MaxQueueWait: duration.Duration(100 * time.Millisecond), + }, }, }, } @@ -221,13 +238,17 @@ func TestUnaryLimitHandler_queueing(t *testing.T) { // that has no wait limit. We of course expect that the actual // config should not have any maximum queueing time. { - RPC: "dummy", - MaxPerRepo: 1, - MaxQueueWait: duration.Duration(1 * time.Nanosecond), + RPC: "dummy", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + MaxQueueWait: duration.Duration(1 * time.Nanosecond), + }, }, { - RPC: "/grpc.testing.TestService/UnaryCall", - MaxPerRepo: 1, + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + }, }, }, } @@ -487,9 +508,11 @@ func TestStreamLimitHandler(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ { - RPC: tc.fullname, - MaxPerRepo: tc.maxConcurrency, - MaxQueueSize: maxQueueSize, + RPC: tc.fullname, + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: tc.maxConcurrency, + MaxQueueSize: maxQueueSize, + }, }, }, } @@ -540,7 +563,13 @@ func TestStreamLimitHandler_error(t *testing.T) { cfg := config.Cfg{ Concurrency: []config.Concurrency{ - {RPC: "/grpc.testing.TestService/FullDuplexCall", MaxPerRepo: 1, MaxQueueSize: 1}, + { + RPC: "/grpc.testing.TestService/FullDuplexCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + MaxQueueSize: 1, + }, + }, }, } @@ -660,7 +689,13 @@ func TestConcurrencyLimitHandlerMetrics(t *testing.T) { methodName := "/grpc.testing.TestService/UnaryCall" cfg := config.Cfg{ Concurrency: []config.Concurrency{ - {RPC: methodName, MaxPerRepo: 1, MaxQueueSize: 1}, + { + RPC: methodName, + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 1, + MaxQueueSize: 1, + }, + }, }, } @@ -737,6 +772,237 @@ func TestConcurrencyLimitHandlerMetrics(t *testing.T) { <-respCh } +func TestAuthenticatedVsUnauthenticatedLimiting(t *testing.T) { + t.Parallel() + + t.Run("unary: authenticated and unauthenticated requests use separate limiters", func(t *testing.T) { + t.Parallel() + + s := &queueTestServer{ + server: server{ + blockCh: make(chan struct{}), + }, + reqArrivedCh: make(chan struct{}), + } + + cfg := config.Cfg{ + Concurrency: []config.Concurrency{ + { + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 2, // Authenticated: 2 concurrent + MaxQueueSize: 10, + }, + Unauthenticated: config.ConcurrencyLimits{ + MaxPerRepo: 1, // Unauthenticated: 1 concurrent + MaxQueueSize: 10, + }, + }, + }, + } + + _, setupPerRPCConcurrencyLimiters := limithandler.WithConcurrencyLimiters(cfg) + lh := limithandler.New(cfg, fixedLockKey, setupPerRPCConcurrencyLimiters) + srv, serverSocketPath := runServerWithAuth(t, s, lh.UnaryInterceptor(), nil) + defer srv.Stop() + + client, conn := newClient(t, serverSocketPath) + defer conn.Close() + + authClient, authConn := newAuthenticatedClient(t, serverSocketPath, "test-secret") + defer authConn.Close() + + ctx := testhelper.Context(t) + + // First, send 2 authenticated requests - both should be accepted (limit is 2) + var wg sync.WaitGroup + wg.Add(2) + for i := 0; i < 2; i++ { + go func() { + defer wg.Done() + _, err := authClient.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) + require.NoError(t, err) + }() + } + + // Wait for both authenticated requests to arrive + <-s.reqArrivedCh + <-s.reqArrivedCh + + // Now send an unauthenticated request - it should also be accepted + // because it uses a separate limiter + wg.Add(1) + go func() { + defer wg.Done() + _, err := client.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) + require.NoError(t, err) + }() + + // Wait for the unauthenticated request to arrive + <-s.reqArrivedCh + + // Verify no more requests can get through (both limiters saturated) + select { + case <-s.reqArrivedCh: + require.FailNow(t, "received unexpected fourth request") + case <-time.After(100 * time.Millisecond): + } + + // Unblock all requests + close(s.blockCh) + wg.Wait() + }) + + t.Run("unary: unauthenticated falls back to authenticated limiter when not configured", func(t *testing.T) { + t.Parallel() + + s := &queueTestServer{ + server: server{ + blockCh: make(chan struct{}), + }, + reqArrivedCh: make(chan struct{}), + } + + cfg := config.Cfg{ + Concurrency: []config.Concurrency{ + { + RPC: "/grpc.testing.TestService/UnaryCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 2, // Only authenticated limiter configured + MaxQueueSize: 10, + }, + // No unauthenticated limiter configured + }, + }, + } + + _, setupPerRPCConcurrencyLimiters := limithandler.WithConcurrencyLimiters(cfg) + lh := limithandler.New(cfg, fixedLockKey, setupPerRPCConcurrencyLimiters) + srv, serverSocketPath := runServerWithAuth(t, s, lh.UnaryInterceptor(), nil) + defer srv.Stop() + + client, conn := newClient(t, serverSocketPath) + defer conn.Close() + + authClient, authConn := newAuthenticatedClient(t, serverSocketPath, "test-secret") + defer authConn.Close() + + ctx := testhelper.Context(t) + + var wg sync.WaitGroup + + // Send 1 authenticated and 1 unauthenticated request + // Both should be accepted (they share the same limiter with limit 2) + wg.Add(2) + go func() { + defer wg.Done() + _, err := authClient.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) + require.NoError(t, err) + }() + go func() { + defer wg.Done() + _, err := client.UnaryCall(ctx, &grpc_testing.SimpleRequest{}) + require.NoError(t, err) + }() + + // Wait for both requests to arrive + <-s.reqArrivedCh + <-s.reqArrivedCh + + // Verify no more requests can get through (shared limiter saturated) + select { + case <-s.reqArrivedCh: + require.FailNow(t, "received unexpected third request") + case <-time.After(100 * time.Millisecond): + } + + // Unblock all requests + close(s.blockCh) + wg.Wait() + }) + + t.Run("stream: authenticated and unauthenticated requests use separate limiters", func(t *testing.T) { + t.Parallel() + + s := &queueTestServer{ + server: server{ + blockCh: make(chan struct{}), + }, + reqArrivedCh: make(chan struct{}), + } + + cfg := config.Cfg{ + Concurrency: []config.Concurrency{ + { + RPC: "/grpc.testing.TestService/FullDuplexCall", + ConcurrencyLimits: config.ConcurrencyLimits{ + MaxPerRepo: 2, // Authenticated: 2 concurrent + MaxQueueSize: 10, + }, + Unauthenticated: config.ConcurrencyLimits{ + MaxPerRepo: 1, // Unauthenticated: 1 concurrent + MaxQueueSize: 10, + }, + }, + }, + } + + _, setupPerRPCConcurrencyLimiters := limithandler.WithConcurrencyLimiters(cfg) + lh := limithandler.New(cfg, fixedLockKey, setupPerRPCConcurrencyLimiters) + srv, serverSocketPath := runServerWithAuth(t, s, nil, lh.StreamInterceptor()) + defer srv.Stop() + + client, conn := newClient(t, serverSocketPath) + defer conn.Close() + + authClient, authConn := newAuthenticatedClient(t, serverSocketPath, "test-secret") + defer authConn.Close() + + ctx := testhelper.Context(t) + + respChan := make(chan *grpc_testing.StreamingOutputCallResponse) + + // Send 2 authenticated streams + for i := 0; i < 2; i++ { + go func() { + stream, err := authClient.FullDuplexCall(ctx) + require.NoError(t, err) + require.NoError(t, stream.Send(&grpc_testing.StreamingOutputCallRequest{})) + require.NoError(t, stream.CloseSend()) + resp, err := stream.Recv() + require.NoError(t, err) + respChan <- resp + }() + } + + // Wait for both authenticated streams to arrive + <-s.reqArrivedCh + <-s.reqArrivedCh + + // Send 1 unauthenticated stream - should be accepted with separate limiter + go func() { + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + require.NoError(t, stream.Send(&grpc_testing.StreamingOutputCallRequest{})) + require.NoError(t, stream.CloseSend()) + resp, err := stream.Recv() + require.NoError(t, err) + respChan <- resp + }() + + // Wait for the unauthenticated stream to arrive + <-s.reqArrivedCh + + // Unblock all streams + close(s.blockCh) + + // Collect all responses + for i := 0; i < 3; i++ { + <-respChan + } + }) +} + func runServer(tb testing.TB, s grpc_testing.TestServiceServer, opt ...grpc.ServerOption) (*grpc.Server, string) { serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(tb) grpcServer := grpc.NewServer(opt...) @@ -750,6 +1016,47 @@ func runServer(tb testing.TB, s grpc_testing.TestServiceServer, opt ...grpc.Serv return grpcServer, "unix://" + serverSocketPath } +func runServerWithAuth(tb testing.TB, s grpc_testing.TestServiceServer, unaryInt grpc.UnaryServerInterceptor, streamInt grpc.StreamServerInterceptor) (*grpc.Server, string) { + serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(tb) + + var unaryInterceptors []grpc.UnaryServerInterceptor + var streamInterceptors []grpc.StreamServerInterceptor + + // Add requestinfohandler first to extract authentication info + unaryInterceptors = append(unaryInterceptors, requestinfohandler.UnaryInterceptor) + streamInterceptors = append(streamInterceptors, requestinfohandler.StreamInterceptor) + + // Add auth interceptor to validate tokens and set authenticated flag + // Use transitioning mode so invalid tokens don't block requests (for testing unauthenticated flow) + authCfg := gitalycfgauth.Config{ + Token: "test-secret", + Transitioning: true, + } + unaryInterceptors = append(unaryInterceptors, auth.UnaryServerInterceptor(authCfg)) + streamInterceptors = append(streamInterceptors, auth.StreamServerInterceptor(authCfg)) + + // Then add the limiter interceptor + if unaryInt != nil { + unaryInterceptors = append(unaryInterceptors, unaryInt) + } + if streamInt != nil { + streamInterceptors = append(streamInterceptors, streamInt) + } + + grpcServer := grpc.NewServer( + grpc.ChainUnaryInterceptor(unaryInterceptors...), + grpc.ChainStreamInterceptor(streamInterceptors...), + ) + grpc_testing.RegisterTestServiceServer(grpcServer, s) + + lis, err := net.Listen("unix", serverSocketPath) + require.NoError(tb, err) + + go testhelper.MustServe(tb, grpcServer, lis) + + return grpcServer, "unix://" + serverSocketPath +} + func newClient(tb testing.TB, serverSocketPath string) (grpc_testing.TestServiceClient, *grpc.ClientConn) { conn, err := client.New(testhelper.Context(tb), serverSocketPath) if err != nil { @@ -758,3 +1065,18 @@ func newClient(tb testing.TB, serverSocketPath string) (grpc_testing.TestService return grpc_testing.NewTestServiceClient(conn), conn } + +func newAuthenticatedClient(tb testing.TB, serverSocketPath, secret string) (grpc_testing.TestServiceClient, *grpc.ClientConn) { + conn, err := client.New( + testhelper.Context(tb), + serverSocketPath, + client.WithGrpcOptions([]grpc.DialOption{ + grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(secret)), + }), + ) + if err != nil { + tb.Fatal(err) + } + + return grpc_testing.NewTestServiceClient(conn), conn +} diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go index da9048031f8545efedcfcee7ffe876c2e1769c7d..fa2e324f60afd5de2bdc9bbf12ba6178f0add40e 100644 --- a/internal/testhelper/testhelper.go +++ b/internal/testhelper/testhelper.go @@ -337,6 +337,9 @@ func ContextWithoutCancel(opts ...ContextOpt) context.Context { // Enable trace2 logs for receive pack ctx = featureflag.ContextWithFeatureFlag(ctx, featureflag.ReceivePackTrace2Hook, true) + // Enable unauthenticated limiter + ctx = featureflag.ContextWithFeatureFlag(ctx, featureflag.LimitUnauthenticated, true) + for _, opt := range opts { ctx = opt(ctx) } diff --git a/internal/testhelper/testserver/gitaly.go b/internal/testhelper/testserver/gitaly.go index d6350cca9757321cfb5a093ab8861113ecde3337..efedc9230bf39232f1b504471d78bf4e6e23051b 100644 --- a/internal/testhelper/testserver/gitaly.go +++ b/internal/testhelper/testserver/gitaly.go @@ -290,36 +290,37 @@ func registerHealthServerIfNotRegistered(srv *grpc.Server) { } type gitalyServerDeps struct { - disablePraefect bool - logger log.Logger - conns *client.Pool - locator storage.Locator - txMgr transaction.Manager - hookMgr hook.Manager - gitlabClient gitlab.Client - gitCmdFactory gitcmd.CommandFactory - backchannelReg *backchannel.Registry - catfileCache catfile.Cache - diskCache cache.Cache - packObjectsCache streamcache.Cache - packObjectsLimiter limiter.Limiter - limitHandler *limithandler.LimiterMiddleware - repositoryCounter *counter.RepositoryCounter - updaterWithHooks *updateref.UpdaterWithHooks - housekeepingManager housekeepingmgr.Manager - backupSink *backup.Sink - backupLocator backup.Locator - signingKey string - transactionRegistry *storagemgr.TransactionRegistry - procReceiveRegistry *hook.ProcReceiveRegistry - bundleURIManager *bundleuri.GenerationManager - bundleURISink *bundleuri.Sink - bundleURIStrategy bundleuri.GenerationStrategy - localRepoFactory localrepo.Factory - migrations *[]migration.Migration - archiveCache streamcache.Cache - MigrationStateManager migration.StateManager - transactionInterceptorsFn func(log.Logger, storage.Node, localrepo.Factory) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) + disablePraefect bool + logger log.Logger + conns *client.Pool + locator storage.Locator + txMgr transaction.Manager + hookMgr hook.Manager + gitlabClient gitlab.Client + gitCmdFactory gitcmd.CommandFactory + backchannelReg *backchannel.Registry + catfileCache catfile.Cache + diskCache cache.Cache + packObjectsCache streamcache.Cache + packObjectsLimiter limiter.Limiter + packObjectsLimiterUnauthenticated limiter.Limiter + limitHandler *limithandler.LimiterMiddleware + repositoryCounter *counter.RepositoryCounter + updaterWithHooks *updateref.UpdaterWithHooks + housekeepingManager housekeepingmgr.Manager + backupSink *backup.Sink + backupLocator backup.Locator + signingKey string + transactionRegistry *storagemgr.TransactionRegistry + procReceiveRegistry *hook.ProcReceiveRegistry + bundleURIManager *bundleuri.GenerationManager + bundleURISink *bundleuri.Sink + bundleURIStrategy bundleuri.GenerationStrategy + localRepoFactory localrepo.Factory + migrations *[]migration.Migration + archiveCache streamcache.Cache + MigrationStateManager migration.StateManager + transactionInterceptorsFn func(log.Logger, storage.Node, localrepo.Factory) ([]grpc.UnaryServerInterceptor, []grpc.StreamServerInterceptor) } func (gsd *gitalyServerDeps) createDependencies(tb testing.TB, ctx context.Context, cfg config.Cfg) *service.Dependencies { @@ -480,6 +481,15 @@ func (gsd *gitalyServerDeps) createDependencies(tb testing.TB, ctx context.Conte ) } + if gsd.packObjectsLimiterUnauthenticated == nil { + gsd.packObjectsLimiterUnauthenticated = limiter.NewConcurrencyLimiter( + limiter.NewAdaptiveLimit("staticLimit", limiter.AdaptiveSetting{Initial: 0}), + 0, + 0, + limiter.NewNoopConcurrencyMonitor(), + ) + } + if gsd.archiveCache == nil { gsd.archiveCache = streamcache.New(cfg.ArchiveCache, gsd.logger) tb.Cleanup(gsd.archiveCache.Stop) @@ -509,32 +519,33 @@ func (gsd *gitalyServerDeps) createDependencies(tb testing.TB, ctx context.Conte gsd.localRepoFactory = localrepo.NewFactory(gsd.logger, gsd.locator, gsd.gitCmdFactory, gsd.catfileCache) return &service.Dependencies{ - Logger: gsd.logger, - Cfg: cfg, - ClientPool: gsd.conns, - StorageLocator: gsd.locator, - TransactionManager: gsd.txMgr, - GitalyHookManager: gsd.hookMgr, - GitCmdFactory: gsd.gitCmdFactory, - BackchannelRegistry: gsd.backchannelReg, - GitlabClient: gsd.gitlabClient, - CatfileCache: gsd.catfileCache, - DiskCache: gsd.diskCache, - PackObjectsCache: gsd.packObjectsCache, - PackObjectsLimiter: gsd.packObjectsLimiter, - LimitHandler: gsd.limitHandler, - RepositoryCounter: gsd.repositoryCounter, - UpdaterWithHooks: gsd.updaterWithHooks, - HousekeepingManager: gsd.housekeepingManager, - TransactionRegistry: gsd.transactionRegistry, - Node: node, - BackupSink: gsd.backupSink, - BackupLocator: gsd.backupLocator, - ProcReceiveRegistry: gsd.procReceiveRegistry, - BundleURIManager: gsd.bundleURIManager, - LocalRepositoryFactory: gsd.localRepoFactory, - MigrationStateManager: gsd.MigrationStateManager, - ArchiveCache: gsd.archiveCache, + Logger: gsd.logger, + Cfg: cfg, + ClientPool: gsd.conns, + StorageLocator: gsd.locator, + TransactionManager: gsd.txMgr, + GitalyHookManager: gsd.hookMgr, + GitCmdFactory: gsd.gitCmdFactory, + BackchannelRegistry: gsd.backchannelReg, + GitlabClient: gsd.gitlabClient, + CatfileCache: gsd.catfileCache, + DiskCache: gsd.diskCache, + PackObjectsCache: gsd.packObjectsCache, + PackObjectsLimiter: gsd.packObjectsLimiter, + PackObjectsLimiterUnauthenticated: gsd.packObjectsLimiterUnauthenticated, + LimitHandler: gsd.limitHandler, + RepositoryCounter: gsd.repositoryCounter, + UpdaterWithHooks: gsd.updaterWithHooks, + HousekeepingManager: gsd.housekeepingManager, + TransactionRegistry: gsd.transactionRegistry, + Node: node, + BackupSink: gsd.backupSink, + BackupLocator: gsd.backupLocator, + ProcReceiveRegistry: gsd.procReceiveRegistry, + BundleURIManager: gsd.bundleURIManager, + LocalRepositoryFactory: gsd.localRepoFactory, + MigrationStateManager: gsd.MigrationStateManager, + ArchiveCache: gsd.archiveCache, } } @@ -631,6 +642,15 @@ func WithPackObjectsLimiter(limiter *limiter.ConcurrencyLimiter) GitalyServerOpt } } +// WithPackObjectsLimiterUnauthenticated sets the PackObjectsLimiterUnauthenticated that will be +// used for gitaly services initialization. +func WithPackObjectsLimiterUnauthenticated(limiter *limiter.ConcurrencyLimiter) GitalyServerOpt { + return func(deps gitalyServerDeps) gitalyServerDeps { + deps.packObjectsLimiterUnauthenticated = limiter + return deps + } +} + // WithPackObjectsCache sets the PackObjectsCache that will be // used for gitaly services initialization. func WithPackObjectsCache(cache streamcache.Cache) GitalyServerOpt {