diff --git a/cmd/gitaly-hooks/hooks.go b/cmd/gitaly-hooks/hooks.go index c10ae9d9a9a9a8dcb2663ec3efe199214286f2e5..3ba7f49cfbb688d359fcf8a0e6ddf0baff49b65f 100644 --- a/cmd/gitaly-hooks/hooks.go +++ b/cmd/gitaly-hooks/hooks.go @@ -314,11 +314,30 @@ func preReceiveHook(ctx context.Context, payload gitcmd.HooksPayload, hookClient EnvironmentVariables: os.Environ(), GitPushOptions: gitPushOptions(), }); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := preReceiveHookStream.Recv(); recvErr != nil { + return fmt.Errorf("error when receiving data for pre-receive hook: %w", recvErr) + } + } return fmt.Errorf("error when sending request for pre-receive hook: %w", err) } f := sendFunc(streamio.NewWriter(func(p []byte) error { - return preReceiveHookStream.Send(&gitalypb.PreReceiveHookRequest{Stdin: p}) + err := preReceiveHookStream.Send(&gitalypb.PreReceiveHookRequest{Stdin: p}) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := preReceiveHookStream.Recv(); recvErr != nil { + return recvErr + } + } + } + return err }), preReceiveHookStream, os.Stdin) if returnCode, err := stream.Handler(func() (stream.StdoutStderrResponse, error) { @@ -343,11 +362,30 @@ func postReceiveHook(ctx context.Context, payload gitcmd.HooksPayload, hookClien EnvironmentVariables: os.Environ(), GitPushOptions: gitPushOptions(), }); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := postReceiveHookStream.Recv(); recvErr != nil { + return fmt.Errorf("error when receiving data for post-receive hook: %w", recvErr) + } + } return fmt.Errorf("error when sending request for post-receive hook: %w", err) } f := sendFunc(streamio.NewWriter(func(p []byte) error { - return postReceiveHookStream.Send(&gitalypb.PostReceiveHookRequest{Stdin: p}) + err := postReceiveHookStream.Send(&gitalypb.PostReceiveHookRequest{Stdin: p}) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := postReceiveHookStream.Recv(); recvErr != nil { + return recvErr + } + } + } + return err }), postReceiveHookStream, os.Stdin) if returnCode, err := stream.Handler(func() (stream.StdoutStderrResponse, error) { @@ -388,11 +426,30 @@ func referenceTransactionHook(ctx context.Context, payload gitcmd.HooksPayload, EnvironmentVariables: os.Environ(), State: state, }); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := referenceTransactionHookStream.Recv(); recvErr != nil { + return fmt.Errorf("error when receiving data for reference-transaction hook: %w", recvErr) + } + } return fmt.Errorf("error when sending request for reference-transaction hook: %w", err) } f := sendFunc(streamio.NewWriter(func(p []byte) error { - return referenceTransactionHookStream.Send(&gitalypb.ReferenceTransactionHookRequest{Stdin: p}) + err := referenceTransactionHookStream.Send(&gitalypb.ReferenceTransactionHookRequest{Stdin: p}) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := referenceTransactionHookStream.Recv(); recvErr != nil { + return recvErr + } + } + } + return err }), referenceTransactionHookStream, os.Stdin) if returnCode, err := stream.Handler(func() (stream.StdoutStderrResponse, error) { @@ -420,11 +477,30 @@ func procReceiveHook(ctx context.Context, payload gitcmd.HooksPayload, hookClien Repository: payload.Repo, EnvironmentVariables: os.Environ(), }); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := hookStream.Recv(); recvErr != nil { + return fmt.Errorf("receiving output for proc-receive hook: %w", recvErr) + } + } return fmt.Errorf("sending first proc-receive request: %w", err) } f := sendFunc(streamio.NewWriter(func(p []byte) error { - return hookStream.Send(&gitalypb.ProcReceiveHookRequest{Stdin: p}) + err := hookStream.Send(&gitalypb.ProcReceiveHookRequest{Stdin: p}) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := hookStream.Recv(); recvErr != nil { + return recvErr + } + } + } + return err }), hookStream, os.Stdin) if returnCode, err := stream.Handler(func() (stream.StdoutStderrResponse, error) { diff --git a/internal/cli/gitaly/subcmd_hooks.go b/internal/cli/gitaly/subcmd_hooks.go index 55997621d4c62cfe7c00f258243c37f12ad48fa3..a31a5f300449e0f72fafbc4322e9f88bf0705fea 100644 --- a/internal/cli/gitaly/subcmd_hooks.go +++ b/internal/cli/gitaly/subcmd_hooks.go @@ -110,12 +110,31 @@ func setRepoHooks(ctx context.Context, conn *grpc.ClientConn, reader io.Reader, RelativePath: relativePath, }, }); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := stream.CloseAndRecv(); recvErr != nil { + return recvErr + } + } return err } // Configure streamWriter to transmit tarball data to stream. streamWriter := streamio.NewWriter(func(p []byte) error { - return stream.Send(&gitalypb.SetCustomHooksRequest{Data: p}) + err := stream.Send(&gitalypb.SetCustomHooksRequest{Data: p}) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := stream.CloseAndRecv(); recvErr != nil { + return recvErr + } + } + } + return err }) if _, err := io.Copy(streamWriter, reader); err != nil { diff --git a/internal/gitaly/gitalyclient/receive_pack.go b/internal/gitaly/gitalyclient/receive_pack.go index 0610f9e103450b41daf959dd43dd10b4c89d5e07..90efe36b980829c99233b728935e3b0abfbfe879 100644 --- a/internal/gitaly/gitalyclient/receive_pack.go +++ b/internal/gitaly/gitalyclient/receive_pack.go @@ -2,6 +2,7 @@ package gitalyclient import ( "context" + "errors" "io" "gitlab.com/gitlab-org/gitaly/v16/internal/stream" @@ -22,11 +23,30 @@ func ReceivePack(ctx context.Context, conn *grpc.ClientConn, stdin io.Reader, st } if err = receivePackStream.Send(req); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := receivePackStream.Recv(); recvErr != nil { + return 0, recvErr + } + } return 0, err } inWriter := streamio.NewWriter(func(p []byte) error { - return receivePackStream.Send(&gitalypb.SSHReceivePackRequest{Stdin: p}) + err := receivePackStream.Send(&gitalypb.SSHReceivePackRequest{Stdin: p}) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := receivePackStream.Recv(); recvErr != nil { + return recvErr + } + } + } + return err }) return stream.Handler(func() (stream.StdoutStderrResponse, error) { diff --git a/internal/gitaly/gitalyclient/upload_archive.go b/internal/gitaly/gitalyclient/upload_archive.go index 03e883d028fb682da39d52fff659e6630a92d771..bc3c0fd736a69ae4c9153c12438f028cd4844fbb 100644 --- a/internal/gitaly/gitalyclient/upload_archive.go +++ b/internal/gitaly/gitalyclient/upload_archive.go @@ -2,6 +2,7 @@ package gitalyclient import ( "context" + "errors" "fmt" "io" @@ -23,11 +24,30 @@ func UploadArchive(ctx context.Context, conn *grpc.ClientConn, stdin io.Reader, } if err = uploadPackStream.Send(req); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := uploadPackStream.Recv(); recvErr != nil { + return 0, recvErr + } + } return 0, err } inWriter := streamio.NewWriter(func(p []byte) error { - return uploadPackStream.Send(&gitalypb.SSHUploadArchiveRequest{Stdin: p}) + err := uploadPackStream.Send(&gitalypb.SSHUploadArchiveRequest{Stdin: p}) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := uploadPackStream.Recv(); recvErr != nil { + return recvErr + } + } + } + return err }) return stream.Handler(func() (stream.StdoutStderrResponse, error) { diff --git a/internal/gitaly/storage/raftmgr/grpc_transport.go b/internal/gitaly/storage/raftmgr/grpc_transport.go index 30cb063be665716f3e128fff1cbfa7966c8891e6..2a885b6b65074b163d2a65ff19c952bcecbec1ab 100644 --- a/internal/gitaly/storage/raftmgr/grpc_transport.go +++ b/internal/gitaly/storage/raftmgr/grpc_transport.go @@ -181,6 +181,14 @@ func (t *GrpcTransport) sendToNode(ctx context.Context, addr string, reqs []*git for _, req := range reqs { if err := stream.Send(req); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := stream.CloseAndRecv(); recvErr != nil { + return fmt.Errorf("send request to address %s: %w", addr, recvErr) + } + } return fmt.Errorf("send request to address %s: %w", addr, err) } } @@ -332,6 +340,14 @@ func (t *GrpcTransport) SendSnapshot(ctx context.Context, pk *gitalypb.RaftParti }, }, }); err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := stream.CloseAndRecv(); recvErr != nil { + return fmt.Errorf("failed to send raft message: %w", recvErr) + } + } return fmt.Errorf("failed to send raft message: %w", err) } @@ -341,11 +357,22 @@ func (t *GrpcTransport) SendSnapshot(ctx context.Context, pk *gitalypb.RaftParti case <-stream.Context().Done(): return fmt.Errorf("context cancelled while sending snapshot: %w", ctx.Err()) default: - return stream.Send(&gitalypb.RaftSnapshotMessageRequest{ + err := stream.Send(&gitalypb.RaftSnapshotMessageRequest{ RaftSnapshotPayload: &gitalypb.RaftSnapshotMessageRequest_Chunk{ Chunk: p, }, }) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := stream.CloseAndRecv(); recvErr != nil { + return recvErr + } + } + } + return err } }) sent, err := io.Copy(sw, snapshot.file)