diff --git a/cmd/check/main.go b/cmd/check/main.go index 578dfdf8434a716be197761cff8aafbc09da1d1a..76f217b782328792626ff968533a0b834bdaea28 100644 --- a/cmd/check/main.go +++ b/cmd/check/main.go @@ -43,7 +43,7 @@ func main() { ctx, finished := command.Setup(executable.Name, config) defer finished() - if err = cmd.Execute(ctx); err != nil { + if ctx, err = cmd.Execute(ctx); err != nil { fmt.Fprintf(readWriter.ErrOut, "%v\n", err) os.Exit(1) } diff --git a/cmd/gitlab-shell-authorized-keys-check/main.go b/cmd/gitlab-shell-authorized-keys-check/main.go index e272e68b0a8635a6e57cde621d4cbfa2180d629e..707d4cc68b9b77fd5d912a8e61aaebb679e18c1d 100644 --- a/cmd/gitlab-shell-authorized-keys-check/main.go +++ b/cmd/gitlab-shell-authorized-keys-check/main.go @@ -46,7 +46,7 @@ func main() { ctx, finished := command.Setup(executable.Name, config) defer finished() - if err = cmd.Execute(ctx); err != nil { + if ctx, err = cmd.Execute(ctx); err != nil { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) os.Exit(1) } diff --git a/cmd/gitlab-shell-authorized-principals-check/main.go b/cmd/gitlab-shell-authorized-principals-check/main.go index 10d3daab09f8b2bfc1836068a31fb92ced8a02fb..09380fb35bcc62dd7dd03f6aaea185c719454ddf 100644 --- a/cmd/gitlab-shell-authorized-principals-check/main.go +++ b/cmd/gitlab-shell-authorized-principals-check/main.go @@ -46,7 +46,7 @@ func main() { ctx, finished := command.Setup(executable.Name, config) defer finished() - if err = cmd.Execute(ctx); err != nil { + if ctx, err = cmd.Execute(ctx); err != nil { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) os.Exit(1) } diff --git a/cmd/gitlab-shell/main.go b/cmd/gitlab-shell/main.go index b789b774d89ed9d8fd7b62fb42c2cfe6b5c650a9..679d4593ee8118a602d4cdcca81bde4999fff14a 100644 --- a/cmd/gitlab-shell/main.go +++ b/cmd/gitlab-shell/main.go @@ -76,7 +76,7 @@ func main() { ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("gitlab-shell: main: executing command") fips.Check() - if err := cmd.Execute(ctx); err != nil { + if _, err := cmd.Execute(ctx); err != nil { ctxlog.WithError(err).Warn("gitlab-shell: main: command execution failed") if grpcstatus.Convert(err).Code() != grpccodes.Internal { console.DisplayWarningMessage(err.Error(), readWriter.ErrOut) diff --git a/internal/command/authorizedkeys/authorized_keys.go b/internal/command/authorizedkeys/authorized_keys.go index 46ab5c467f98fdd5a265eb8e056833cefa13ca19..92547cd9dad1052aab056574dd56b8c0549cbc6a 100644 --- a/internal/command/authorizedkeys/authorized_keys.go +++ b/internal/command/authorizedkeys/authorized_keys.go @@ -18,21 +18,21 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { // Do and return nothing when the expected and actual user don't match. // This can happen when the user in sshd_config doesn't match the user // trying to login. When nothing is printed, the user will be denied access. if c.Args.ExpectedUser != c.Args.ActualUser { // TODO: Log this event once we have a consistent way to log in Go. // See https://gitlab.com/gitlab-org/gitlab-shell/issues/192 for more info. - return nil + return ctx, nil } if err := c.printKeyLine(ctx); err != nil { - return err + return ctx, err } - return nil + return ctx, nil } func (c *Command) printKeyLine(ctx context.Context) error { diff --git a/internal/command/authorizedkeys/authorized_keys_test.go b/internal/command/authorizedkeys/authorized_keys_test.go index b91e460f4626a90f9c9dd39fd4e4494f9c9640ef..e54634099a2ad5fa1e88c98302c60cdece08a3dd 100644 --- a/internal/command/authorizedkeys/authorized_keys_test.go +++ b/internal/command/authorizedkeys/authorized_keys_test.go @@ -84,7 +84,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/authorizedprincipals/authorized_principals.go b/internal/command/authorizedprincipals/authorized_principals.go index a7cfe1a192a2184845e37c1e4e623901d8a83c80..38267f454ac3ac23bb4f012fd948d49c7bab5070 100644 --- a/internal/command/authorizedprincipals/authorized_principals.go +++ b/internal/command/authorizedprincipals/authorized_principals.go @@ -16,12 +16,12 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { if err := c.printPrincipalLines(); err != nil { - return err + return ctx, err } - return nil + return ctx, nil } func (c *Command) printPrincipalLines() error { diff --git a/internal/command/authorizedprincipals/authorized_principals_test.go b/internal/command/authorizedprincipals/authorized_principals_test.go index ba4d066ffcbdd06f1748c04533c4f9131cf22aca..b84d8e88ac2a3bd549fa4f527f7564d888ec3de9 100644 --- a/internal/command/authorizedprincipals/authorized_principals_test.go +++ b/internal/command/authorizedprincipals/authorized_principals_test.go @@ -42,7 +42,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, buffer.String()) diff --git a/internal/command/command.go b/internal/command/command.go index d9706b5d05a76b9cb6f18732d20137b9c5f97774..30b2bfcd2d0d1f0e710e22580168c79fe8dd87e2 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -2,6 +2,7 @@ package command import ( "context" + "strings" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/labkit/correlation" @@ -9,7 +10,13 @@ import ( ) type Command interface { - Execute(ctx context.Context) error + Execute(ctx context.Context) (context.Context, error) +} + +type LogMetadata struct { + Username string `json:"username"` + Project string `json:"project,omitempty"` + RootNamespace string `json:"root_namespace,omitempty"` } // Setup() initializes tracing from the configuration file and generates a @@ -47,3 +54,23 @@ func Setup(serviceName string, config *config.Config) (context.Context, func()) closer.Close() } } + +func NewLogMetadata(project, username string) LogMetadata { + rootNameSpace := "" + + if len(project) > 0 { + splitFn := func(c rune) bool { + return c == '/' + } + m := strings.FieldsFunc(project, splitFn) + if len(m) > 0 { + rootNameSpace = m[0] + } + } + + return LogMetadata{ + Username: username, + Project: project, + RootNamespace: rootNameSpace, + } +} diff --git a/internal/command/command_test.go b/internal/command/command_test.go index c95e838842e3d130ae0b3f453b6090d16a164002..91934e60bad2f2c5cb822692023e0b1653fd2959 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -77,3 +77,40 @@ func addAdditionalEnv(envMap map[string]string) func() { } } + +func TestNewLogMetadata(t *testing.T) { + testCases := []struct { + desc string + project string + username string + expectedRootNamespace string + }{ + { + desc: "Project under single namespace", + project: "flightjs/Flight", + username: "alex-doe", + expectedRootNamespace: "flightjs", + }, + { + desc: "Project under single odd namespace", + project: "flightjs///Flight", + username: "alex-doe", + expectedRootNamespace: "flightjs", + }, + { + desc: "Project under deeper namespace", + project: "flightjs/one/Flight", + username: "alex-doe", + expectedRootNamespace: "flightjs", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + metadata := NewLogMetadata(tc.project, tc.username) + require.Equal(t, tc.project, metadata.Project) + require.Equal(t, tc.username, metadata.Username) + require.Equal(t, tc.expectedRootNamespace, metadata.RootNamespace) + }) + } +} diff --git a/internal/command/discover/discover.go b/internal/command/discover/discover.go index 2f81a78445820badaa2ac4a86aebf616cd5fdea2..a8e3aad7b8472c4c4b30755b58b2c3bd20124927 100644 --- a/internal/command/discover/discover.go +++ b/internal/command/discover/discover.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -16,19 +17,24 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { response, err := c.getUserInfo(ctx) if err != nil { - return fmt.Errorf("Failed to get username: %v", err) + return ctx, fmt.Errorf("Failed to get username: %v", err) } + metadata := command.LogMetadata{} if response.IsAnonymous() { + metadata.Username = "Anonymous" fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, Anonymous!\n") } else { + metadata.Username = response.Username fmt.Fprintf(c.ReadWriter.Out, "Welcome to GitLab, @%s!\n", response.Username) } - return nil + ctxWithLogMetadata := context.WithValue(ctx, "metadata", metadata) + + return ctxWithLogMetadata, nil } func (c *Command) getUserInfo(ctx context.Context) (*discover.Response, error) { diff --git a/internal/command/discover/discover_test.go b/internal/command/discover/discover_test.go index df9ca47c317cca2423538490b0d5071f9b3cf7fa..430b0ef4544ce9876ba1fa06b8ca52169001ece5 100644 --- a/internal/command/discover/discover_test.go +++ b/internal/command/discover/discover_test.go @@ -6,11 +6,13 @@ import ( "encoding/json" "fmt" "net/http" + "strings" "testing" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -46,29 +48,29 @@ func TestExecute(t *testing.T) { url := testserver.StartSocketHttpServer(t, requests) testCases := []struct { - desc string - arguments *commandargs.Shell - expectedOutput string + desc string + arguments *commandargs.Shell + expectedUsername string }{ { - desc: "With a known username", - arguments: &commandargs.Shell{GitlabUsername: "alex-doe"}, - expectedOutput: "Welcome to GitLab, @alex-doe!\n", + desc: "With a known username", + arguments: &commandargs.Shell{GitlabUsername: "alex-doe"}, + expectedUsername: "@alex-doe", }, { - desc: "With a known key id", - arguments: &commandargs.Shell{GitlabKeyId: "1"}, - expectedOutput: "Welcome to GitLab, @alex-doe!\n", + desc: "With a known key id", + arguments: &commandargs.Shell{GitlabKeyId: "1"}, + expectedUsername: "@alex-doe", }, { - desc: "With an unknown key", - arguments: &commandargs.Shell{GitlabKeyId: "-1"}, - expectedOutput: "Welcome to GitLab, Anonymous!\n", + desc: "With an unknown key", + arguments: &commandargs.Shell{GitlabKeyId: "-1"}, + expectedUsername: "Anonymous", }, { - desc: "With an unknown username", - arguments: &commandargs.Shell{GitlabUsername: "unknown"}, - expectedOutput: "Welcome to GitLab, Anonymous!\n", + desc: "With an unknown username", + arguments: &commandargs.Shell{GitlabUsername: "unknown"}, + expectedUsername: "Anonymous", }, } @@ -81,10 +83,14 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + ctxWithLogMetadata, err := cmd.Execute(context.Background()) + + expectedOutput := fmt.Sprintf("Welcome to GitLab, %s!\n", tc.expectedUsername) + expectedUsername := strings.TrimLeft(tc.expectedUsername, "@") require.NoError(t, err) - require.Equal(t, tc.expectedOutput, buffer.String()) + require.Equal(t, expectedOutput, buffer.String()) + require.Equal(t, expectedUsername, ctxWithLogMetadata.Value("metadata").(command.LogMetadata).Username) }) } } @@ -123,7 +129,7 @@ func TestFailingExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, tc.expectedError) diff --git a/internal/command/healthcheck/healthcheck.go b/internal/command/healthcheck/healthcheck.go index e80fe2a1c6df22d047367d7191158efdf5bff0a5..206a97f7c39b706e240de8b7fa2520035d658d4b 100644 --- a/internal/command/healthcheck/healthcheck.go +++ b/internal/command/healthcheck/healthcheck.go @@ -19,20 +19,20 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { response, err := c.runCheck(ctx) if err != nil { - return fmt.Errorf("%v: FAILED - %v", apiMessage, err) + return ctx, fmt.Errorf("%v: FAILED - %v", apiMessage, err) } fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", apiMessage) if !response.Redis { - return fmt.Errorf("%v: FAILED", redisMessage) + return ctx, fmt.Errorf("%v: FAILED", redisMessage) } fmt.Fprintf(c.ReadWriter.Out, "%v: OK\n", redisMessage) - return nil + return ctx, nil } func (c *Command) runCheck(ctx context.Context) (*healthcheck.Response, error) { diff --git a/internal/command/healthcheck/healthcheck_test.go b/internal/command/healthcheck/healthcheck_test.go index 12a8444c490cffc96d43d5345e53eff9592637a1..d1c2a6ba54a9ff3ef28a0c33d875c717af2bac8b 100644 --- a/internal/command/healthcheck/healthcheck_test.go +++ b/internal/command/healthcheck/healthcheck_test.go @@ -53,7 +53,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, "Internal API available: OK\nRedis available via internal API: OK\n", buffer.String()) @@ -68,7 +68,7 @@ func TestFailingRedisExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Error(t, err, "Redis available via internal API: FAILED") require.Equal(t, "Internal API available: OK\n", buffer.String()) } @@ -82,7 +82,7 @@ func TestFailingAPIExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: buffer}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Empty(t, buffer.String()) require.EqualError(t, err, "Internal API available: FAILED - Internal API unreachable") } diff --git a/internal/command/lfsauthenticate/lfsauthenticate.go b/internal/command/lfsauthenticate/lfsauthenticate.go index a06ac93a42af87931b620a373ba00c9cfd70a032..07d530f9942e49e9f2e6ef09d07fee8730de6d93 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate.go +++ b/internal/command/lfsauthenticate/lfsauthenticate.go @@ -8,6 +8,7 @@ import ( "gitlab.com/gitlab-org/labkit/log" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/shared/accessverifier" @@ -37,10 +38,10 @@ type Payload struct { ExpiresIn int `json:"expires_in,omitempty"` } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { args := c.Args.SshArgs if len(args) < 3 { - return disallowedcommand.Error + return ctx, disallowedcommand.Error } // e.g. git-lfs-authenticate user/repo.git download @@ -49,14 +50,20 @@ func (c *Command) Execute(ctx context.Context) error { action, err := actionFromOperation(operation) if err != nil { - return err + return ctx, err } accessResponse, err := c.verifyAccess(ctx, action, repo) if err != nil { - return err + return ctx, err } + metadata := command.NewLogMetadata( + accessResponse.Gitaly.Repo.GlProjectPath, + accessResponse.Username, + ) + ctxWithLogMetadata := context.WithValue(ctx, "metadata", metadata) + payload, err := c.authenticate(ctx, operation, repo, accessResponse.UserId) if err != nil { // return nothing just like Ruby's GitlabShell#lfs_authenticate does @@ -65,12 +72,12 @@ func (c *Command) Execute(ctx context.Context) error { log.Fields{"operation": operation, "repo": repo, "user_id": accessResponse.UserId}, ).WithError(err).Debug("lfsauthenticate: execute: LFS authentication failed") - return nil + return ctxWithLogMetadata, nil } fmt.Fprintf(c.ReadWriter.Out, "%s\n", payload) - return nil + return ctxWithLogMetadata, nil } func actionFromOperation(operation string) (commandargs.CommandType, error) { diff --git a/internal/command/lfsauthenticate/lfsauthenticate_test.go b/internal/command/lfsauthenticate/lfsauthenticate_test.go index 709608c3c44547d42ef1cebcd0264aca2b94e4ba..1279a789b3c66a7682c1e2908eef47e895bf1421 100644 --- a/internal/command/lfsauthenticate/lfsauthenticate_test.go +++ b/internal/command/lfsauthenticate/lfsauthenticate_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" @@ -54,7 +55,7 @@ func TestFailedRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Error(t, err) require.Equal(t, tc.expectedOutput, err.Error()) @@ -109,8 +110,14 @@ func TestLfsAuthenticateRequests(t *testing.T) { } body := map[string]interface{}{ - "gl_id": glId, - "status": true, + "gl_id": glId, + "status": true, + "gl_username": "alex-doe", + "gitaly": map[string]interface{}{ + "repository": map[string]interface{}{ + "gl_project_path": "group/project-path", + }, + }, } require.NoError(t, json.NewEncoder(w).Encode(body)) }, @@ -145,10 +152,15 @@ func TestLfsAuthenticateRequests(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, } - err := cmd.Execute(context.Background()) - require.NoError(t, err) + ctxWithLogMetadata, err := cmd.Execute(context.Background()) + require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) + + metadata := ctxWithLogMetadata.Value("metadata").(command.LogMetadata) + require.Equal(t, "alex-doe", metadata.Username) + require.Equal(t, "group/project-path", metadata.Project) + require.Equal(t, "group", metadata.RootNamespace) }) } } diff --git a/internal/command/personalaccesstoken/personalaccesstoken.go b/internal/command/personalaccesstoken/personalaccesstoken.go index fcf7dda1c8d066ea4b4af6ca77bbcba164ab5069..c4f3deec39457298cb4c7ed116c004dfbb0d8c00 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken.go +++ b/internal/command/personalaccesstoken/personalaccesstoken.go @@ -34,10 +34,10 @@ type tokenArgs struct { ExpiresDate string // Calculated, a TTL is passed from command-line. } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { err := c.parseTokenArgs() if err != nil { - return err + return ctx, err } log.WithContextFields(ctx, log.Fields{ @@ -46,13 +46,14 @@ func (c *Command) Execute(ctx context.Context) error { response, err := c.getPersonalAccessToken(ctx) if err != nil { - return err + return ctx, err } fmt.Fprint(c.ReadWriter.Out, "Token: "+response.Token+"\n") fmt.Fprint(c.ReadWriter.Out, "Scopes: "+strings.Join(response.Scopes, ",")+"\n") fmt.Fprint(c.ReadWriter.Out, "Expires: "+response.ExpiresAt+"\n") - return nil + + return ctx, nil } func (c *Command) parseTokenArgs() error { diff --git a/internal/command/personalaccesstoken/personalaccesstoken_test.go b/internal/command/personalaccesstoken/personalaccesstoken_test.go index c3434ce4f874eac1baf5c9729a3ee979eea0bf54..711f7dac2eb293002cf4f436ca09ff23230fc936 100644 --- a/internal/command/personalaccesstoken/personalaccesstoken_test.go +++ b/internal/command/personalaccesstoken/personalaccesstoken_test.go @@ -167,7 +167,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) if tc.expectedError == "" { require.NoError(t, err) diff --git a/internal/command/receivepack/gitalycall_test.go b/internal/command/receivepack/gitalycall_test.go index 9f70189fb491ffd4e36da90c7d69bef846e73dd9..c9321821e8e8ee7d29a2e9d5ba66ed1fd92a5967 100644 --- a/internal/command/receivepack/gitalycall_test.go +++ b/internal/command/receivepack/gitalycall_test.go @@ -72,7 +72,7 @@ func TestReceivePack(t *testing.T) { ctx := correlation.ContextWithCorrelation(context.Background(), "a-correlation-id") ctx = correlation.ContextWithClientName(ctx, "gitlab-shell-tests") - err := cmd.Execute(ctx) + _, err := cmd.Execute(ctx) require.NoError(t, err) if tc.username != "" { diff --git a/internal/command/receivepack/receivepack.go b/internal/command/receivepack/receivepack.go index c9ef7cdc386163cd4a347caa058a5d5bd57444f0..3c3ac277b153a3c3e16659c92cb2e59476a014e8 100644 --- a/internal/command/receivepack/receivepack.go +++ b/internal/command/receivepack/receivepack.go @@ -3,6 +3,7 @@ package receivepack import ( "context" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/githttp" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" @@ -18,18 +19,23 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { args := c.Args.SshArgs if len(args) != 2 { - return disallowedcommand.Error + return ctx, disallowedcommand.Error } repo := args[1] response, err := c.verifyAccess(ctx, repo) if err != nil { - return err + return ctx, err } + ctxWithLogMetadata := context.WithValue(ctx, "metadata", command.NewLogMetadata( + response.Gitaly.Repo.GlProjectPath, + response.Username, + )) + if response.IsCustomAction() { // When `geo_proxy_direct_to_primary` feature flag is enabled, a Git over HTTP direct request // to primary repo is performed instead of proxying the request through Gitlab Rails. @@ -42,7 +48,7 @@ func (c *Command) Execute(ctx context.Context) error { Response: response, } - return cmd.Execute(ctx) + return ctxWithLogMetadata, cmd.Execute(ctx) } customAction := customaction.Command{ @@ -50,10 +56,10 @@ func (c *Command) Execute(ctx context.Context) error { ReadWriter: c.ReadWriter, EOFSent: true, } - return customAction.Execute(ctx, response) + return ctxWithLogMetadata, customAction.Execute(ctx, response) } - return c.performGitalyCall(ctx, response) + return ctxWithLogMetadata, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/receivepack/receivepack_test.go b/internal/command/receivepack/receivepack_test.go index 17622bb1a0be3bf986fbdc60b99eb3606757207c..581903e0c1534c55152fab8683ae047fc4741f41 100644 --- a/internal/command/receivepack/receivepack_test.go +++ b/internal/command/receivepack/receivepack_test.go @@ -8,24 +8,41 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper/requesthandlers" ) +func TestAllowedAccess(t *testing.T) { + gitalyAddress, _ := testserver.StartGitalyServer(t, "unix") + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + cmd, _ := setup(t, "1", requests) + cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) + + ctxWithLogMetadata, err := cmd.Execute(context.Background()) + + require.NoError(t, err) + metadata := ctxWithLogMetadata.Value("metadata").(command.LogMetadata) + require.Equal(t, "alex-doe", metadata.Username) + require.Equal(t, "group/project-path", metadata.Project) + require.Equal(t, "group", metadata.RootNamespace) +} + func TestForbiddenAccess(t *testing.T) { requests := requesthandlers.BuildDisallowedByApiHandlers(t) cmd, _ := setup(t, "disallowed", requests) - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.Equal(t, "Disallowed by API call", err.Error()) } func TestCustomReceivePack(t *testing.T) { cmd, output := setup(t, "1", requesthandlers.BuildAllowedWithCustomActionsHandlers(t)) - require.NoError(t, cmd.Execute(context.Background())) + _, err := cmd.Execute(context.Background()) + require.NoError(t, err) require.Equal(t, "customoutput", output.String()) } diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go index 7496396ab2dc96e7b00191929fa24aa2caf74c46..8828c71d04063879572bbcf889ef6332a901d780 100644 --- a/internal/command/twofactorrecover/twofactorrecover.go +++ b/internal/command/twofactorrecover/twofactorrecover.go @@ -22,7 +22,7 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { ctxlog := log.ContextLogger(ctx) ctxlog.Debug("twofactorrecover: execute: Waiting for user input") @@ -34,7 +34,7 @@ func (c *Command) Execute(ctx context.Context) error { fmt.Fprintln(c.ReadWriter.Out, "\nNew recovery codes have *not* been generated. Existing codes will remain valid.") } - return nil + return ctx, nil } func (c *Command) getUserAnswer(ctx context.Context) string { diff --git a/internal/command/twofactorrecover/twofactorrecover_test.go b/internal/command/twofactorrecover/twofactorrecover_test.go index 7e20a06528917fd15a7526c250f4cad0c82f9b0f..8f86777ecaac199547cfb49c338c7604706df4c0 100644 --- a/internal/command/twofactorrecover/twofactorrecover_test.go +++ b/internal/command/twofactorrecover/twofactorrecover_test.go @@ -132,7 +132,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, tc.expectedOutput, output.String()) diff --git a/internal/command/twofactorverify/twofactorverify.go b/internal/command/twofactorverify/twofactorverify.go index 4041de0c855754deef662e0143a644e9ad091a24..cbe68e66caf851c6a3708dbee96e71c82baf3bdf 100644 --- a/internal/command/twofactorverify/twofactorverify.go +++ b/internal/command/twofactorverify/twofactorverify.go @@ -25,10 +25,10 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { client, err := twofactorverify.NewClient(c.Config) if err != nil { - return err + return ctx, err } ctx, cancel := context.WithTimeout(ctx, timeout) @@ -67,7 +67,7 @@ func (c *Command) Execute(ctx context.Context) error { log.WithContextFields(ctx, log.Fields{"message": message}).Info("Two factor verify command finished") fmt.Fprintf(c.ReadWriter.Out, "\n%v\n", message) - return nil + return ctx, nil } func (c *Command) getOTP(ctx context.Context) (string, error) { diff --git a/internal/command/twofactorverify/twofactorverify_test.go b/internal/command/twofactorverify/twofactorverify_test.go index 213c02532a2ed341a1dec183f43c9ee1b11f6390..4629be93709529be01d2d5a922a5215e5d393ab0 100644 --- a/internal/command/twofactorverify/twofactorverify_test.go +++ b/internal/command/twofactorverify/twofactorverify_test.go @@ -160,7 +160,7 @@ func TestExecute(t *testing.T) { ReadWriter: &readwriter.ReadWriter{Out: output, In: input}, } - err := cmd.Execute(context.Background()) + _, err := cmd.Execute(context.Background()) require.NoError(t, err) require.Equal(t, prompt+"\n"+tc.expectedOutput, output.String()) @@ -183,7 +183,10 @@ func TestCanceledContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error) - go func() { errCh <- cmd.Execute(ctx) }() + go func() { + _, err := cmd.Execute(ctx) + errCh <- err + }() cancel() require.NoError(t, <-errCh) diff --git a/internal/command/uploadarchive/gitalycall_test.go b/internal/command/uploadarchive/gitalycall_test.go index 0479e3054471cee8ca5cbcaa042f973c12f9f5c5..8004e20d114bc87ff86b9f120e883693e8269b22 100644 --- a/internal/command/uploadarchive/gitalycall_test.go +++ b/internal/command/uploadarchive/gitalycall_test.go @@ -56,7 +56,7 @@ func TestUploadArchive(t *testing.T) { ctx := correlation.ContextWithCorrelation(context.Background(), "a-correlation-id") ctx = correlation.ContextWithClientName(ctx, "gitlab-shell-tests") - err := cmd.Execute(ctx) + _, err := cmd.Execute(ctx) require.NoError(t, err) require.Equal(t, "UploadArchive: "+repo, output.String()) diff --git a/internal/command/uploadarchive/uploadarchive.go b/internal/command/uploadarchive/uploadarchive.go index dcdd144076dd9ccf5ad90342aabacf3373de76ed..4d1b2641f8bfce9884969c59205bbf9abf183e7f 100644 --- a/internal/command/uploadarchive/uploadarchive.go +++ b/internal/command/uploadarchive/uploadarchive.go @@ -3,6 +3,7 @@ package uploadarchive import ( "context" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/shared/accessverifier" @@ -16,19 +17,25 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { args := c.Args.SshArgs if len(args) != 2 { - return disallowedcommand.Error + return ctx, disallowedcommand.Error } repo := args[1] response, err := c.verifyAccess(ctx, repo) if err != nil { - return err + return ctx, err } - return c.performGitalyCall(ctx, response) + metadata := command.NewLogMetadata( + response.Gitaly.Repo.GlProjectPath, + response.Username, + ) + ctxWithLogMetadata := context.WithValue(ctx, "metadata", metadata) + + return ctxWithLogMetadata, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/uploadarchive/uploadarchive_test.go b/internal/command/uploadarchive/uploadarchive_test.go index 86a40315ef093d7799e2617f499eb702284f0459..245f083f6885a1aaaf96e4c2f0525d8f1628298e 100644 --- a/internal/command/uploadarchive/uploadarchive_test.go +++ b/internal/command/uploadarchive/uploadarchive_test.go @@ -8,24 +8,48 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper/requesthandlers" ) +func TestAllowedAccess(t *testing.T) { + gitalyAddress, _ := testserver.StartGitalyServer(t, "unix") + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + cmd, _ := setup(t, "1", requests) + cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) + + ctxWithLogMetadata, err := cmd.Execute(context.Background()) + + require.NoError(t, err) + metadata := ctxWithLogMetadata.Value("metadata").(command.LogMetadata) + require.Equal(t, "alex-doe", metadata.Username) + require.Equal(t, "group/project-path", metadata.Project) + require.Equal(t, "group", metadata.RootNamespace) +} + func TestForbiddenAccess(t *testing.T) { requests := requesthandlers.BuildDisallowedByApiHandlers(t) + + cmd, _ := setup(t, "disallowed", requests) + + _, err := cmd.Execute(context.Background()) + require.Equal(t, "Disallowed by API call", err.Error()) +} + +func setup(t *testing.T, keyId string, requests []testserver.TestRequestHandler) (*Command, *bytes.Buffer) { url := testserver.StartHttpServer(t, requests) output := &bytes.Buffer{} + input := bytes.NewBufferString("input") cmd := &Command{ Config: &config.Config{GitlabUrl: url}, - Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-upload-archive", "group/repo"}}, - ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, + Args: &commandargs.Shell{GitlabKeyId: keyId, SshArgs: []string{"git-upload-archive", "group/repo"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, } - err := cmd.Execute(context.Background()) - require.Equal(t, "Disallowed by API call", err.Error()) + return cmd, output } diff --git a/internal/command/uploadpack/gitalycall_test.go b/internal/command/uploadpack/gitalycall_test.go index 874d12e6a2779316e7e6dff69b8f01d392f8aa3c..dfa189d58a4e084af123213dbf25710c3cd7e47a 100644 --- a/internal/command/uploadpack/gitalycall_test.go +++ b/internal/command/uploadpack/gitalycall_test.go @@ -57,7 +57,7 @@ func TestUploadPack(t *testing.T) { ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, } - err := cmd.Execute(ctx) + _, err := cmd.Execute(ctx) require.NoError(t, err) require.Equal(t, "SSHUploadPackWithSidechannel: "+repo, output.String()) diff --git a/internal/command/uploadpack/uploadpack.go b/internal/command/uploadpack/uploadpack.go index 725093a11bb9b919e1e6fb3fc3bd211aa937180d..3d402e0f46ca0b6f20a1b1201dbaded6bd33fef3 100644 --- a/internal/command/uploadpack/uploadpack.go +++ b/internal/command/uploadpack/uploadpack.go @@ -3,6 +3,7 @@ package uploadpack import ( "context" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/shared/accessverifier" @@ -17,28 +18,34 @@ type Command struct { ReadWriter *readwriter.ReadWriter } -func (c *Command) Execute(ctx context.Context) error { +func (c *Command) Execute(ctx context.Context) (context.Context, error) { args := c.Args.SshArgs if len(args) != 2 { - return disallowedcommand.Error + return ctx, disallowedcommand.Error } repo := args[1] response, err := c.verifyAccess(ctx, repo) if err != nil { - return err + return ctx, err } + metadata := command.NewLogMetadata( + response.Gitaly.Repo.GlProjectPath, + response.Username, + ) + ctxWithLogMetadata := context.WithValue(ctx, "metadata", metadata) + if response.IsCustomAction() { customAction := customaction.Command{ Config: c.Config, ReadWriter: c.ReadWriter, EOFSent: false, } - return customAction.Execute(ctx, response) + return ctxWithLogMetadata, customAction.Execute(ctx, response) } - return c.performGitalyCall(ctx, response) + return ctxWithLogMetadata, c.performGitalyCall(ctx, response) } func (c *Command) verifyAccess(ctx context.Context, repo string) (*accessverifier.Response, error) { diff --git a/internal/command/uploadpack/uploadpack_test.go b/internal/command/uploadpack/uploadpack_test.go index 5456cae14acd27be7ce062177ca0e4cff6c31ff8..dd96c3a343a0eed10815c74a3f04c42ad8319b03 100644 --- a/internal/command/uploadpack/uploadpack_test.go +++ b/internal/command/uploadpack/uploadpack_test.go @@ -8,24 +8,48 @@ import ( "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/commandargs" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command/readwriter" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper/requesthandlers" ) +func TestAllowedAccess(t *testing.T) { + gitalyAddress, _ := testserver.StartGitalyServer(t, "unix") + requests := requesthandlers.BuildAllowedWithGitalyHandlers(t, gitalyAddress) + cmd, _ := setup(t, "1", requests) + cmd.Config.GitalyClient.InitSidechannelRegistry(context.Background()) + + ctxWithLogMetadata, err := cmd.Execute(context.Background()) + + require.NoError(t, err) + metadata := ctxWithLogMetadata.Value("metadata").(command.LogMetadata) + require.Equal(t, "alex-doe", metadata.Username) + require.Equal(t, "group/project-path", metadata.Project) + require.Equal(t, "group", metadata.RootNamespace) +} + func TestForbiddenAccess(t *testing.T) { requests := requesthandlers.BuildDisallowedByApiHandlers(t) + + cmd, _ := setup(t, "disallowed", requests) + + _, err := cmd.Execute(context.Background()) + require.Equal(t, "Disallowed by API call", err.Error()) +} + +func setup(t *testing.T, keyId string, requests []testserver.TestRequestHandler) (*Command, *bytes.Buffer) { url := testserver.StartHttpServer(t, requests) output := &bytes.Buffer{} + input := bytes.NewBufferString("input") cmd := &Command{ Config: &config.Config{GitlabUrl: url}, - Args: &commandargs.Shell{GitlabKeyId: "disallowed", SshArgs: []string{"git-upload-pack", "group/repo"}}, - ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output}, + Args: &commandargs.Shell{GitlabKeyId: keyId, SshArgs: []string{"git-upload-pack", "group/repo"}}, + ReadWriter: &readwriter.ReadWriter{ErrOut: output, Out: output, In: input}, } - err := cmd.Execute(context.Background()) - require.Equal(t, "Disallowed by API call", err.Error()) + return cmd, output } diff --git a/internal/sshd/connection.go b/internal/sshd/connection.go index e691d331d46ff82143d198918e123b9d8246d52a..850f91efdd75a64c3c61c6e9b3518e3ef0bcaec6 100644 --- a/internal/sshd/connection.go +++ b/internal/sshd/connection.go @@ -36,7 +36,7 @@ type connection struct { remoteAddr string } -type channelHandler func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error +type channelHandler func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) func newConnection(cfg *config.Config, nconn net.Conn) *connection { maxSessions := cfg.Server.ConcurrentSessionsLimit @@ -50,12 +50,12 @@ func newConnection(cfg *config.Config, nconn net.Conn) *connection { } } -func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) { +func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handler channelHandler) context.Context { log.WithContextFields(ctx, log.Fields{}).Info("server: handleConn: start") sconn, chans, err := c.initServerConn(ctx, srvCfg) if err != nil { - return + return ctx } if c.cfg.Server.ClientAliveInterval > 0 { @@ -64,10 +64,17 @@ func (c *connection) handle(ctx context.Context, srvCfg *ssh.ServerConfig, handl go c.sendKeepAliveMsg(ctx, sconn, ticker) } - c.handleRequests(ctx, sconn, chans, handler) + ctxWithLogMetadataChan := make(chan context.Context) + defer close(ctxWithLogMetadataChan) + + go c.handleRequests(ctx, sconn, chans, ctxWithLogMetadataChan, handler) + + ctxWithLogMetadata := <-ctxWithLogMetadataChan reason := sconn.Wait() log.WithContextFields(ctx, log.Fields{"reason": reason}).Info("server: handleConn: done") + + return ctxWithLogMetadata } func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfig) (*ssh.ServerConn, <-chan ssh.NewChannel, error) { @@ -94,22 +101,25 @@ func (c *connection) initServerConn(ctx context.Context, srvCfg *ssh.ServerConfi return sconn, chans, err } -func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, handler channelHandler) { +func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, ctxWithLogMetadataChan chan<- context.Context, handler channelHandler) { ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr}) for newChannel := range chans { ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested") + if newChannel.ChannelType() != "session" { ctxlog.Info("connection: handleRequests: unknown channel type") newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") continue } + if !c.concurrentSessions.TryAcquire(1) { ctxlog.Info("connection: handleRequests: too many concurrent sessions") newChannel.Reject(ssh.ResourceShortage, "too many concurrent sessions") metrics.SshdHitMaxSessions.Inc() continue } + channel, requests, err := newChannel.Accept() if err != nil { ctxlog.WithError(err).Error("connection: handleRequests: accepting channel failed") @@ -134,10 +144,12 @@ func (c *connection) handleRequests(ctx context.Context, sconn *ssh.ServerConn, }() metrics.SliSshdSessionsTotal.Inc() - err := handler(sconn, channel, requests) + ctxWithLogMetadata, err := handler(ctx, sconn, channel, requests) if err != nil { c.trackError(ctxlog, err) } + + ctxWithLogMetadataChan <- ctxWithLogMetadata }() } diff --git a/internal/sshd/connection_test.go b/internal/sshd/connection_test.go index 5438935f01b909900843220d8b6dea25406c587b..88e4e426e46a7ffb0b491dc53e614f86e1d32042 100644 --- a/internal/sshd/connection_test.go +++ b/internal/sshd/connection_test.go @@ -81,23 +81,25 @@ func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (boo return true, nil, nil } -func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) { +func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel, chan<- context.Context) { cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum}} conn := &connection{cfg: cfg, concurrentSessions: semaphore.NewWeighted(sessionsNum)} chans := make(chan ssh.NewChannel, 1) chans <- newChannel - return conn, chans + ctxWithLogMetadataChan := make(chan context.Context) + + return conn, chans, ctxWithLogMetadataChan } func TestPanicDuringSessionIsRecovered(t *testing.T) { newChannel := &fakeNewChannel{channelType: "session"} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) numSessions := 0 require.NotPanics(t, func() { - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { numSessions += 1 close(chans) panic("This is a panic") @@ -112,10 +114,10 @@ func TestUnknownChannelType(t *testing.T) { defer close(rejectCh) newChannel := &fakeNewChannel{channelType: "unknown session", rejectCh: rejectCh} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) go func() { - conn.handleRequests(context.Background(), nil, chans, nil) + conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, nil) }() rejectionData := <-rejectCh @@ -129,15 +131,15 @@ func TestTooManySessions(t *testing.T) { defer close(rejectCh) newChannel := &fakeNewChannel{channelType: "session", rejectCh: rejectCh} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(context.Background(), nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { <-ctx.Done() // Keep the accepted channel open until the end of the test - return nil + return ctx, nil }) }() @@ -147,13 +149,14 @@ func TestTooManySessions(t *testing.T) { func TestAcceptSessionSucceeds(t *testing.T) { newChannel := &fakeNewChannel{channelType: "session"} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) + ctx := context.Background() channelHandled := false - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { channelHandled = true close(chans) - return nil + return ctx, nil }) require.True(t, channelHandled) @@ -165,13 +168,14 @@ func TestAcceptSessionFails(t *testing.T) { acceptErr := errors.New("some failure") newChannel := &fakeNewChannel{channelType: "session", acceptCh: acceptCh, acceptErr: acceptErr} - conn, chans := setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) + ctx := context.Background() channelHandled := false go func() { - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { channelHandled = true - return nil + return ctx, nil }) }() @@ -203,11 +207,12 @@ func TestSessionsMetrics(t *testing.T) { initialSessionsErrorTotal := testutil.ToFloat64(metrics.SliSshdSessionsErrorsTotal) newChannel := &fakeNewChannel{channelType: "session"} + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) + ctx := context.Background() - conn, chans := setup(1, newChannel) - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { close(chans) - return errors.New("custom error") + return ctx, errors.New("custom error") }) eventuallyInDelta(t, initialSessionsTotal+1, testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1) @@ -224,11 +229,13 @@ func TestSessionsMetrics(t *testing.T) { {"not our ref", grpcstatus.Error(grpccodes.Internal, `rpc error: code = Internal desc = cmd wait: exit status 128, stderr: "fatal: git upload-pack: not our ref 9106d18f6a1b8022f6517f479696f3e3ea5e68c1"`)}, } { t.Run(ignoredError.desc, func(t *testing.T) { - conn, chans = setup(1, newChannel) + conn, chans, ctxWithLogMetadataChan := setup(1, newChannel) ignored := ignoredError.err - conn.handleRequests(context.Background(), nil, chans, func(*ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) error { + ctx := context.Background() + + conn.handleRequests(ctx, nil, chans, ctxWithLogMetadataChan, func(context.Context, *ssh.ServerConn, ssh.Channel, <-chan *ssh.Request) (context.Context, error) { close(chans) - return ignored + return ctx, ignored }) eventuallyInDelta(t, initialSessionsTotal+2+float64(i), testutil.ToFloat64(metrics.SliSshdSessionsTotal), 0.1) diff --git a/internal/sshd/session.go b/internal/sshd/session.go index 3394b2a55fd31d919f9db064f09955b25d6aa59b..df0cfc165d57267178dfca7ed2b4fccfb181da55 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -49,7 +49,8 @@ type exitStatusReq struct { ExitStatus uint32 } -func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) error { +func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) (context.Context, error) { + ctxWithLogMetadata := ctx ctxlog := log.ContextLogger(ctx) ctxlog.Debug("session: handle: entering request loop") @@ -70,13 +71,13 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) erro case "exec": // The command has been executed as `ssh user@host command` or `exec` channel has been used // in the app implementation - shouldContinue, err = s.handleExec(ctx, req) + ctxWithLogMetadata, shouldContinue, err = s.handleExec(ctx, req) case "shell": // The command has been entered into the shell or `shell` channel has been used // in the app implementation shouldContinue = false var status uint32 - status, err = s.handleShell(ctx, req) + ctxWithLogMetadata, status, err = s.handleShell(ctx, req) s.exit(ctx, status) default: // Ignore unknown requests but don't terminate the session @@ -99,7 +100,7 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) erro ctxlog.Debug("session: handle: exiting request loop") - return err + return ctxWithLogMetadata, err } func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) { @@ -132,21 +133,22 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) (bool, error) return true, nil } -func (s *session) handleExec(ctx context.Context, req *ssh.Request) (bool, error) { +func (s *session) handleExec(ctx context.Context, req *ssh.Request) (context.Context, bool, error) { var execRequest execRequest + if err := ssh.Unmarshal(req.Payload, &execRequest); err != nil { - return false, err + return ctx, false, err } s.execCmd = execRequest.Command - status, err := s.handleShell(ctx, req) - s.exit(ctx, status) + ctxWithLogMetadata, status, err := s.handleShell(ctx, req) + s.exit(ctxWithLogMetadata, status) - return false, err + return ctxWithLogMetadata, false, err } -func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, error) { +func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Context, uint32, error) { ctxlog := log.ContextLogger(ctx) if req.WantReply { @@ -183,7 +185,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, er s.toStderr(ctx, "ERROR: Failed to parse command: %v\n", err.Error()) } - return 128, err + return ctx, 128, err } cmdName := reflect.TypeOf(cmd).String() @@ -194,18 +196,19 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (uint32, er }).Info("session: handleShell: executing command") metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration) - if err := cmd.Execute(ctx); err != nil { + ctxWithLogMetadata, err := cmd.Execute(ctx) + if err != nil { grpcStatus := grpcstatus.Convert(err) if grpcStatus.Code() != grpccodes.Internal { s.toStderr(ctx, "ERROR: %v\n", grpcStatus.Message()) } - return 1, err + return ctx, 1, err } ctxlog.Info("session: handleShell: command executed successfully") - return 0, nil + return ctxWithLogMetadata, 0, nil } func (s *session) toStderr(ctx context.Context, format string, args ...interface{}) { diff --git a/internal/sshd/session_test.go b/internal/sshd/session_test.go index d1bff7e824395c3f315e023b432a04f022a5077e..d0f4f51f12687a474e68ecbcad65e2d7638cea3b 100644 --- a/internal/sshd/session_test.go +++ b/internal/sshd/session_test.go @@ -146,7 +146,7 @@ func TestHandleExec(t *testing.T) { r := &ssh.Request{Payload: tc.payload} s.channel = f - shouldContinue, err := s.handleExec(context.Background(), r) + _, shouldContinue, err := s.handleExec(context.Background(), r) require.Equal(t, tc.expectedErr, err) require.Equal(t, false, shouldContinue) @@ -210,7 +210,7 @@ func TestHandleShell(t *testing.T) { } r := &ssh.Request{} - exitCode, err := s.handleShell(context.Background(), r) + _, exitCode, err := s.handleShell(context.Background(), r) if tc.expectedErrString != "" { require.Equal(t, tc.expectedErrString, err.Error()) diff --git a/internal/sshd/sshd.go b/internal/sshd/sshd.go index f26458228eee787cdf500ed4c3ad62a8c5e22da8..0328b6db4bddb7a5dc1c372242d346b8a995b574 100644 --- a/internal/sshd/sshd.go +++ b/internal/sshd/sshd.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/v14/client" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/gitlabnet" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/metrics" @@ -193,7 +194,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { started := time.Now() conn := newConnection(s.Config, nconn) - conn.handle(ctx, s.serverConfig.get(ctx), func(sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) error { + ctxWithLogMetadata := conn.handle(ctx, s.serverConfig.get(ctx), func(ctx context.Context, sconn *ssh.ServerConn, channel ssh.Channel, requests <-chan *ssh.Request) (context.Context, error) { session := &session{ cfg: s.Config, channel: channel, @@ -206,7 +207,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) { return session.handle(ctx, requests) }) - ctxlog.WithFields(log.Fields{"duration_s": time.Since(started).Seconds()}).Info("access: finish") + ctxlog.WithFields(log.Fields{"duration_s": time.Since(started).Seconds(), "meta": extractMetaDataFromContext(ctxWithLogMetadata)}).Info("access: finish") } func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { @@ -228,6 +229,16 @@ func (s *Server) proxyPolicy() (proxyproto.PolicyFunc, error) { } } +func extractMetaDataFromContext(ctx context.Context) command.LogMetadata { + metadata := command.LogMetadata{} + + if ctx.Value("metadata") != nil { + metadata = ctx.Value("metadata").(command.LogMetadata) + } + + return metadata +} + func staticProxyPolicy(policy proxyproto.Policy) proxyproto.PolicyFunc { return func(_ net.Addr) (proxyproto.Policy, error) { return policy, nil diff --git a/internal/sshd/sshd_test.go b/internal/sshd/sshd_test.go index c14a9f5f852da5c01a476d1a88e7adb9ec57d6d3..bdbb3c8cf388d49165ba661a3c5b73e3f85b8c46 100644 --- a/internal/sshd/sshd_test.go +++ b/internal/sshd/sshd_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/crypto/ssh" "gitlab.com/gitlab-org/gitlab-shell/v14/client/testserver" + "gitlab.com/gitlab-org/gitlab-shell/v14/internal/command" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/config" "gitlab.com/gitlab-org/gitlab-shell/v14/internal/testhelper" ) @@ -348,6 +349,25 @@ func TestLoginGraceTime(t *testing.T) { verifyStatus(t, s, StatusClosed) } +func TestExtractMetaDataFromContext(t *testing.T) { + rootNameSpace := "flightjs" + project := fmt.Sprintf("%s/Flight", rootNameSpace) + username := "alex-doe" + ctxWithLogMetadata := context.WithValue(context.Background(), "metadata", command.NewLogMetadata(project, username)) + + metadata := extractMetaDataFromContext(ctxWithLogMetadata) + + require.Equal(t, command.LogMetadata{Project: project, Username: username, RootNamespace: rootNameSpace}, metadata) +} + +func TestExtractMetaDataFromContextWithoutMetaData(t *testing.T) { + ctxWithLogMetadata := context.Background() + + metadata := extractMetaDataFromContext(ctxWithLogMetadata) + + require.Equal(t, command.LogMetadata{}, metadata) +} + func setupServer(t *testing.T) *Server { t.Helper() diff --git a/internal/testhelper/requesthandlers/requesthandlers.go b/internal/testhelper/requesthandlers/requesthandlers.go index de1fdf942a52f9896d0be3b82b387202eac539c0..a58b67ae0c3286658fb97026f9d010a5350d193f 100644 --- a/internal/testhelper/requesthandlers/requesthandlers.go +++ b/internal/testhelper/requesthandlers/requesthandlers.go @@ -38,6 +38,7 @@ func BuildAllowedWithGitalyHandlers(t *testing.T, gitalyAddress string) []testse "gl_id": "1", "gl_key_type": "key", "gl_key_id": 123, + "gl_username": "alex-doe", "gitaly": map[string]interface{}{ "repository": map[string]interface{}{ "storage_name": "storage_name",