diff --git a/cmd/gitaly/main.go b/cmd/gitaly/main.go index 3be1e38e8b1b570c669bda4c1eeacb4a82de67ce..9decd10749e89a54dfde779af6ced3861eabd0fc 100644 --- a/cmd/gitaly/main.go +++ b/cmd/gitaly/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "net" @@ -8,6 +9,7 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -21,6 +23,7 @@ import ( "gitlab.com/gitlab-org/gitaly/internal/tempdir" "gitlab.com/gitlab-org/gitaly/internal/version" "gitlab.com/gitlab-org/labkit/tracing" + "google.golang.org/grpc" ) var ( @@ -175,11 +178,15 @@ func run(insecureListeners, secureListeners []net.Listener) error { } defer ruby.Stop() + var servers []*grpc.Server + serverErrors := make(chan error, len(insecureListeners)+len(secureListeners)) if len(insecureListeners) > 0 { insecureServer := server.NewInsecure(ruby) defer insecureServer.Stop() + servers = append(servers, insecureServer) + for _, listener := range insecureListeners { // Must pass the listener as a function argument because there is a race // between 'go' and 'for'. @@ -193,6 +200,8 @@ func run(insecureListeners, secureListeners []net.Listener) error { secureServer := server.NewSecure(ruby) defer secureServer.Stop() + servers = append(servers, secureServer) + for _, listener := range secureListeners { go func(l net.Listener) { serverErrors <- secureServer.Serve(l) @@ -202,9 +211,33 @@ func run(insecureListeners, secureListeners []net.Listener) error { select { case s := <-termCh: + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = fmt.Errorf("received signal %q", s) + for _, srv := range servers { + if err := gracefulStopServer(ctx, srv); err != nil { + log.Warnf("error while attempting a graceful stop: %v", err) + } + } case err = <-serverErrors: } return err } + +func gracefulStopServer(ctx context.Context, srv *grpc.Server) error { + done := make(chan struct{}) + go func() { + srv.GracefulStop() + close(done) + }() + + select { + case <-ctx.Done(): + srv.Stop() + return ctx.Err() + case <-done: + return nil + } +}