From 49bd5e3958457950c5fe169f81c29915e0ebf262 Mon Sep 17 00:00:00 2001 From: Ash McKenzie Date: Tue, 20 Jun 2023 14:53:52 +1000 Subject: [PATCH 1/8] Pass ctx where needed --- cmd/check/main.go | 2 +- .../main.go | 2 +- .../main.go | 2 +- cmd/gitlab-shell/main.go | 2 +- .../command/authorizedkeys/authorized_keys.go | 8 ++--- .../authorizedkeys/authorized_keys_test.go | 2 +- .../authorized_principals.go | 6 ++-- .../authorized_principals_test.go | 2 +- internal/command/command.go | 2 +- internal/command/discover/discover.go | 6 ++-- internal/command/discover/discover_test.go | 4 +-- internal/command/healthcheck/healthcheck.go | 8 ++--- .../command/healthcheck/healthcheck_test.go | 6 ++-- .../lfsauthenticate/lfsauthenticate.go | 12 ++++---- .../lfsauthenticate/lfsauthenticate_test.go | 4 +-- .../personalaccesstoken.go | 9 +++--- .../personalaccesstoken_test.go | 2 +- .../command/receivepack/gitalycall_test.go | 2 +- internal/command/receivepack/receivepack.go | 12 ++++---- .../command/receivepack/receivepack_test.go | 5 ++-- .../twofactorrecover/twofactorrecover.go | 4 +-- .../twofactorrecover/twofactorrecover_test.go | 2 +- .../twofactorverify/twofactorverify.go | 6 ++-- .../twofactorverify/twofactorverify_test.go | 7 +++-- .../command/uploadarchive/gitalycall_test.go | 2 +- .../command/uploadarchive/uploadarchive.go | 8 ++--- .../uploadarchive/uploadarchive_test.go | 2 +- .../command/uploadpack/gitalycall_test.go | 2 +- internal/command/uploadpack/uploadpack.go | 10 +++---- .../command/uploadpack/uploadpack_test.go | 2 +- internal/sshd/connection.go | 8 +++-- internal/sshd/connection_test.go | 29 +++++++++++-------- internal/sshd/session.go | 28 +++++++++--------- internal/sshd/session_test.go | 4 +-- internal/sshd/sshd.go | 2 +- 35 files changed, 114 insertions(+), 100 deletions(-) diff --git a/cmd/check/main.go b/cmd/check/main.go index 578dfdf84..76f217b78 100644 --- a/cmd/check/main.go +++ b/cmd/check/main.go @@ -43,7 +43,7 @@ func main() { ctx, finished := command.Setup(executable.Name, config) defer finished() - if err = cmd.Execute(ctx); err != nil { + if ctx, err = cmd.Execute(ctx); err != nil { fmt.Fprintf(readWriter.ErrOut, "%v\n", err) os.Exit(1) } diff --git a/cmd/gitlab-shell-authorized-keys-check/main.go b/cmd/gitlab-shell-authorized-keys-check/main.go index e272e68b0..707d4cc68 100644 --- a/cmd/gitlab-shell-authorized-keys-check/main.go +++ b/cmd/gitlab-shell-authorized-keys-check/main.go @@ -46,7 +46,7 @@ func main() { ctx, finished := command.Setup(executable.Name, config) defer finished() - if err = cmd.Execute(ctx); err != nil { + if ctx, err = cmd.Execute(ctx); err != nil { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) os.Exit(1) } diff --git a/cmd/gitlab-shell-authorized-principals-check/main.go b/cmd/gitlab-shell-authorized-principals-check/main.go index 10d3daab0..09380fb35 100644 --- a/cmd/gitlab-shell-authorized-principals-check/main.go +++ b/cmd/gitlab-shell-authorized-principals-check/main.go @@ -46,7 +46,7 @@ func main() { ctx, finished := command.Setup(executable.Name, config) defer finished() - if err = cmd.Execute(ctx); err != nil { + if ctx, err = cmd.Execute(ctx); err != nil { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) os.Exit(1) } diff --git a/cmd/gitlab-shell/main.go b/cmd/gitlab-shell/main.go index b789b774d..679d4593e 100644 --- a/cmd/gitlab-shell/main.go +++ b/cmd/gitlab-shell/main.go @@ -76,7 +76,7 @@ func main() { ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("gitlab-shell: main: executing command") fips.Check() - if err := cmd.Execute(ctx); err != nil { + if _, err := cmd.Execute(ctx); err != nil { ctxlog.WithError(err).Warn("gitlab-shell: main: command execution failed") if grpcstatus.Convert(err).Code() != grpccodes.Internal { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go index 46ab5c467..92547cd9d 100644 --- a/internal/command/authorizedkeys/authorized_keys.go +++ b/internal/command/authorizedkeys/authorized_keys.go @@ -18,21 +18,21 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { // Do and return nothing when the expected and actual user don't match. // This can happen when the user in sshd_config doesn't match the user // trying to login. When nothing is printed, the user will be denied access. if c.Args.ExpectedUser != c.Args.ActualUser { // TODO: Log this event once we have a consistent way to log in Go. // See https://gitlab.com/gitlab-org/gitlab-shell/issues/192 for more info. - return nil + return ctx, nil } if err := c.printKeyLine(ctx); err != nil { - return err + return ctx, err } - return nil + return ctx, nil } func (c *Command) printKeyLine(ctx context.Context) error { diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go index b91e460f4..e54634099 100644 --- a/internal/command/authorizedkeys/authorized_keys_test.go +++ b/internal/command/authorizedkeys/authorized_keys_test.go @@ -84,7 +84,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go index a7cfe1a19..38267f454 100644 --- a/internal/command/authorizedprincipals/authorized_principals.go +++ b/internal/command/authorizedprincipals/authorized_principals.go @@ -16,12 +16,12 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { if err := c.printPrincipalLines(); err != nil { - return err + return ctx, err } - return nil + return ctx, nil } func (c *Command) printPrincipalLines() error { diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go index ba4d066ff..b84d8e88a 100644 --- a/internal/command/authorizedprincipals/authorized_principals_test.go +++ b/internal/command/authorizedprincipals/authorized_principals_test.go @@ -42,7 +42,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/command.go b/internal/command/command.go index d9706b5d0..552678f7e 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -9,7 +9,7 @@ import ( ) type Command interface { - Execute(ctx context.Context) error + Execute(ctx context.Context) (context.Context, error) } // Setup() initializes tracing from the configuration file and generates a diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go index 2f81a7844..e0a98ebdc 100644 --- a/internal/command/discover/discover.go +++ b/internal/command/discover/discover.go @@ -16,10 +16,10 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { response, err := c.getUserInfo(ctx) if err != nil { - return fmt.Errorf("Failed to get username: %v", err) + return ctx, fmt.Errorf("Failed to get username: %v", err) } if response.IsAnonymous() { @@ -28,7 +28,7 @@ func (c *Command) Execute(ctx context.Context) error { fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) } - return nil + return ctx, nil } func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) { diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go index df9ca47c3..2683cee65 100644 --- a/internal/command/discover/discover_test.go +++ b/internal/command/discover/discover_test.go @@ -81,7 +81,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) @@ -123,7 +123,7 @@ func TestFailingExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, tc.expectedError) diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go index e80fe2a1c..206a97f7c 100644 --- a/internal/command/healthcheck/healthcheck.go +++ b/internal/command/healthcheck/healthcheck.go @@ -19,20 +19,20 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { response, err := c.runCheck(ctx) if err != nil { - return fmt.Errorf("%v: FAILED - %v", apiMessage, err) + return ctx, fmt.Errorf("%v: FAILED - %v", apiMessage, err) } fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", apiMessage) if !response.Redis { - return fmt.Errorf("%v: FAILED", redisMessage) + return ctx, fmt.Errorf("%v: FAILED", redisMessage) } fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", redisMessage) - return nil + return ctx, nil } func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) { diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go index 12a8444c4..d1c2a6ba5 100644 --- a/internal/command/healthcheck/healthcheck_test.go +++ b/internal/command/healthcheck/healthcheck_test.go @@ -53,7 +53,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String()) @@ -68,7 +68,7 @@ func TestFailingRedisExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Error(t, err, "Redis available via internal API: FAILED") require.Equal(t, "Internal API available: OK\n", buffer.String()) } @@ -82,7 +82,7 @@ func TestFailingAPIExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, "Internal API available: FAILED - Internal API unreachable") } diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go index a06ac93a4..9ca97a9da 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate.go +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -37,10 +37,10 @@ type Payload struct { ExpiresIn int `json:"expires_in,omitempty"` } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { args := c.Args.SshArgs if len(args) < 3 { - return disallowedcommand.Error + return ctx, disallowedcommand.Error } // e.g. git-lfs-authenticate user/repo.git download @@ -49,12 +49,12 @@ func (c *Command) Execute(ctx context.Context) error { action, err := actionFromOperation(operation) if err != nil { - return err + return ctx, err } accessResponse, err := c.verifyAccess(ctx, action, repo) if err != nil { - return err + return ctx, err } payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId) @@ -65,12 +65,12 @@ func (c *Command) Execute(ctx context.Context) error { log.Fields{"operation": operation, "repo": repo, "user_id": accessResponse.UserId}, ).WithError(err).Debug("lfsauthenticate: execute: LFS authentication failed") - return nil + return ctx, nil } fmt.Fprintf(c.ReadWriter.Out, "%s\n", payload) - return nil + return ctx, nil } func actionFromOperation(operation string) (commandargs.CommandType, error) { diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go index 709608c3c..10e5b0dbf 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate_test.go +++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go @@ -54,7 +54,7 @@ func TestFailedRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Error(t, err) require.Equal(t, tc.expectedOutput, err.Error()) @@ -145,7 +145,7 @@ func TestLfsAuthenticateRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go index fcf7dda1c..c4f3deec3 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken.go +++ b/internal/command/personalaccesstoken/personalaccesstoken.go @@ -34,10 +34,10 @@ type tokenArgs struct { ExpiresDate string // Calculated, a TTL is passed from command-line. } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { err := c.parseTokenArgs() if err != nil { - return err + return ctx, err } log.WithContextFields(ctx, log.Fields{ @@ -46,13 +46,14 @@ func (c *Command) Execute(ctx context.Context) error { response, err := c.getPersonalAccessToken(ctx) if err != nil { - return err + return ctx, err } fmt.Fprint(c.ReadWriter.Out, "Token: "+response.Token+"\n") fmt.Fprint(c.ReadWriter.Out, "Scopes: "+strings.Join(response.Scopes, ",")+"\n") fmt.Fprint(c.ReadWriter.Out, "Expires: "+response.ExpiresAt+"\n") - return nil + + return ctx, nil } func (c *Command) parseTokenArgs() error { diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go index c3434ce4f..711f7dac2 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken_test.go +++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go @@ -167,7 +167,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) if tc.expectedError == "" { require.NoError(t, err) diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go index 9f70189fb..c9321821e 100644 --- a/internal/command/receivepack/gitalycall_test.go +++ b/internal/command/receivepack/gitalycall_test.go @@ -72,7 +72,7 @@ func TestReceivePack(t *testing.T) { ctx := correlation.ContextWithCorrelation(context.Background(), "a-correlation-id") ctx = correlation.ContextWithClientName(ctx, "gitlab-shell-tests") - err := cmd.Execute(ctx) + _, err := cmd.Execute(ctx) require.NoError(t, err) if tc.username != "" { diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go index c9ef7cdc3..63301d1f1 100644 --- a/internal/command/receivepack/receivepack.go +++ b/internal/command/receivepack/receivepack.go @@ -18,16 +18,16 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { args := c.Args.SshArgs if len(args) != 2 { - return disallowedcommand.Error + return ctx, disallowedcommand.Error } repo := args[1] response, err := c.verifyAccess(ctx, repo) if err != nil { - return err + return ctx, err } if response.IsCustomAction() { @@ -42,7 +42,7 @@ func (c *Command) Execute(ctx context.Context) error { Response: response, } - return cmd.Execute(ctx) + return ctx, cmd.Execute(ctx) } customAction := customaction.Command{ @@ -50,10 +50,10 @@ func (c *Command) Execute(ctx context.Context) error { ReadWriter: c.ReadWriter, EOFSent: true, } - return customAction.Execute(ctx, response) + return ctx, customAction.Execute(ctx, response) } - return c.performGitalyCall(ctx, response) + return ctx, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go index 17622bb1a..c987daae5 100644 --- a/internal/command/receivepack/receivepack_test.go +++ b/internal/command/receivepack/receivepack_test.go @@ -18,14 +18,15 @@ func TestForbiddenAccess(t *testing.T) { requests := requesthandlers.BuildDisallowedByApiHandlers(t) cmd, _ := setup(t, "disallowed", requests) - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } func TestCustomReceivePack(t *testing.T) { cmd, output := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t)) - require.NoError(t, cmd.Execute(context.Background())) + _, err := cmd.Execute(context.Background()) + require.NoError(t, err) require.Equal(t, "customoutput", output.String()) } diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go index 7496396ab..8828c71d0 100644 --- a/internal/command/twofactorrecover/twofactorrecover.go +++ b/internal/command/twofactorrecover/twofactorrecover.go @@ -22,7 +22,7 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { ctxlog := log.ContextLogger(ctx) ctxlog.Debug("twofactorrecover: execute: Waiting for user input") @@ -34,7 +34,7 @@ func (c *Command) Execute(ctx context.Context) error { fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") } - return nil + return ctx, nil } func (c *Command) getUserAnswer(ctx context.Context) string { diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go index 7e20a0652..8f86777ec 100644 --- a/internal/command/twofactorrecover/twofactorrecover_test.go +++ b/internal/command/twofactorrecover/twofactorrecover_test.go @@ -132,7 +132,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/twofactorverify/twofactorverify.go b/internal/command/twofactorverify/twofactorverify.go index 4041de0c8..cbe68e66c 100644 --- a/internal/command/twofactorverify/twofactorverify.go +++ b/internal/command/twofactorverify/twofactorverify.go @@ -25,10 +25,10 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { client, err := twofactorverify.NewClient(c.Config) if err != nil { - return err + return ctx, err } ctx, cancel := context.WithTimeout(ctx, timeout) @@ -67,7 +67,7 @@ func (c *Command) Execute(ctx context.Context) error { log.WithContextFields(ctx, log.Fields{"message": message}).Info("Two factor verify command finished") fmt.Fprintf(c.ReadWriter.Out, "\n%v\n", message) - return nil + return ctx, nil } func (c *Command) getOTP(ctx context.Context) (string, error) { diff --git a/internal/command/twofactorverify/twofactorverify_test.go b/internal/command/twofactorverify/twofactorverify_test.go index 213c02532..4629be937 100644 --- a/internal/command/twofactorverify/twofactorverify_test.go +++ b/internal/command/twofactorverify/twofactorverify_test.go @@ -160,7 +160,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, prompt+"\n"+tc.expectedOutput, output.String()) @@ -183,7 +183,10 @@ func TestCanceledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error) - go func() { errCh <- cmd.Execute(ctx) }() + go func() { + _, err := cmd.Execute(ctx) + errCh <- err + }() cancel() require.NoError(t, <-errCh) diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go index 0479e3054..8004e20d1 100644 --- a/internal/command/uploadarchive/gitalycall_test.go +++ b/internal/command/uploadarchive/gitalycall_test.go @@ -56,7 +56,7 @@ func TestUploadArchive(t *testing.T) { ctx := correlation.ContextWithCorrelation(context.Background(), "a-correlation-id") ctx = correlation.ContextWithClientName(ctx, "gitlab-shell-tests") - err := cmd.Execute(ctx) + _, err := cmd.Execute(ctx) require.NoError(t, err) require.Equal(t, "UploadArchive: "+repo, output.String()) diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go index dcdd14407..3974b1752 100644 --- a/internal/command/uploadarchive/uploadarchive.go +++ b/internal/command/uploadarchive/uploadarchive.go @@ -16,19 +16,19 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { args := c.Args.SshArgs if len(args) != 2 { - return disallowedcommand.Error + return ctx, disallowedcommand.Error } repo := args[1] response, err := c.verifyAccess(ctx, repo) if err != nil { - return err + return ctx, err } - return c.performGitalyCall(ctx, response) + return ctx, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go index 86a40315e..506a74482 100644 --- a/internal/command/uploadarchive/uploadarchive_test.go +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -26,6 +26,6 @@ func TestForbiddenAccess(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go index 874d12e6a..dfa189d58 100644 --- a/internal/command/uploadpack/gitalycall_test.go +++ b/internal/command/uploadpack/gitalycall_test.go @@ -57,7 +57,7 @@ func TestUploadPack(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, } - err := cmd.Execute(ctx) + _, err := cmd.Execute(ctx) require.NoError(t, err) require.Equal(t, "SSHUploadPackWithSidechannel: "+repo, output.String()) diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go index 725093a11..35b818709 100644 --- a/internal/command/uploadpack/uploadpack.go +++ b/internal/command/uploadpack/uploadpack.go @@ -17,16 +17,16 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { args := c.Args.SshArgs if len(args) != 2 { - return disallowedcommand.Error + return ctx, disallowedcommand.Error } repo := args[1] response, err := c.verifyAccess(ctx, repo) if err != nil { - return err + return ctx, err } if response.IsCustomAction() { @@ -35,10 +35,10 @@ func (c *Command) Execute(ctx context.Context) error { ReadWriter: c.ReadWriter, EOFSent: false, } - return customAction.Execute(ctx, response) + return ctx, customAction.Execute(ctx, response) } - return c.performGitalyCall(ctx, response) + return ctx, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go index 5456cae14..b27c8b839 100644 --- a/internal/command/uploadpack/uploadpack_test.go +++ b/internal/command/uploadpack/uploadpack_test.go @@ -26,6 +26,6 @@ func TestForbiddenAccess(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index e691d331d..e5483981c 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -36,7 +36,7 @@ type connection struct { remoteAddr string } -type channelHandler func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error +type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) func newConnection(cfg *config.Config, nconn net.Conn) *connection { maxSessions := cfg.Server.ConcurrentSessionsLimit @@ -94,7 +94,7 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi return sconn, chans, err } -func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) { +func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) context.Context { ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { @@ -134,7 +134,7 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, }() metrics.SliSshdSessionsTotal.Inc() - err := handler(sconn, channel, requests) + ctx, err = handler(ctx, sconn, channel, requests) if err != nil { c.trackError(ctxlog, err) } @@ -148,6 +148,8 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, ctx, cancel := context.WithTimeout(ctx, EOFTimeout) defer cancel() c.concurrentSessions.Acquire(ctx, c.maxSessions) + + return ctx } func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) { diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index 5438935f0..e9aa9eda6 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -97,7 +97,7 @@ func TestPanicDuringSessionIsRecovered(t *testing.T) { numSessions := 0 require.NotPanics(t, func() { - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { numSessions += 1 close(chans) panic("This is a panic") @@ -135,9 +135,9 @@ func TestTooManySessions(t *testing.T) { defer cancel() go func() { - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { <-ctx.Done() // Keep the accepted channel open until the end of the test - return nil + return ctx, nil }) }() @@ -148,12 +148,13 @@ func TestTooManySessions(t *testing.T) { func TestAcceptSessionSucceeds(t *testing.T) { newChannel := &fakeNewChannel{channelType: "session"} conn, chans := setup(1, newChannel) + ctx := context.Background() channelHandled := false - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { channelHandled = true close(chans) - return nil + return ctx, nil }) require.True(t, channelHandled) @@ -166,12 +167,13 @@ func TestAcceptSessionFails(t *testing.T) { acceptErr := errors.New("some failure") newChannel := &fakeNewChannel{channelType: "session", acceptCh: acceptCh, acceptErr: acceptErr} conn, chans := setup(1, newChannel) + ctx := context.Background() channelHandled := false go func() { - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { channelHandled = true - return nil + return ctx, nil }) }() @@ -203,11 +205,12 @@ func TestSessionsMetrics(t *testing.T) { initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal) newChannel := &fakeNewChannel{channelType: "session"} - conn, chans := setup(1, newChannel) - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + ctx := context.Background() + + conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { close(chans) - return errors.New("custom error") + return ctx, errors.New("custom error") }) eventuallyInDelta(t, initialSessionsTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1) @@ -226,9 +229,11 @@ func TestSessionsMetrics(t *testing.T) { t.Run(ignoredError.desc, func(t *testing.T) { conn, chans = setup(1, newChannel) ignored := ignoredError.err - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + ctx := context.Background() + + conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { close(chans) - return ignored + return ctx, ignored }) eventuallyInDelta(t, initialSessionsTotal+2+float64(i), testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1) diff --git a/internal/sshd/session.go b/internal/sshd/session.go index 3394b2a55..930b8ade6 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -49,7 +49,7 @@ type exitStatusReq struct { ExitStatus uint32 } -func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error { +func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (context.Context, error) { ctxlog := log.ContextLogger(ctx) ctxlog.Debug("session: handle: entering request loop") @@ -70,13 +70,13 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) erro case "exec": // The command has been executed as `ssh user@host command` or `exec` channel has been used // in the app implementation - shouldContinue, err = s.handleExec(ctx, req) + ctx, shouldContinue, err = s.handleExec(ctx, req) case "shell": // The command has been entered into the shell or `shell` channel has been used // in the app implementation shouldContinue = false var status uint32 - status, err = s.handleShell(ctx, req) + ctx, status, err = s.handleShell(ctx, req) s.exit(ctx, status) default: // Ignore unknown requests but don't terminate the session @@ -99,7 +99,7 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) erro ctxlog.Debug("session: handle: exiting request loop") - return err + return ctx, err } func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) { @@ -132,21 +132,22 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) return true, nil } -func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) { +func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Context, bool, error) { var execRequest execRequest + if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil { - return false, err + return ctx, false, err } s.execCmd = execRequest.Command - status, err := s.handleShell(ctx, req) + ctx, status, err := s.handleShell(ctx, req) s.exit(ctx, status) - return false, err + return ctx, false, err } -func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) { +func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Context, uint32, error) { ctxlog := log.ContextLogger(ctx) if req.WantReply { @@ -183,7 +184,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, er s.toStderr(ctx, "ERROR: Failed to parse command: %v\n", err.Error()) } - return 128, err + return ctx, 128, err } cmdName := reflect.TypeOf(cmd).String() @@ -194,18 +195,19 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, er }).Info("session: handleShell: executing command") metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) - if err := cmd.Execute(ctx); err != nil { + ctx, err = cmd.Execute(ctx) + if err != nil { grpcStatus := grpcstatus.Convert(err) if grpcStatus.Code() != grpccodes.Internal { s.toStderr(ctx, "ERROR: %v\n", grpcStatus.Message()) } - return 1, err + return ctx, 1, err } ctxlog.Info("session: handleShell: command executed successfully") - return 0, nil + return ctx, 0, nil } func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go index d1bff7e82..d0f4f51f1 100644 --- a/internal/sshd/session_test.go +++ b/internal/sshd/session_test.go @@ -146,7 +146,7 @@ func TestHandleExec(t *testing.T) { r := &ssh.Request{Payload: tc.payload} s.channel = f - shouldContinue, err := s.handleExec(context.Background(), r) + _, shouldContinue, err := s.handleExec(context.Background(), r) require.Equal(t, tc.expectedErr, err) require.Equal(t, false, shouldContinue) @@ -210,7 +210,7 @@ func TestHandleShell(t *testing.T) { } r := &ssh.Request{} - exitCode, err := s.handleShell(context.Background(), r) + _, exitCode, err := s.handleShell(context.Background(), r) if tc.expectedErrString != "" { require.Equal(t, tc.expectedErrString, err.Error()) diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index f26458228..168dd2201 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -193,7 +193,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { started := time.Now() conn := newConnection(s.Config, nconn) - conn.handle(ctx, s.serverConfig.get(ctx), func(sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error { + conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) { session := &session{ cfg: s.Config, channel: channel, -- GitLab From c44e4e4ec45d204713cb175b7e4250f7c56bed19 Mon Sep 17 00:00:00 2001 From: Ash McKenzie Date: Wed, 21 Jun 2023 14:21:55 +1000 Subject: [PATCH 2/8] New config.MetaData struct and NewMetaData() --- internal/config/config.go | 27 +++++++++++++++++++++++++ internal/config/config_test.go | 37 ++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/internal/config/config.go b/internal/config/config.go index c34f1d4ea..67cac1267 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,7 @@ import ( "os" "path" "path/filepath" + "strings" "sync" "time" @@ -82,6 +83,12 @@ type Config struct { GitalyClient gitaly.Client } +type MetaData struct { + Username string `json:"username"` + Project string `json:"project,omitempty"` + RootNamespace string `json:"root_namespace,omitempty"` +} + // The defaults to apply before parsing the config file(s). var ( DefaultConfig = Config{ @@ -110,6 +117,26 @@ var ( } ) +func NewMetaData(project, username string) MetaData { + rootNameSpace := "" + + if len(project) > 0 { + splitFn := func(c rune) bool { + return c == '/' + } + m := strings.FieldsFunc(project, splitFn) + if len(m) > 0 { + rootNameSpace = m[0] + } + } + + return MetaData{ + Username: username, + Project: project, + RootNamespace: rootNameSpace, + } +} + func (d *YamlDuration) UnmarshalYAML(unmarshal func(interface{}) error) error { var intDuration int if err := unmarshal(&intDuration); err != nil { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index d88a0cd91..bd61691fd 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -98,3 +98,40 @@ func TestYAMLDuration(t *testing.T) { }) } } + +func TestNewMetaData(t *testing.T) { + testCases := []struct { + desc string + project string + username string + expectedRootNamespace string + }{ + { + desc: "Project under single namespace", + project: "flightjs/Flight", + username: "@alex-doe", + expectedRootNamespace: "flightjs", + }, + { + desc: "Project under single odd namespace", + project: "flightjs///Flight", + username: "@alex-doe", + expectedRootNamespace: "flightjs", + }, + { + desc: "Project under deeper namespace", + project: "flightjs/one/Flight", + username: "@alex-doe", + expectedRootNamespace: "flightjs", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + metaData := NewMetaData(tc.project, tc.username) + require.Equal(t, tc.project, metaData.Project) + require.Equal(t, tc.username, metaData.Username) + require.Equal(t, tc.expectedRootNamespace, metaData.RootNamespace) + }) + } +} -- GitLab From 9b82224134280d401c850d619cb39f3959c26c37 Mon Sep 17 00:00:00 2001 From: Ash McKenzie Date: Tue, 20 Jun 2023 15:30:09 +1000 Subject: [PATCH 3/8] Log 'access: finish' line with metadata --- internal/command/discover/discover.go | 7 +++- internal/command/discover/discover_test.go | 39 +++++++++++-------- internal/command/receivepack/receivepack.go | 11 ++++-- .../command/receivepack/receivepack_test.go | 15 +++++++ .../command/uploadarchive/uploadarchive.go | 8 +++- .../uploadarchive/uploadarchive_test.go | 31 +++++++++++++-- internal/command/uploadpack/uploadpack.go | 10 ++++- .../command/uploadpack/uploadpack_test.go | 31 +++++++++++++-- internal/sshd/connection.go | 14 ++++--- internal/sshd/connection_test.go | 3 +- internal/sshd/session.go | 17 ++++---- internal/sshd/sshd.go | 14 ++++++- internal/sshd/sshd_test.go | 19 +++++++++ .../requesthandlers/requesthandlers.go | 1 + 14 files changed, 172 insertions(+), 48 deletions(-) diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go index e0a98ebdc..cf3b05fcf 100644 --- a/internal/command/discover/discover.go +++ b/internal/command/discover/discover.go @@ -22,13 +22,18 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, fmt.Errorf("Failed to get username: %v", err) } + metaData := config.MetaData{} if response.IsAnonymous() { + metaData.Username = "Anonymous" fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, Anonymous!\n") } else { + metaData.Username = response.Username fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) } - return ctx, nil + ctxWithMetaData := context.WithValue(ctx, "metaData", metaData) + + return ctxWithMetaData, nil } func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) { diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go index 2683cee65..f8b54ce1f 100644 --- a/internal/command/discover/discover_test.go +++ b/internal/command/discover/discover_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "testing" "github.com/stretchr/testify/require" @@ -46,29 +47,29 @@ func TestExecute(t *testing.T) { url := testserver.StartSocketHttpServer(t, requests) testCases := []struct { - desc string - arguments *commandargs.Shell - expectedOutput string + desc string + arguments *commandargs.Shell + expectedUsername string }{ { - desc: "With a known username", - arguments: &commandargs.Shell{GitlabUsername: "alex-doe"}, - expectedOutput: "Welcome to GitLab, @alex-doe!\n", + desc: "With a known username", + arguments: &commandargs.Shell{GitlabUsername: "alex-doe"}, + expectedUsername: "@alex-doe", }, { - desc: "With a known key id", - arguments: &commandargs.Shell{GitlabKeyId: "1"}, - expectedOutput: "Welcome to GitLab, @alex-doe!\n", + desc: "With a known key id", + arguments: &commandargs.Shell{GitlabKeyId: "1"}, + expectedUsername: "@alex-doe", }, { - desc: "With an unknown key", - arguments: &commandargs.Shell{GitlabKeyId: "-1"}, - expectedOutput: "Welcome to GitLab, Anonymous!\n", + desc: "With an unknown key", + arguments: &commandargs.Shell{GitlabKeyId: "-1"}, + expectedUsername: "Anonymous", }, { - desc: "With an unknown username", - arguments: &commandargs.Shell{GitlabUsername: "unknown"}, - expectedOutput: "Welcome to GitLab, Anonymous!\n", + desc: "With an unknown username", + arguments: &commandargs.Shell{GitlabUsername: "unknown"}, + expectedUsername: "Anonymous", }, } @@ -81,10 +82,14 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - _, err := cmd.Execute(context.Background()) + ctxWithMetaData, err := cmd.Execute(context.Background()) + + expectedOutput := fmt.Sprintf("Welcome to GitLab, %s!\n", tc.expectedUsername) + expectedUsername := strings.TrimLeft(tc.expectedUsername, "@") require.NoError(t, err) - require.Equal(t, tc.expectedOutput, buffer.String()) + require.Equal(t, expectedOutput, buffer.String()) + require.Equal(t, expectedUsername, ctxWithMetaData.Value("metaData").(config.MetaData).Username) }) } } diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go index 63301d1f1..8bdb6fb02 100644 --- a/internal/command/receivepack/receivepack.go +++ b/internal/command/receivepack/receivepack.go @@ -30,6 +30,11 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } + ctxWithMetaData := context.WithValue(ctx, "metaData", config.NewMetaData( + response.Gitaly.Repo.GlProjectPath, + response.Username, + )) + if response.IsCustomAction() { // When `geo_proxy_direct_to_primary` feature flag is enabled, a Git over HTTP direct request // to primary repo is performed instead of proxying the request through Gitlab Rails. @@ -42,7 +47,7 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { Response: response, } - return ctx, cmd.Execute(ctx) + return ctxWithMetaData, cmd.Execute(ctx) } customAction := customaction.Command{ @@ -50,10 +55,10 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { ReadWriter: c.ReadWriter, EOFSent: true, } - return ctx, customAction.Execute(ctx, response) + return ctxWithMetaData, customAction.Execute(ctx, response) } - return ctx, c.performGitalyCall(ctx, response) + return ctxWithMetaData, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go index c987daae5..ddd11f03d 100644 --- a/internal/command/receivepack/receivepack_test.go +++ b/internal/command/receivepack/receivepack_test.go @@ -14,6 +14,21 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper/requesthandlers" ) +func TestAllowedAccess(t *testing.T) { + gitalyAddress, _ := testserver.StartGitalyServer(t, "unix") + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + cmd, _ := setup(t, "1", requests) + cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) + + ctxWithMetaData, err := cmd.Execute(context.Background()) + + require.NoError(t, err) + metaData := ctxWithMetaData.Value("metaData").(config.MetaData) + require.Equal(t, "alex-doe", metaData.Username) + require.Equal(t, "group/project-path", metaData.Project) + require.Equal(t, "group", metaData.RootNamespace) +} + func TestForbiddenAccess(t *testing.T) { requests := requesthandlers.BuildDisallowedByApiHandlers(t) cmd, _ := setup(t, "disallowed", requests) diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go index 3974b1752..5938ba2ec 100644 --- a/internal/command/uploadarchive/uploadarchive.go +++ b/internal/command/uploadarchive/uploadarchive.go @@ -28,7 +28,13 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } - return ctx, c.performGitalyCall(ctx, response) + metaData := config.NewMetaData( + response.Gitaly.Repo.GlProjectPath, + response.Username, + ) + ctxWithMetaData := context.WithValue(ctx, "metaData", metaData) + + return ctxWithMetaData, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go index 506a74482..c63069671 100644 --- a/internal/command/uploadarchive/uploadarchive_test.go +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -14,18 +14,41 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper/requesthandlers" ) +func TestAllowedAccess(t *testing.T) { + gitalyAddress, _ := testserver.StartGitalyServer(t, "unix") + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + cmd, _ := setup(t, "1", requests) + cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) + + ctxWithMetaData, err := cmd.Execute(context.Background()) + + require.NoError(t, err) + metaData := ctxWithMetaData.Value("metaData").(config.MetaData) + require.Equal(t, "alex-doe", metaData.Username) + require.Equal(t, "group/project-path", metaData.Project) + require.Equal(t, "group", metaData.RootNamespace) +} + func TestForbiddenAccess(t *testing.T) { requests := requesthandlers.BuildDisallowedByApiHandlers(t) + + cmd, _ := setup(t, "disallowed", requests) + + _, err := cmd.Execute(context.Background()) + require.Equal(t, "Disallowed by API call", err.Error()) +} + +func setup(t *testing.T, keyId string, requests []testserver.TestRequestHandler) (*Command, *bytes.Buffer) { url := testserver.StartHttpServer(t, requests) output := &bytes.Buffer{} + input := bytes.NewBufferString("input") cmd := &Command{ Config: &config.Config{GitlabUrl: url}, - Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-upload-archive", "group/repo"}}, - ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, + Args: &commandargs.Shell{GitlabKeyId: keyId, SshArgs: []string{"git-upload-archive", "group/repo"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, } - _, err := cmd.Execute(context.Background()) - require.Equal(t, "Disallowed by API call", err.Error()) + return cmd, output } diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go index 35b818709..1015245bc 100644 --- a/internal/command/uploadpack/uploadpack.go +++ b/internal/command/uploadpack/uploadpack.go @@ -29,16 +29,22 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } + metaData := config.NewMetaData( + response.Gitaly.Repo.GlProjectPath, + response.Username, + ) + ctxWithMetaData := context.WithValue(ctx, "metaData", metaData) + if response.IsCustomAction() { customAction := customaction.Command{ Config: c.Config, ReadWriter: c.ReadWriter, EOFSent: false, } - return ctx, customAction.Execute(ctx, response) + return ctxWithMetaData, customAction.Execute(ctx, response) } - return ctx, c.performGitalyCall(ctx, response) + return ctxWithMetaData, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go index b27c8b839..50c966224 100644 --- a/internal/command/uploadpack/uploadpack_test.go +++ b/internal/command/uploadpack/uploadpack_test.go @@ -14,18 +14,41 @@ import ( "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper/requesthandlers" ) +func TestAllowedAccess(t *testing.T) { + gitalyAddress, _ := testserver.StartGitalyServer(t, "unix") + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + cmd, _ := setup(t, "1", requests) + cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) + + ctxWithMetaData, err := cmd.Execute(context.Background()) + + require.NoError(t, err) + metaData := ctxWithMetaData.Value("metaData").(config.MetaData) + require.Equal(t, "alex-doe", metaData.Username) + require.Equal(t, "group/project-path", metaData.Project) + require.Equal(t, "group", metaData.RootNamespace) +} + func TestForbiddenAccess(t *testing.T) { requests := requesthandlers.BuildDisallowedByApiHandlers(t) + + cmd, _ := setup(t, "disallowed", requests) + + _, err := cmd.Execute(context.Background()) + require.Equal(t, "Disallowed by API call", err.Error()) +} + +func setup(t *testing.T, keyId string, requests []testserver.TestRequestHandler) (*Command, *bytes.Buffer) { url := testserver.StartHttpServer(t, requests) output := &bytes.Buffer{} + input := bytes.NewBufferString("input") cmd := &Command{ Config: &config.Config{GitlabUrl: url}, - Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-upload-pack", "group/repo"}}, - ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, + Args: &commandargs.Shell{GitlabKeyId: keyId, SshArgs: []string{"git-upload-pack", "group/repo"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, } - _, err := cmd.Execute(context.Background()) - require.Equal(t, "Disallowed by API call", err.Error()) + return cmd, output } diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index e5483981c..202b7c6b4 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -50,12 +50,12 @@ func newConnection(cfg *config.Config, nconn net.Conn) *connection { } } -func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) { +func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) context.Context { log.WithContextFields(ctx, log.Fields{}).Info("server: handleConn: start") sconn, chans, err := c.initServerConn(ctx, srvCfg) if err != nil { - return + return ctx } if c.cfg.Server.ClientAliveInterval > 0 { @@ -64,10 +64,12 @@ func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handl go c.sendKeepAliveMsg(ctx, sconn, ticker) } - c.handleRequests(ctx, sconn, chans, handler) + ctxWithMetaData := c.handleRequests(ctx, sconn, chans, handler) reason := sconn.Wait() log.WithContextFields(ctx, log.Fields{"reason": reason}).Info("server: handleConn: done") + + return ctxWithMetaData } func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) { @@ -95,6 +97,7 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi } func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) context.Context { + ctxWithMetaData := ctx ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { @@ -134,7 +137,8 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, }() metrics.SliSshdSessionsTotal.Inc() - ctx, err = handler(ctx, sconn, channel, requests) + ctxWithMetaData, err = handler(ctx, sconn, channel, requests) + if err != nil { c.trackError(ctxlog, err) } @@ -149,7 +153,7 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, defer cancel() c.concurrentSessions.Acquire(ctx, c.maxSessions) - return ctx + return ctxWithMetaData } func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) { diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index e9aa9eda6..29ea81e9a 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -151,13 +151,14 @@ func TestAcceptSessionSucceeds(t *testing.T) { ctx := context.Background() channelHandled := false - conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { + returnedCtx := conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { channelHandled = true close(chans) return ctx, nil }) require.True(t, channelHandled) + require.NotNil(t, returnedCtx) } func TestAcceptSessionFails(t *testing.T) { diff --git a/internal/sshd/session.go b/internal/sshd/session.go index 930b8ade6..a1f22ee95 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -50,6 +50,7 @@ type exitStatusReq struct { } func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (context.Context, error) { + ctxWithMetaData := ctx ctxlog := log.ContextLogger(ctx) ctxlog.Debug("session: handle: entering request loop") @@ -70,13 +71,13 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (con case "exec": // The command has been executed as `ssh user@host command` or `exec` channel has been used // in the app implementation - ctx, shouldContinue, err = s.handleExec(ctx, req) + ctxWithMetaData, shouldContinue, err = s.handleExec(ctx, req) case "shell": // The command has been entered into the shell or `shell` channel has been used // in the app implementation shouldContinue = false var status uint32 - ctx, status, err = s.handleShell(ctx, req) + ctxWithMetaData, status, err = s.handleShell(ctx, req) s.exit(ctx, status) default: // Ignore unknown requests but don't terminate the session @@ -99,7 +100,7 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (con ctxlog.Debug("session: handle: exiting request loop") - return ctx, err + return ctxWithMetaData, err } func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) { @@ -141,10 +142,10 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Con s.execCmd = execRequest.Command - ctx, status, err := s.handleShell(ctx, req) - s.exit(ctx, status) + ctxWithMetaData, status, err := s.handleShell(ctx, req) + s.exit(ctxWithMetaData, status) - return ctx, false, err + return ctxWithMetaData, false, err } func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Context, uint32, error) { @@ -195,7 +196,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co }).Info("session: handleShell: executing command") metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) - ctx, err = cmd.Execute(ctx) + ctxWithMetaData, err := cmd.Execute(ctx) if err != nil { grpcStatus := grpcstatus.Convert(err) if grpcStatus.Code() != grpccodes.Internal { @@ -207,7 +208,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co ctxlog.Info("session: handleShell: command executed successfully") - return ctx, 0, nil + return ctxWithMetaData, 0, nil } func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 168dd2201..9fa0e6c49 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -193,7 +193,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { started := time.Now() conn := newConnection(s.Config, nconn) - conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) { + ctxWithMetaData := conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) { session := &session{ cfg: s.Config, channel: channel, @@ -206,7 +206,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { return session.handle(ctx, requests) }) - ctxlog.WithFields(log.Fields{"duration_s": time.Since(started).Seconds()}).Info("access: finish") + ctxlog.WithFields(log.Fields{"duration_s": time.Since(started).Seconds(), "meta": extractMetaDataFromContext(ctxWithMetaData)}).Info("access: finish") } func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { @@ -228,6 +228,16 @@ func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { } } +func extractMetaDataFromContext(ctx context.Context) config.MetaData { + metaData := config.MetaData{} + + if ctx.Value("metaData") != nil { + metaData = ctx.Value("metaData").(config.MetaData) + } + + return metaData +} + func staticProxyPolicy(policy proxyproto.Policy) proxyproto.PolicyFunc { return func(_ net.Addr) (proxyproto.Policy, error) { return policy, nil diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index c14a9f5f8..fc63cab46 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -348,6 +348,25 @@ func TestLoginGraceTime(t *testing.T) { verifyStatus(t, s, StatusClosed) } +func TestExtractMetaDataFromContext(t *testing.T) { + rootNameSpace := "flightjs" + project := fmt.Sprintf("%s/Flight", rootNameSpace) + username := "alex-doe" + ctxWithMetaData := context.WithValue(context.Background(), "metaData", config.NewMetaData(project, username)) + + metaData := extractMetaDataFromContext(ctxWithMetaData) + + require.Equal(t, config.MetaData{Project: project, Username: username, RootNamespace: rootNameSpace}, metaData) +} + +func TestExtractMetaDataFromContextWithoutMetaData(t *testing.T) { + ctxWithMetaData := context.Background() + + metaData := extractMetaDataFromContext(ctxWithMetaData) + + require.Equal(t, config.MetaData{}, metaData) +} + func setupServer(t *testing.T) *Server { t.Helper() diff --git a/internal/testhelper/requesthandlers/requesthandlers.go b/internal/testhelper/requesthandlers/requesthandlers.go index de1fdf942..a58b67ae0 100644 --- a/internal/testhelper/requesthandlers/requesthandlers.go +++ b/internal/testhelper/requesthandlers/requesthandlers.go @@ -38,6 +38,7 @@ func BuildAllowedWithGitalyHandlers(t *testing.T, gitalyAddress string) []testse "gl_id": "1", "gl_key_type": "key", "gl_key_id": 123, + "gl_username": "alex-doe", "gitaly": map[string]interface{}{ "repository": map[string]interface{}{ "storage_name": "storage_name", -- GitLab From 75343abf284ce678bab4f72d5ec3fddb6d6061c6 Mon Sep 17 00:00:00 2001 From: Ash McKenzie Date: Fri, 30 Jun 2023 13:06:34 +1000 Subject: [PATCH 4/8] Move MetaData into command package Also rename to Metadata --- internal/command/command.go | 27 ++++++++++++++ internal/command/command_test.go | 37 +++++++++++++++++++ internal/command/discover/discover.go | 7 ++-- internal/command/discover/discover_test.go | 5 ++- internal/command/receivepack/receivepack.go | 9 +++-- .../command/receivepack/receivepack_test.go | 5 ++- .../command/uploadarchive/uploadarchive.go | 7 ++-- .../uploadarchive/uploadarchive_test.go | 5 ++- internal/command/uploadpack/uploadpack.go | 9 +++-- .../command/uploadpack/uploadpack_test.go | 5 ++- internal/config/config.go | 27 -------------- internal/config/config_test.go | 37 ------------------- internal/sshd/connection.go | 10 ++--- internal/sshd/session.go | 18 ++++----- internal/sshd/sshd.go | 11 +++--- internal/sshd/sshd_test.go | 13 ++++--- 16 files changed, 121 insertions(+), 111 deletions(-) diff --git a/internal/command/command.go b/internal/command/command.go index 552678f7e..30b2bfcd2 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -2,6 +2,7 @@ package command import ( "context" + "strings" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/labkit/correlation" @@ -12,6 +13,12 @@ type Command interface { Execute(ctx context.Context) (context.Context, error) } +type LogMetadata struct { + Username string `json:"username"` + Project string `json:"project,omitempty"` + RootNamespace string `json:"root_namespace,omitempty"` +} + // Setup() initializes tracing from the configuration file and generates a // background context from which all other contexts in the process should derive // from, as it has a service name and initial correlation ID set. @@ -47,3 +54,23 @@ func Setup(serviceName string, config *config.Config) (context.Context, func()) closer.Close() } } + +func NewLogMetadata(project, username string) LogMetadata { + rootNameSpace := "" + + if len(project) > 0 { + splitFn := func(c rune) bool { + return c == '/' + } + m := strings.FieldsFunc(project, splitFn) + if len(m) > 0 { + rootNameSpace = m[0] + } + } + + return LogMetadata{ + Username: username, + Project: project, + RootNamespace: rootNameSpace, + } +} diff --git a/internal/command/command_test.go b/internal/command/command_test.go index c95e83884..c5dc3c91b 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -77,3 +77,40 @@ func addAdditionalEnv(envMap map[string]string) func() { } } + +func TestNewLogMetadata(t *testing.T) { + testCases := []struct { + desc string + project string + username string + expectedRootNamespace string + }{ + { + desc: "Project under single namespace", + project: "flightjs/Flight", + username: "@alex-doe", + expectedRootNamespace: "flightjs", + }, + { + desc: "Project under single odd namespace", + project: "flightjs///Flight", + username: "@alex-doe", + expectedRootNamespace: "flightjs", + }, + { + desc: "Project under deeper namespace", + project: "flightjs/one/Flight", + username: "@alex-doe", + expectedRootNamespace: "flightjs", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + metaData := NewLogMetadata(tc.project, tc.username) + require.Equal(t, tc.project, metaData.Project) + require.Equal(t, tc.username, metaData.Username) + require.Equal(t, tc.expectedRootNamespace, metaData.RootNamespace) + }) + } +} diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go index cf3b05fcf..228db468c 100644 --- a/internal/command/discover/discover.go +++ b/internal/command/discover/discover.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -22,7 +23,7 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, fmt.Errorf("Failed to get username: %v", err) } - metaData := config.MetaData{} + metaData := command.LogMetadata{} if response.IsAnonymous() { metaData.Username = "Anonymous" fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, Anonymous!\n") @@ -31,9 +32,9 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) } - ctxWithMetaData := context.WithValue(ctx, "metaData", metaData) + ctxWithLogMetadata := context.WithValue(ctx, "metaData", metaData) - return ctxWithMetaData, nil + return ctxWithLogMetadata, nil } func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) { diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go index f8b54ce1f..4b05acc76 100644 --- a/internal/command/discover/discover_test.go +++ b/internal/command/discover/discover_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -82,14 +83,14 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - ctxWithMetaData, err := cmd.Execute(context.Background()) + ctxWithLogMetadata, err := cmd.Execute(context.Background()) expectedOutput := fmt.Sprintf("Welcome to GitLab, %s!\n", tc.expectedUsername) expectedUsername := strings.TrimLeft(tc.expectedUsername, "@") require.NoError(t, err) require.Equal(t, expectedOutput, buffer.String()) - require.Equal(t, expectedUsername, ctxWithMetaData.Value("metaData").(config.MetaData).Username) + require.Equal(t, expectedUsername, ctxWithLogMetadata.Value("metaData").(command.LogMetadata).Username) }) } } diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go index 8bdb6fb02..4d2cdcae9 100644 --- a/internal/command/receivepack/receivepack.go +++ b/internal/command/receivepack/receivepack.go @@ -3,6 +3,7 @@ package receivepack import ( "context" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/githttp" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" @@ -30,7 +31,7 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } - ctxWithMetaData := context.WithValue(ctx, "metaData", config.NewMetaData( + ctxWithLogMetadata := context.WithValue(ctx, "metaData", command.NewLogMetadata( response.Gitaly.Repo.GlProjectPath, response.Username, )) @@ -47,7 +48,7 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { Response: response, } - return ctxWithMetaData, cmd.Execute(ctx) + return ctxWithLogMetadata, cmd.Execute(ctx) } customAction := customaction.Command{ @@ -55,10 +56,10 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { ReadWriter: c.ReadWriter, EOFSent: true, } - return ctxWithMetaData, customAction.Execute(ctx, response) + return ctxWithLogMetadata, customAction.Execute(ctx, response) } - return ctxWithMetaData, c.performGitalyCall(ctx, response) + return ctxWithLogMetadata, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go index ddd11f03d..862400380 100644 --- a/internal/command/receivepack/receivepack_test.go +++ b/internal/command/receivepack/receivepack_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -20,10 +21,10 @@ func TestAllowedAccess(t *testing.T) { cmd, _ := setup(t, "1", requests) cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) - ctxWithMetaData, err := cmd.Execute(context.Background()) + ctxWithLogMetadata, err := cmd.Execute(context.Background()) require.NoError(t, err) - metaData := ctxWithMetaData.Value("metaData").(config.MetaData) + metaData := ctxWithLogMetadata.Value("metaData").(command.LogMetadata) require.Equal(t, "alex-doe", metaData.Username) require.Equal(t, "group/project-path", metaData.Project) require.Equal(t, "group", metaData.RootNamespace) diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go index 5938ba2ec..2442659db 100644 --- a/internal/command/uploadarchive/uploadarchive.go +++ b/internal/command/uploadarchive/uploadarchive.go @@ -3,6 +3,7 @@ package uploadarchive import ( "context" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/shared/accessverifier" @@ -28,13 +29,13 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } - metaData := config.NewMetaData( + metaData := command.NewLogMetadata( response.Gitaly.Repo.GlProjectPath, response.Username, ) - ctxWithMetaData := context.WithValue(ctx, "metaData", metaData) + ctxWithLogMetadata := context.WithValue(ctx, "metaData", metaData) - return ctxWithMetaData, c.performGitalyCall(ctx, response) + return ctxWithLogMetadata, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go index c63069671..a25b8abf1 100644 --- a/internal/command/uploadarchive/uploadarchive_test.go +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -20,10 +21,10 @@ func TestAllowedAccess(t *testing.T) { cmd, _ := setup(t, "1", requests) cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) - ctxWithMetaData, err := cmd.Execute(context.Background()) + ctxWithLogMetadata, err := cmd.Execute(context.Background()) require.NoError(t, err) - metaData := ctxWithMetaData.Value("metaData").(config.MetaData) + metaData := ctxWithLogMetadata.Value("metaData").(command.LogMetadata) require.Equal(t, "alex-doe", metaData.Username) require.Equal(t, "group/project-path", metaData.Project) require.Equal(t, "group", metaData.RootNamespace) diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go index 1015245bc..996782354 100644 --- a/internal/command/uploadpack/uploadpack.go +++ b/internal/command/uploadpack/uploadpack.go @@ -3,6 +3,7 @@ package uploadpack import ( "context" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/shared/accessverifier" @@ -29,11 +30,11 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } - metaData := config.NewMetaData( + metaData := command.NewLogMetadata( response.Gitaly.Repo.GlProjectPath, response.Username, ) - ctxWithMetaData := context.WithValue(ctx, "metaData", metaData) + ctxWithLogMetadata := context.WithValue(ctx, "metaData", metaData) if response.IsCustomAction() { customAction := customaction.Command{ @@ -41,10 +42,10 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { ReadWriter: c.ReadWriter, EOFSent: false, } - return ctxWithMetaData, customAction.Execute(ctx, response) + return ctxWithLogMetadata, customAction.Execute(ctx, response) } - return ctxWithMetaData, c.performGitalyCall(ctx, response) + return ctxWithLogMetadata, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go index 50c966224..bb6113860 100644 --- a/internal/command/uploadpack/uploadpack_test.go +++ b/internal/command/uploadpack/uploadpack_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -20,10 +21,10 @@ func TestAllowedAccess(t *testing.T) { cmd, _ := setup(t, "1", requests) cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) - ctxWithMetaData, err := cmd.Execute(context.Background()) + ctxWithLogMetadata, err := cmd.Execute(context.Background()) require.NoError(t, err) - metaData := ctxWithMetaData.Value("metaData").(config.MetaData) + metaData := ctxWithLogMetadata.Value("metaData").(command.LogMetadata) require.Equal(t, "alex-doe", metaData.Username) require.Equal(t, "group/project-path", metaData.Project) require.Equal(t, "group", metaData.RootNamespace) diff --git a/internal/config/config.go b/internal/config/config.go index 67cac1267..c34f1d4ea 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,7 +6,6 @@ import ( "os" "path" "path/filepath" - "strings" "sync" "time" @@ -83,12 +82,6 @@ type Config struct { GitalyClient gitaly.Client } -type MetaData struct { - Username string `json:"username"` - Project string `json:"project,omitempty"` - RootNamespace string `json:"root_namespace,omitempty"` -} - // The defaults to apply before parsing the config file(s). var ( DefaultConfig = Config{ @@ -117,26 +110,6 @@ var ( } ) -func NewMetaData(project, username string) MetaData { - rootNameSpace := "" - - if len(project) > 0 { - splitFn := func(c rune) bool { - return c == '/' - } - m := strings.FieldsFunc(project, splitFn) - if len(m) > 0 { - rootNameSpace = m[0] - } - } - - return MetaData{ - Username: username, - Project: project, - RootNamespace: rootNameSpace, - } -} - func (d *YamlDuration) UnmarshalYAML(unmarshal func(interface{}) error) error { var intDuration int if err := unmarshal(&intDuration); err != nil { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index bd61691fd..d88a0cd91 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -98,40 +98,3 @@ func TestYAMLDuration(t *testing.T) { }) } } - -func TestNewMetaData(t *testing.T) { - testCases := []struct { - desc string - project string - username string - expectedRootNamespace string - }{ - { - desc: "Project under single namespace", - project: "flightjs/Flight", - username: "@alex-doe", - expectedRootNamespace: "flightjs", - }, - { - desc: "Project under single odd namespace", - project: "flightjs///Flight", - username: "@alex-doe", - expectedRootNamespace: "flightjs", - }, - { - desc: "Project under deeper namespace", - project: "flightjs/one/Flight", - username: "@alex-doe", - expectedRootNamespace: "flightjs", - }, - } - - for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { - metaData := NewMetaData(tc.project, tc.username) - require.Equal(t, tc.project, metaData.Project) - require.Equal(t, tc.username, metaData.Username) - require.Equal(t, tc.expectedRootNamespace, metaData.RootNamespace) - }) - } -} diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 202b7c6b4..7bc845a66 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -64,12 +64,12 @@ func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handl go c.sendKeepAliveMsg(ctx, sconn, ticker) } - ctxWithMetaData := c.handleRequests(ctx, sconn, chans, handler) + ctxWithLogMetadata := c.handleRequests(ctx, sconn, chans, handler) reason := sconn.Wait() log.WithContextFields(ctx, log.Fields{"reason": reason}).Info("server: handleConn: done") - return ctxWithMetaData + return ctxWithLogMetadata } func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) { @@ -97,7 +97,7 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi } func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) context.Context { - ctxWithMetaData := ctx + ctxWithLogMetadata := ctx ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { @@ -137,7 +137,7 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, }() metrics.SliSshdSessionsTotal.Inc() - ctxWithMetaData, err = handler(ctx, sconn, channel, requests) + ctxWithLogMetadata, err = handler(ctx, sconn, channel, requests) if err != nil { c.trackError(ctxlog, err) @@ -153,7 +153,7 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, defer cancel() c.concurrentSessions.Acquire(ctx, c.maxSessions) - return ctxWithMetaData + return ctxWithLogMetadata } func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) { diff --git a/internal/sshd/session.go b/internal/sshd/session.go index a1f22ee95..df0cfc165 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -50,7 +50,7 @@ type exitStatusReq struct { } func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (context.Context, error) { - ctxWithMetaData := ctx + ctxWithLogMetadata := ctx ctxlog := log.ContextLogger(ctx) ctxlog.Debug("session: handle: entering request loop") @@ -71,13 +71,13 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (con case "exec": // The command has been executed as `ssh user@host command` or `exec` channel has been used // in the app implementation - ctxWithMetaData, shouldContinue, err = s.handleExec(ctx, req) + ctxWithLogMetadata, shouldContinue, err = s.handleExec(ctx, req) case "shell": // The command has been entered into the shell or `shell` channel has been used // in the app implementation shouldContinue = false var status uint32 - ctxWithMetaData, status, err = s.handleShell(ctx, req) + ctxWithLogMetadata, status, err = s.handleShell(ctx, req) s.exit(ctx, status) default: // Ignore unknown requests but don't terminate the session @@ -100,7 +100,7 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (con ctxlog.Debug("session: handle: exiting request loop") - return ctxWithMetaData, err + return ctxWithLogMetadata, err } func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) { @@ -142,10 +142,10 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Con s.execCmd = execRequest.Command - ctxWithMetaData, status, err := s.handleShell(ctx, req) - s.exit(ctxWithMetaData, status) + ctxWithLogMetadata, status, err := s.handleShell(ctx, req) + s.exit(ctxWithLogMetadata, status) - return ctxWithMetaData, false, err + return ctxWithLogMetadata, false, err } func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Context, uint32, error) { @@ -196,7 +196,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co }).Info("session: handleShell: executing command") metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) - ctxWithMetaData, err := cmd.Execute(ctx) + ctxWithLogMetadata, err := cmd.Execute(ctx) if err != nil { grpcStatus := grpcstatus.Convert(err) if grpcStatus.Code() != grpccodes.Internal { @@ -208,7 +208,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co ctxlog.Info("session: handleShell: command executed successfully") - return ctxWithMetaData, 0, nil + return ctxWithLogMetadata, 0, nil } func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 9fa0e6c49..08f093564 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/v14/client" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/metrics" @@ -193,7 +194,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { started := time.Now() conn := newConnection(s.Config, nconn) - ctxWithMetaData := conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) { + ctxWithLogMetadata := conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) { session := &session{ cfg: s.Config, channel: channel, @@ -206,7 +207,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { return session.handle(ctx, requests) }) - ctxlog.WithFields(log.Fields{"duration_s": time.Since(started).Seconds(), "meta": extractMetaDataFromContext(ctxWithMetaData)}).Info("access: finish") + ctxlog.WithFields(log.Fields{"duration_s": time.Since(started).Seconds(), "meta": extractMetaDataFromContext(ctxWithLogMetadata)}).Info("access: finish") } func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { @@ -228,11 +229,11 @@ func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { } } -func extractMetaDataFromContext(ctx context.Context) config.MetaData { - metaData := config.MetaData{} +func extractMetaDataFromContext(ctx context.Context) command.LogMetadata { + metaData := command.LogMetadata{} if ctx.Value("metaData") != nil { - metaData = ctx.Value("metaData").(config.MetaData) + metaData = ctx.Value("metaData").(command.LogMetadata) } return metaData diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index fc63cab46..bdb33b023 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper" ) @@ -352,19 +353,19 @@ func TestExtractMetaDataFromContext(t *testing.T) { rootNameSpace := "flightjs" project := fmt.Sprintf("%s/Flight", rootNameSpace) username := "alex-doe" - ctxWithMetaData := context.WithValue(context.Background(), "metaData", config.NewMetaData(project, username)) + ctxWithLogMetadata := context.WithValue(context.Background(), "metaData", command.NewLogMetadata(project, username)) - metaData := extractMetaDataFromContext(ctxWithMetaData) + metaData := extractMetaDataFromContext(ctxWithLogMetadata) - require.Equal(t, config.MetaData{Project: project, Username: username, RootNamespace: rootNameSpace}, metaData) + require.Equal(t, command.LogMetadata{Project: project, Username: username, RootNamespace: rootNameSpace}, metaData) } func TestExtractMetaDataFromContextWithoutMetaData(t *testing.T) { - ctxWithMetaData := context.Background() + ctxWithLogMetadata := context.Background() - metaData := extractMetaDataFromContext(ctxWithMetaData) + metaData := extractMetaDataFromContext(ctxWithLogMetadata) - require.Equal(t, config.MetaData{}, metaData) + require.Equal(t, command.LogMetadata{}, metaData) } func setupServer(t *testing.T) *Server { -- GitLab From 5e521fe0294df0cd25e3ed6841334bdbd6849c5c Mon Sep 17 00:00:00 2001 From: Ash McKenzie Date: Fri, 30 Jun 2023 13:20:39 +1000 Subject: [PATCH 5/8] Make NewLogMetadata test data more realistic --- internal/command/command_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/command/command_test.go b/internal/command/command_test.go index c5dc3c91b..5d76772d0 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -88,19 +88,19 @@ func TestNewLogMetadata(t *testing.T) { { desc: "Project under single namespace", project: "flightjs/Flight", - username: "@alex-doe", + username: "alex-doe", expectedRootNamespace: "flightjs", }, { desc: "Project under single odd namespace", project: "flightjs///Flight", - username: "@alex-doe", + username: "alex-doe", expectedRootNamespace: "flightjs", }, { desc: "Project under deeper namespace", project: "flightjs/one/Flight", - username: "@alex-doe", + username: "alex-doe", expectedRootNamespace: "flightjs", }, } -- GitLab From 1c216d0279b1a47a329644ac8a3897699f76cb09 Mon Sep 17 00:00:00 2001 From: Ash McKenzie Date: Fri, 30 Jun 2023 15:47:11 +1000 Subject: [PATCH 6/8] Fix data race with ctxWithLogMetadata --- internal/sshd/connection.go | 20 +++++++++++------- internal/sshd/connection_test.go | 35 ++++++++++++++++---------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index 7bc845a66..850f91efd 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -64,7 +64,12 @@ func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handl go c.sendKeepAliveMsg(ctx, sconn, ticker) } - ctxWithLogMetadata := c.handleRequests(ctx, sconn, chans, handler) + ctxWithLogMetadataChan := make(chan context.Context) + defer close(ctxWithLogMetadataChan) + + go c.handleRequests(ctx, sconn, chans, ctxWithLogMetadataChan, handler) + + ctxWithLogMetadata := <-ctxWithLogMetadataChan reason := sconn.Wait() log.WithContextFields(ctx, log.Fields{"reason": reason}).Info("server: handleConn: done") @@ -96,23 +101,25 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi return sconn, chans, err } -func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) context.Context { - ctxWithLogMetadata := ctx +func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, ctxWithLogMetadataChan chan<- context.Context, handler channelHandler) { ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested") + if newChannel.ChannelType() != "session" { ctxlog.Info("connection: handleRequests: unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } + if !c.concurrentSessions.TryAcquire(1) { ctxlog.Info("connection: handleRequests: too many concurrent sessions") newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") metrics.SshdHitMaxSessions.Inc() continue } + channel, requests, err := newChannel.Accept() if err != nil { ctxlog.WithError(err).Error("connection: handleRequests: accepting channel failed") @@ -137,11 +144,12 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, }() metrics.SliSshdSessionsTotal.Inc() - ctxWithLogMetadata, err = handler(ctx, sconn, channel, requests) - + ctxWithLogMetadata, err := handler(ctx, sconn, channel, requests) if err != nil { c.trackError(ctxlog, err) } + + ctxWithLogMetadataChan <- ctxWithLogMetadata }() } @@ -152,8 +160,6 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, ctx, cancel := context.WithTimeout(ctx, EOFTimeout) defer cancel() c.concurrentSessions.Acquire(ctx, c.maxSessions) - - return ctxWithLogMetadata } func (c *connection) sendKeepAliveMsg(ctx context.Context, sconn *ssh.ServerConn, ticker *time.Ticker) { diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index 29ea81e9a..88e4e426e 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -81,23 +81,25 @@ func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (boo return true, nil, nil } -func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) { +func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel, chan<- context.Context) { cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum}} conn := &connection{cfg: cfg, concurrentSessions: semaphore.NewWeighted(sessionsNum)} chans := make(chan ssh.NewChannel, 1) chans <- newChannel - return conn, chans + ctxWithLogMetadataChan := make(chan context.Context) + + return conn, chans, ctxWithLogMetadataChan } func TestPanicDuringSessionIsRecovered(t *testing.T) { newChannel := &fakeNewChannel{channelType: "session"} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) numSessions := 0 require.NotPanics(t, func() { - conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { + conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { numSessions += 1 close(chans) panic("This is a panic") @@ -112,10 +114,10 @@ func TestUnknownChannelType(t *testing.T) { defer close(rejectCh) newChannel := &fakeNewChannel{channelType: "unknown session", rejectCh: rejectCh} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) go func() { - conn.handleRequests(context.Background(), nil, chans, nil) + conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, nil) }() rejectionData := <-rejectCh @@ -129,13 +131,13 @@ func TestTooManySessions(t *testing.T) { defer close(rejectCh) newChannel := &fakeNewChannel{channelType: "session", rejectCh: rejectCh} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { - conn.handleRequests(context.Background(), nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { + conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { <-ctx.Done() // Keep the accepted channel open until the end of the test return ctx, nil }) @@ -147,18 +149,17 @@ func TestTooManySessions(t *testing.T) { func TestAcceptSessionSucceeds(t *testing.T) { newChannel := &fakeNewChannel{channelType: "session"} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) ctx := context.Background() channelHandled := false - returnedCtx := conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { + conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { channelHandled = true close(chans) return ctx, nil }) require.True(t, channelHandled) - require.NotNil(t, returnedCtx) } func TestAcceptSessionFails(t *testing.T) { @@ -167,12 +168,12 @@ func TestAcceptSessionFails(t *testing.T) { acceptErr := errors.New("some failure") newChannel := &fakeNewChannel{channelType: "session", acceptCh: acceptCh, acceptErr: acceptErr} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) ctx := context.Background() channelHandled := false go func() { - conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { + conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { channelHandled = true return ctx, nil }) @@ -206,10 +207,10 @@ func TestSessionsMetrics(t *testing.T) { initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal) newChannel := &fakeNewChannel{channelType: "session"} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) ctx := context.Background() - conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { + conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { close(chans) return ctx, errors.New("custom error") }) @@ -228,11 +229,11 @@ func TestSessionsMetrics(t *testing.T) { {"not our ref", grpcstatus.Error(grpccodes.Internal, `rpc error: code = Internal desc = cmd wait: exit status 128, stderr: "fatal: git upload-pack: not our ref 9106d18f6a1b8022f6517f479696f3e3ea5e68c1"`)}, } { t.Run(ignoredError.desc, func(t *testing.T) { - conn, chans = setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) ignored := ignoredError.err ctx := context.Background() - conn.handleRequests(ctx, nil, chans, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { + conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { close(chans) return ctx, ignored }) -- GitLab From e0903318a285fc55b41211e1434fa722010a12ce Mon Sep 17 00:00:00 2001 From: Ash McKenzie Date: Tue, 4 Jul 2023 18:18:03 +1000 Subject: [PATCH 7/8] Add missing ctxWithLogMetadata for lfsauthenticate --- .../lfsauthenticate/lfsauthenticate.go | 11 ++++++++-- .../lfsauthenticate/lfsauthenticate_test.go | 20 +++++++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go index 9ca97a9da..4211c8fe8 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate.go +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -8,6 +8,7 @@ import ( "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/shared/accessverifier" @@ -57,6 +58,12 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } + metaData := command.NewLogMetadata( + accessResponse.Gitaly.Repo.GlProjectPath, + accessResponse.Username, + ) + ctxWithLogMetadata := context.WithValue(ctx, "metaData", metaData) + payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId) if err != nil { // return nothing just like Ruby's GitlabShell#lfs_authenticate does @@ -65,12 +72,12 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { log.Fields{"operation": operation, "repo": repo, "user_id": accessResponse.UserId}, ).WithError(err).Debug("lfsauthenticate: execute: LFS authentication failed") - return ctx, nil + return ctxWithLogMetadata, nil } fmt.Fprintf(c.ReadWriter.Out, "%s\n", payload) - return ctx, nil + return ctxWithLogMetadata, nil } func actionFromOperation(operation string) (commandargs.CommandType, error) { diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go index 10e5b0dbf..14167b797 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate_test.go +++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -109,8 +110,14 @@ func TestLfsAuthenticateRequests(t *testing.T) { } body := map[string]interface{}{ - "gl_id": glId, - "status": true, + "gl_id": glId, + "status": true, + "gl_username": "alex-doe", + "gitaly": map[string]interface{}{ + "repository": map[string]interface{}{ + "gl_project_path": "group/project-path", + }, + }, } require.NoError(t, json.NewEncoder(w).Encode(body)) }, @@ -145,10 +152,15 @@ func TestLfsAuthenticateRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - _, err := cmd.Execute(context.Background()) - require.NoError(t, err) + ctxWithLogMetadata, err := cmd.Execute(context.Background()) + require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) + + metaData := ctxWithLogMetadata.Value("metaData").(command.LogMetadata) + require.Equal(t, "alex-doe", metaData.Username) + require.Equal(t, "group/project-path", metaData.Project) + require.Equal(t, "group", metaData.RootNamespace) }) } } -- GitLab From 20c4e46586cedf17d346a9bfbad2b48b7accca7d Mon Sep 17 00:00:00 2001 From: Ash McKenzie Date: Tue, 4 Jul 2023 19:10:22 +1000 Subject: [PATCH 8/8] Rename metaData to just metadata --- internal/command/command_test.go | 8 ++++---- internal/command/discover/discover.go | 8 ++++---- internal/command/discover/discover_test.go | 2 +- internal/command/lfsauthenticate/lfsauthenticate.go | 4 ++-- .../command/lfsauthenticate/lfsauthenticate_test.go | 8 ++++---- internal/command/receivepack/receivepack.go | 2 +- internal/command/receivepack/receivepack_test.go | 8 ++++---- internal/command/uploadarchive/uploadarchive.go | 4 ++-- internal/command/uploadarchive/uploadarchive_test.go | 8 ++++---- internal/command/uploadpack/uploadpack.go | 4 ++-- internal/command/uploadpack/uploadpack_test.go | 8 ++++---- internal/sshd/sshd.go | 8 ++++---- internal/sshd/sshd_test.go | 10 +++++----- 13 files changed, 41 insertions(+), 41 deletions(-) diff --git a/internal/command/command_test.go b/internal/command/command_test.go index 5d76772d0..91934e60b 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -107,10 +107,10 @@ func TestNewLogMetadata(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - metaData := NewLogMetadata(tc.project, tc.username) - require.Equal(t, tc.project, metaData.Project) - require.Equal(t, tc.username, metaData.Username) - require.Equal(t, tc.expectedRootNamespace, metaData.RootNamespace) + metadata := NewLogMetadata(tc.project, tc.username) + require.Equal(t, tc.project, metadata.Project) + require.Equal(t, tc.username, metadata.Username) + require.Equal(t, tc.expectedRootNamespace, metadata.RootNamespace) }) } } diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go index 228db468c..a8e3aad7b 100644 --- a/internal/command/discover/discover.go +++ b/internal/command/discover/discover.go @@ -23,16 +23,16 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, fmt.Errorf("Failed to get username: %v", err) } - metaData := command.LogMetadata{} + metadata := command.LogMetadata{} if response.IsAnonymous() { - metaData.Username = "Anonymous" + metadata.Username = "Anonymous" fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, Anonymous!\n") } else { - metaData.Username = response.Username + metadata.Username = response.Username fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) } - ctxWithLogMetadata := context.WithValue(ctx, "metaData", metaData) + ctxWithLogMetadata := context.WithValue(ctx, "metadata", metadata) return ctxWithLogMetadata, nil } diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go index 4b05acc76..430b0ef45 100644 --- a/internal/command/discover/discover_test.go +++ b/internal/command/discover/discover_test.go @@ -90,7 +90,7 @@ func TestExecute(t *testing.T) { require.NoError(t, err) require.Equal(t, expectedOutput, buffer.String()) - require.Equal(t, expectedUsername, ctxWithLogMetadata.Value("metaData").(command.LogMetadata).Username) + require.Equal(t, expectedUsername, ctxWithLogMetadata.Value("metadata").(command.LogMetadata).Username) }) } } diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go index 4211c8fe8..07d530f99 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate.go +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -58,11 +58,11 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } - metaData := command.NewLogMetadata( + metadata := command.NewLogMetadata( accessResponse.Gitaly.Repo.GlProjectPath, accessResponse.Username, ) - ctxWithLogMetadata := context.WithValue(ctx, "metaData", metaData) + ctxWithLogMetadata := context.WithValue(ctx, "metadata", metadata) payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId) if err != nil { diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go index 14167b797..1279a789b 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate_test.go +++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go @@ -157,10 +157,10 @@ func TestLfsAuthenticateRequests(t *testing.T) { require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) - metaData := ctxWithLogMetadata.Value("metaData").(command.LogMetadata) - require.Equal(t, "alex-doe", metaData.Username) - require.Equal(t, "group/project-path", metaData.Project) - require.Equal(t, "group", metaData.RootNamespace) + metadata := ctxWithLogMetadata.Value("metadata").(command.LogMetadata) + require.Equal(t, "alex-doe", metadata.Username) + require.Equal(t, "group/project-path", metadata.Project) + require.Equal(t, "group", metadata.RootNamespace) }) } } diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go index 4d2cdcae9..3c3ac277b 100644 --- a/internal/command/receivepack/receivepack.go +++ b/internal/command/receivepack/receivepack.go @@ -31,7 +31,7 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } - ctxWithLogMetadata := context.WithValue(ctx, "metaData", command.NewLogMetadata( + ctxWithLogMetadata := context.WithValue(ctx, "metadata", command.NewLogMetadata( response.Gitaly.Repo.GlProjectPath, response.Username, )) diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go index 862400380..581903e0c 100644 --- a/internal/command/receivepack/receivepack_test.go +++ b/internal/command/receivepack/receivepack_test.go @@ -24,10 +24,10 @@ func TestAllowedAccess(t *testing.T) { ctxWithLogMetadata, err := cmd.Execute(context.Background()) require.NoError(t, err) - metaData := ctxWithLogMetadata.Value("metaData").(command.LogMetadata) - require.Equal(t, "alex-doe", metaData.Username) - require.Equal(t, "group/project-path", metaData.Project) - require.Equal(t, "group", metaData.RootNamespace) + metadata := ctxWithLogMetadata.Value("metadata").(command.LogMetadata) + require.Equal(t, "alex-doe", metadata.Username) + require.Equal(t, "group/project-path", metadata.Project) + require.Equal(t, "group", metadata.RootNamespace) } func TestForbiddenAccess(t *testing.T) { diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go index 2442659db..4d1b2641f 100644 --- a/internal/command/uploadarchive/uploadarchive.go +++ b/internal/command/uploadarchive/uploadarchive.go @@ -29,11 +29,11 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } - metaData := command.NewLogMetadata( + metadata := command.NewLogMetadata( response.Gitaly.Repo.GlProjectPath, response.Username, ) - ctxWithLogMetadata := context.WithValue(ctx, "metaData", metaData) + ctxWithLogMetadata := context.WithValue(ctx, "metadata", metadata) return ctxWithLogMetadata, c.performGitalyCall(ctx, response) } diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go index a25b8abf1..245f083f6 100644 --- a/internal/command/uploadarchive/uploadarchive_test.go +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -24,10 +24,10 @@ func TestAllowedAccess(t *testing.T) { ctxWithLogMetadata, err := cmd.Execute(context.Background()) require.NoError(t, err) - metaData := ctxWithLogMetadata.Value("metaData").(command.LogMetadata) - require.Equal(t, "alex-doe", metaData.Username) - require.Equal(t, "group/project-path", metaData.Project) - require.Equal(t, "group", metaData.RootNamespace) + metadata := ctxWithLogMetadata.Value("metadata").(command.LogMetadata) + require.Equal(t, "alex-doe", metadata.Username) + require.Equal(t, "group/project-path", metadata.Project) + require.Equal(t, "group", metadata.RootNamespace) } func TestForbiddenAccess(t *testing.T) { diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go index 996782354..3d402e0f4 100644 --- a/internal/command/uploadpack/uploadpack.go +++ b/internal/command/uploadpack/uploadpack.go @@ -30,11 +30,11 @@ func (c *Command) Execute(ctx context.Context) (context.Context, error) { return ctx, err } - metaData := command.NewLogMetadata( + metadata := command.NewLogMetadata( response.Gitaly.Repo.GlProjectPath, response.Username, ) - ctxWithLogMetadata := context.WithValue(ctx, "metaData", metaData) + ctxWithLogMetadata := context.WithValue(ctx, "metadata", metadata) if response.IsCustomAction() { customAction := customaction.Command{ diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go index bb6113860..dd96c3a34 100644 --- a/internal/command/uploadpack/uploadpack_test.go +++ b/internal/command/uploadpack/uploadpack_test.go @@ -24,10 +24,10 @@ func TestAllowedAccess(t *testing.T) { ctxWithLogMetadata, err := cmd.Execute(context.Background()) require.NoError(t, err) - metaData := ctxWithLogMetadata.Value("metaData").(command.LogMetadata) - require.Equal(t, "alex-doe", metaData.Username) - require.Equal(t, "group/project-path", metaData.Project) - require.Equal(t, "group", metaData.RootNamespace) + metadata := ctxWithLogMetadata.Value("metadata").(command.LogMetadata) + require.Equal(t, "alex-doe", metadata.Username) + require.Equal(t, "group/project-path", metadata.Project) + require.Equal(t, "group", metadata.RootNamespace) } func TestForbiddenAccess(t *testing.T) { diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index 08f093564..0328b6db4 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -230,13 +230,13 @@ func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { } func extractMetaDataFromContext(ctx context.Context) command.LogMetadata { - metaData := command.LogMetadata{} + metadata := command.LogMetadata{} - if ctx.Value("metaData") != nil { - metaData = ctx.Value("metaData").(command.LogMetadata) + if ctx.Value("metadata") != nil { + metadata = ctx.Value("metadata").(command.LogMetadata) } - return metaData + return metadata } func staticProxyPolicy(policy proxyproto.Policy) proxyproto.PolicyFunc { diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index bdb33b023..bdbb3c8cf 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -353,19 +353,19 @@ func TestExtractMetaDataFromContext(t *testing.T) { rootNameSpace := "flightjs" project := fmt.Sprintf("%s/Flight", rootNameSpace) username := "alex-doe" - ctxWithLogMetadata := context.WithValue(context.Background(), "metaData", command.NewLogMetadata(project, username)) + ctxWithLogMetadata := context.WithValue(context.Background(), "metadata", command.NewLogMetadata(project, username)) - metaData := extractMetaDataFromContext(ctxWithLogMetadata) + metadata := extractMetaDataFromContext(ctxWithLogMetadata) - require.Equal(t, command.LogMetadata{Project: project, Username: username, RootNamespace: rootNameSpace}, metaData) + require.Equal(t, command.LogMetadata{Project: project, Username: username, RootNamespace: rootNameSpace}, metadata) } func TestExtractMetaDataFromContextWithoutMetaData(t *testing.T) { ctxWithLogMetadata := context.Background() - metaData := extractMetaDataFromContext(ctxWithLogMetadata) + metadata := extractMetaDataFromContext(ctxWithLogMetadata) - require.Equal(t, command.LogMetadata{}, metaData) + require.Equal(t, command.LogMetadata{}, metadata) } func setupServer(t *testing.T) *Server { -- GitLab