diff --git a/internal/helper/command.go b/internal/helper/command.go index 46cf1543e8dd4dc41c0b44aa4ac722ebc9715eee..fd10c503970aa48445e8e4d0b8f19cf58da2087f 100644 --- a/internal/helper/command.go +++ b/internal/helper/command.go @@ -3,9 +3,12 @@ package helper import ( "fmt" "io" + "log" "os" "os/exec" "syscall" + + "golang.org/x/net/context" ) // Command encapsulates operations with commands creates with NewCommand @@ -21,8 +24,10 @@ func (c *Command) Kill() { } // GitCommandReader creates a git Command with the given args -func GitCommandReader(args ...string) (*Command, error) { - return NewCommand(exec.Command("git", args...), nil, nil) +func GitCommandReader(ctx context.Context, args ...string) (*Command, error) { + // TODO: when we switch to Go 1.7, switch to using + // exec.CommandContext + return NewCommand(CommandWrapper(ctx, "git", args...), nil, nil) } // NewCommand creates a Command from an exec.Cmd @@ -66,6 +71,15 @@ func NewCommand(cmd *exec.Cmd, stdin io.Reader, stdout io.Writer, env ...string) return command, nil } +func cleanUpProcessGroupNoWait(cmd *exec.Cmd) { + process := cmd.Process + if process != nil && process.Pid > 0 { + // Send SIGTERM to the process group of cmd + syscall.Kill(-process.Pid, syscall.SIGTERM) + } + +} + // CleanUpProcessGroup will send a SIGTERM signal to the process group // belonging to the `cmd` process func CleanUpProcessGroup(cmd *exec.Cmd) { @@ -73,11 +87,7 @@ func CleanUpProcessGroup(cmd *exec.Cmd) { return } - process := cmd.Process - if process != nil && process.Pid > 0 { - // Send SIGTERM to the process group of cmd - syscall.Kill(-process.Pid, syscall.SIGTERM) - } + cleanUpProcessGroupNoWait(cmd) // reap our child process cmd.Wait() @@ -97,3 +107,39 @@ func ExitStatus(err error) (int, bool) { return waitStatus.ExitStatus(), true } + +// CommandWrapper ensures that the command is executed within a context, +// and ensures that the process group is terminated with the +func CommandWrapper(ctx context.Context, name string, arg ...string) *exec.Cmd { + command := exec.Command(name, arg...) + + if ctx != nil { + // Create a channel to listen to the command completion + done := make(chan error, 1) + go func() { + done <- command.Wait() + }() + + // Wait for the process to shutdown or the + // context to be complete + go func() { + select { + case <-ctx.Done(): + log.Printf("Context done, killing process") + cleanUpProcessGroupNoWait(command) + + case err := <-done: + if err != nil { + log.Printf("process done with error = %v", err) + } else { + log.Print("process done gracefully without error") + } + cleanUpProcessGroupNoWait(command) + + } + + }() + } + + return command +} diff --git a/internal/service/commit/isancestor.go b/internal/service/commit/isancestor.go index 3afa86fa33c16bd50710ed4931c4a8659aaf5767..9cccd448c1dee846f43f6d773e04fe9b69036b34 100644 --- a/internal/service/commit/isancestor.go +++ b/internal/service/commit/isancestor.go @@ -3,7 +3,6 @@ package commit import ( "io/ioutil" "log" - "os/exec" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -30,13 +29,13 @@ func (s *server) CommitIsAncestor(ctx context.Context, in *pb.CommitIsAncestorRe return nil, grpc.Errorf(codes.InvalidArgument, message) } - ret, err := commitIsAncestorName(repoPath, in.AncestorId, in.ChildId) + ret, err := commitIsAncestorName(ctx, repoPath, in.AncestorId, in.ChildId) return &pb.CommitIsAncestorResponse{Value: ret}, err } // Assumes that `path`, `ancestorID` and `childID` are populated :trollface: -func commitIsAncestorName(path, ancestorID, childID string) (bool, error) { - osCommand := exec.Command("git", "--git-dir", path, "merge-base", "--is-ancestor", ancestorID, childID) +func commitIsAncestorName(ctx context.Context, path, ancestorID, childID string) (bool, error) { + osCommand := helper.CommandWrapper(ctx, "git", "--git-dir", path, "merge-base", "--is-ancestor", ancestorID, childID) cmd, err := helper.NewCommand(osCommand, nil, ioutil.Discard) if err != nil { return false, grpc.Errorf(codes.Internal, err.Error()) diff --git a/internal/service/diff/commit.go b/internal/service/diff/commit.go index 99986519d56a9d798a78425b5c2d5130e75e7e4b..7d6779e8c5634734b47b8ff7e583bdc3fff39df7 100644 --- a/internal/service/diff/commit.go +++ b/internal/service/diff/commit.go @@ -10,6 +10,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + + "golang.org/x/net/context" ) type requestWithLeftRightCommitIds interface { @@ -60,7 +62,7 @@ func (s *server) CommitDiff(in *pb.CommitDiffRequest, stream pb.Diff_CommitDiffS } } - err = eachDiff("CommitDiff", cmdArgs, func(diff *diff.Diff) error { + err = eachDiff(stream.Context(), "CommitDiff", cmdArgs, func(diff *diff.Diff) error { err := stream.Send(&pb.CommitDiffResponse{ FromPath: diff.FromPath, ToPath: diff.ToPath, @@ -135,7 +137,7 @@ func (s *server) CommitDelta(in *pb.CommitDeltaRequest, stream pb.Diff_CommitDel return nil } - err = eachDiff("CommitDelta", cmdArgs, func(diff *diff.Diff) error { + err = eachDiff(stream.Context(), "CommitDelta", cmdArgs, func(diff *diff.Diff) error { delta := &pb.CommitDelta{ FromPath: diff.FromPath, ToPath: diff.ToPath, @@ -178,8 +180,8 @@ func validateRequest(in requestWithLeftRightCommitIds) error { return nil } -func eachDiff(rpc string, cmdArgs []string, callback func(*diff.Diff) error) error { - cmd, err := helper.GitCommandReader(cmdArgs...) +func eachDiff(ctx context.Context, rpc string, cmdArgs []string, callback func(*diff.Diff) error) error { + cmd, err := helper.GitCommandReader(ctx, cmdArgs...) if err != nil { return grpc.Errorf(codes.Internal, "%s: cmd: %v", rpc, err) } diff --git a/internal/service/ref/refname.go b/internal/service/ref/refname.go index d3b402b941407a5e605044fb6237d5490609df98..88631899cea45dff3a1ec387c9d5dc2349a22f78 100644 --- a/internal/service/ref/refname.go +++ b/internal/service/ref/refname.go @@ -28,7 +28,7 @@ func (s *server) FindRefName(ctx context.Context, in *pb.FindRefNameRequest) (*p return nil, grpc.Errorf(codes.InvalidArgument, message) } - ref, err := findRefName(repoPath, in.CommitId, string(in.Prefix)) + ref, err := findRefName(ctx, repoPath, in.CommitId, string(in.Prefix)) if err != nil { return nil, grpc.Errorf(codes.Internal, err.Error()) } @@ -37,8 +37,8 @@ func (s *server) FindRefName(ctx context.Context, in *pb.FindRefNameRequest) (*p } // We assume `path` and `commitID` and `prefix` are non-empty -func findRefName(path, commitID, prefix string) (string, error) { - cmd, err := helper.GitCommandReader("--git-dir", path, "for-each-ref", "--format=%(refname)", "--count=1", prefix, "--contains", commitID) +func findRefName(ctx context.Context, path, commitID, prefix string) (string, error) { + cmd, err := helper.GitCommandReader(ctx, "--git-dir", path, "for-each-ref", "--format=%(refname)", "--count=1", prefix, "--contains", commitID) if err != nil { return "", err } diff --git a/internal/service/ref/refs.go b/internal/service/ref/refs.go index 051e9fb5ac29b724df70bc556343e27ee9c91caa..c831f82d649cfbd87c811d1edeccbf65020e2614 100644 --- a/internal/service/ref/refs.go +++ b/internal/service/ref/refs.go @@ -33,7 +33,7 @@ func handleGitCommand(w refsWriter, r io.Reader) error { return w.Flush() } -func findRefs(writer refsWriter, repo *pb.Repository, pattern string, args ...string) error { +func findRefs(ctx context.Context, writer refsWriter, repo *pb.Repository, pattern string, args ...string) error { repoPath, err := helper.GetRepoPath(repo) if err != nil { return err @@ -49,7 +49,7 @@ func findRefs(writer refsWriter, repo *pb.Repository, pattern string, args ...st args = append(baseArgs, args...) } - cmd, err := helper.GitCommandReader(args...) + cmd, err := helper.GitCommandReader(ctx, args...) if err != nil { return err } @@ -64,18 +64,18 @@ func findRefs(writer refsWriter, repo *pb.Repository, pattern string, args ...st // FindAllBranchNames creates a stream of ref names for all branches in the given repository func (s *server) FindAllBranchNames(in *pb.FindAllBranchNamesRequest, stream pb.Ref_FindAllBranchNamesServer) error { - return findRefs(newFindAllBranchNamesWriter(stream, s.MaxMsgSize), in.Repository, "refs/heads") + return findRefs(stream.Context(), newFindAllBranchNamesWriter(stream, s.MaxMsgSize), in.Repository, "refs/heads") } // FindAllTagNames creates a stream of ref names for all tags in the given repository func (s *server) FindAllTagNames(in *pb.FindAllTagNamesRequest, stream pb.Ref_FindAllTagNamesServer) error { - return findRefs(newFindAllTagNamesWriter(stream, s.MaxMsgSize), in.Repository, "refs/tags") + return findRefs(stream.Context(), newFindAllTagNamesWriter(stream, s.MaxMsgSize), in.Repository, "refs/tags") } -func _findBranchNames(repoPath string) ([][]byte, error) { +func _findBranchNames(ctx context.Context, repoPath string) ([][]byte, error) { var names [][]byte - cmd, err := helper.GitCommandReader("--git-dir", repoPath, "for-each-ref", "refs/heads", "--format=%(refname)") + cmd, err := helper.GitCommandReader(ctx, "--git-dir", repoPath, "for-each-ref", "refs/heads", "--format=%(refname)") if err != nil { return nil, err } @@ -96,10 +96,10 @@ func _findBranchNames(repoPath string) ([][]byte, error) { return names, nil } -func _headReference(repoPath string) ([]byte, error) { +func _headReference(ctx context.Context, repoPath string) ([]byte, error) { var headRef []byte - cmd, err := helper.GitCommandReader("--git-dir", repoPath, "rev-parse", "--symbolic-full-name", "HEAD") + cmd, err := helper.GitCommandReader(ctx, "--git-dir", repoPath, "rev-parse", "--symbolic-full-name", "HEAD") if err != nil { return nil, err } @@ -119,8 +119,8 @@ func _headReference(repoPath string) ([]byte, error) { return headRef, nil } -func defaultBranchName(repoPath string) ([]byte, error) { - branches, err := findBranchNames(repoPath) +func defaultBranchName(ctx context.Context, repoPath string) ([]byte, error) { + branches, err := findBranchNames(ctx, repoPath) if err != nil { return nil, err @@ -137,7 +137,7 @@ func defaultBranchName(repoPath string) ([]byte, error) { } hasMaster := false - headRef, err := headReference(repoPath) + headRef, err := headReference(ctx, repoPath) if err != nil { return nil, err } @@ -167,7 +167,7 @@ func (s *server) FindDefaultBranchName(ctx context.Context, in *pb.FindDefaultBr log.Printf("FindDefaultBranchName: RepoPath=%q", repoPath) - defaultBranchName, err := defaultBranchName(repoPath) + defaultBranchName, err := defaultBranchName(ctx, repoPath) if err != nil { return nil, err } @@ -195,5 +195,5 @@ func (s *server) FindLocalBranches(in *pb.FindLocalBranchesRequest, stream pb.Re sortFlag := "--sort=" + parseSortKey(in.GetSortBy()) writer := newFindLocalBranchesWriter(stream, s.MaxMsgSize) - return findRefs(writer, in.Repository, "refs/heads", formatFlag, sortFlag) + return findRefs(stream.Context(), writer, in.Repository, "refs/heads", formatFlag, sortFlag) } diff --git a/internal/service/ref/refs_test.go b/internal/service/ref/refs_test.go index ae14e48ebd99895bebb16b19f737184c4f354cf3..1c4d57513e32535753d28a50d1f0c6c21b03b418 100644 --- a/internal/service/ref/refs_test.go +++ b/internal/service/ref/refs_test.go @@ -175,7 +175,7 @@ func TestInvalidRepoFindAllTagNamesRequest(t *testing.T) { } func TestHeadReference(t *testing.T) { - headRef, err := headReference(testRepoPath) + headRef, err := headReference(context.Background(), testRepoPath) if err != nil { t.Fatal(err) } @@ -193,47 +193,47 @@ func TestDefaultBranchName(t *testing.T) { testCases := []struct { desc string - findBranchNames func(string) ([][]byte, error) - headReference func(string) ([]byte, error) + findBranchNames func(context.Context, string) ([][]byte, error) + headReference func(context.Context, string) ([]byte, error) expected []byte }{ { desc: "Get first branch when only one branch exists", expected: []byte("refs/heads/foo"), - findBranchNames: func(string) ([][]byte, error) { + findBranchNames: func(context.Context, string) ([][]byte, error) { return [][]byte{[]byte("refs/heads/foo")}, nil }, - headReference: func(string) ([]byte, error) { return nil, nil }, + headReference: func(context.Context, string) ([]byte, error) { return nil, nil }, }, { desc: "Get empy ref if no branches exists", expected: nil, - findBranchNames: func(string) ([][]byte, error) { return [][]byte{}, nil }, - headReference: func(string) ([]byte, error) { return nil, nil }, + findBranchNames: func(context.Context, string) ([][]byte, error) { return [][]byte{}, nil }, + headReference: func(context.Context, string) ([]byte, error) { return nil, nil }, }, { desc: "Get the name of the head reference when more than one branch exists", expected: []byte("refs/heads/bar"), - findBranchNames: func(string) ([][]byte, error) { + findBranchNames: func(context.Context, string) ([][]byte, error) { return [][]byte{[]byte("refs/heads/foo"), []byte("refs/heads/bar")}, nil }, - headReference: func(string) ([]byte, error) { return []byte("refs/heads/bar"), nil }, + headReference: func(context.Context, string) ([]byte, error) { return []byte("refs/heads/bar"), nil }, }, { desc: "Get `ref/heads/master` when several branches exist", expected: []byte("refs/heads/master"), - findBranchNames: func(string) ([][]byte, error) { + findBranchNames: func(context.Context, string) ([][]byte, error) { return [][]byte{[]byte("refs/heads/foo"), []byte("refs/heads/master"), []byte("refs/heads/bar")}, nil }, - headReference: func(string) ([]byte, error) { return nil, nil }, + headReference: func(context.Context, string) ([]byte, error) { return nil, nil }, }, { desc: "Get the name of the first branch when several branches exists and no other conditions are met", expected: []byte("refs/heads/foo"), - findBranchNames: func(string) ([][]byte, error) { + findBranchNames: func(context.Context, string) ([][]byte, error) { return [][]byte{[]byte("refs/heads/foo"), []byte("refs/heads/bar"), []byte("refs/heads/baz")}, nil }, - headReference: func(string) ([]byte, error) { return nil, nil }, + headReference: func(context.Context, string) ([]byte, error) { return nil, nil }, }, } @@ -241,7 +241,7 @@ func TestDefaultBranchName(t *testing.T) { findBranchNames = testCase.findBranchNames headReference = testCase.headReference - defaultBranch, err := defaultBranchName("") + defaultBranch, err := defaultBranchName(context.Background(), "") if err != nil { t.Fatal(err) } diff --git a/internal/service/smarthttp/inforefs.go b/internal/service/smarthttp/inforefs.go index 59026f7a5ac1c6cc55a99a5fa03f89a410e244e2..c25a9813f70bc642e4f59f79f2cc7e88f7d31d9d 100644 --- a/internal/service/smarthttp/inforefs.go +++ b/internal/service/smarthttp/inforefs.go @@ -11,29 +11,31 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + + "golang.org/x/net/context" ) func (s *server) InfoRefsUploadPack(in *pb.InfoRefsRequest, stream pb.SmartHTTP_InfoRefsUploadPackServer) error { w := pbhelper.NewSendWriter(func(p []byte) error { return stream.Send(&pb.InfoRefsResponse{Data: p}) }) - return handleInfoRefs("upload-pack", in.Repository, w) + return handleInfoRefs(stream.Context(), "upload-pack", in.Repository, w) } func (s *server) InfoRefsReceivePack(in *pb.InfoRefsRequest, stream pb.SmartHTTP_InfoRefsReceivePackServer) error { w := pbhelper.NewSendWriter(func(p []byte) error { return stream.Send(&pb.InfoRefsResponse{Data: p}) }) - return handleInfoRefs("receive-pack", in.Repository, w) + return handleInfoRefs(stream.Context(), "receive-pack", in.Repository, w) } -func handleInfoRefs(service string, repo *pb.Repository, w io.Writer) error { +func handleInfoRefs(ctx context.Context, service string, repo *pb.Repository, w io.Writer) error { repoPath, err := helper.GetRepoPath(repo) if err != nil { return err } - cmd, err := helper.GitCommandReader(service, "--stateless-rpc", "--advertise-refs", repoPath) + cmd, err := helper.GitCommandReader(ctx, service, "--stateless-rpc", "--advertise-refs", repoPath) if err != nil { return grpc.Errorf(codes.Internal, "GetInfoRefs: cmd: %v", err) } diff --git a/internal/service/smarthttp/receive_pack.go b/internal/service/smarthttp/receive_pack.go index fa0caed2ef2ec66aa67253177671bff54551a98a..63e1fe5b87ce35e460932fc5d70e58d7578ce176 100644 --- a/internal/service/smarthttp/receive_pack.go +++ b/internal/service/smarthttp/receive_pack.go @@ -3,7 +3,6 @@ package smarthttp import ( "fmt" "log" - "os/exec" "gitlab.com/gitlab-org/gitaly/internal/helper" @@ -40,7 +39,7 @@ func (s *server) PostReceivePack(stream pb.SmartHTTP_PostReceivePackServer) erro log.Printf("PostReceivePack: RepoPath=%q GlID=%q GlRepository=%q", repoPath, req.GlId, req.GlRepository) - osCommand := exec.Command("git", "receive-pack", "--stateless-rpc", repoPath) + osCommand := helper.CommandWrapper(stream.Context(), "git", "receive-pack", "--stateless-rpc", repoPath) cmd, err := helper.NewCommand(osCommand, stdin, stdout, env...) if err != nil { diff --git a/internal/service/smarthttp/upload_pack.go b/internal/service/smarthttp/upload_pack.go index f34ee1b7faade73981d3eed62ad238c7f1c5cbcf..f4b15ccd2b2839fe848b9eca9861df464e1b47a4 100644 --- a/internal/service/smarthttp/upload_pack.go +++ b/internal/service/smarthttp/upload_pack.go @@ -2,7 +2,6 @@ package smarthttp import ( "log" - "os/exec" "gitlab.com/gitlab-org/gitaly/internal/helper" @@ -35,7 +34,7 @@ func (s *server) PostUploadPack(stream pb.SmartHTTP_PostUploadPackServer) error log.Printf("PostUploadPack: RepoPath=%q", repoPath) - osCommand := exec.Command("git", "upload-pack", "--stateless-rpc", repoPath) + osCommand := helper.CommandWrapper(stream.Context(), "git", "upload-pack", "--stateless-rpc", repoPath) cmd, err := helper.NewCommand(osCommand, stdin, stdout) if err != nil {