From 201ae9921852927ff6c4f38b2cde3c8a78aba427 Mon Sep 17 00:00:00 2001 From: Mustafa Bayar Date: Tue, 30 Jul 2024 23:39:33 +0200 Subject: [PATCH] repository: Return error from stream on EOF As documented on the ClientStream interface: SendMsg is generally called by generated code. On error, SendMsg aborts the stream. If the error was generated by the client, the status is returned directly; otherwise, io.EOF is returned and the status of the stream may be discovered using RecvMsg. There are many places where we use chunker to send data. But majority of those are ServerStream which does not have the same issue. Therefore, instead of implementing error handling directly in chunker, we decided to only touch the places which are using chunker for client side streaming and implement the error handling in those places. --- internal/backup/repository.go | 50 +++++++++++++++++++++--------- internal/backup/repository_test.go | 44 ++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/internal/backup/repository.go b/internal/backup/repository.go index 73fd549bead..31582d4d92d 100644 --- a/internal/backup/repository.go +++ b/internal/backup/repository.go @@ -148,14 +148,14 @@ func (rr *remoteRepository) createBundlePatterns(ctx context.Context, out io.Wri Repository: rr.repo, Patterns: [][]byte{line}, }); err != nil { - return fmt.Errorf("remote repository: create bundle patterns: %w", err) + return fmt.Errorf("remote repository: create bundle patterns: send: %w", err) } } if err := c.Flush(); err != nil { - return fmt.Errorf("remote repository: create bundle patterns: %w", err) + return fmt.Errorf("remote repository: create bundle patterns: flush: %w", err) } if err := stream.CloseSend(); err != nil { - return fmt.Errorf("remote repository: create bundle patterns: %w", err) + return fmt.Errorf("remote repository: create bundle patterns: close: %w", err) } bundle := streamio.NewReader(func() ([]byte, error) { @@ -167,7 +167,7 @@ func (rr *remoteRepository) createBundlePatterns(ctx context.Context, out io.Wri }) if _, err := io.Copy(out, bundle); err != nil { - return fmt.Errorf("remote repository: create bundle patterns: %w", err) + return fmt.Errorf("remote repository: create bundle patterns: recv: %w", err) } return nil @@ -192,13 +192,25 @@ func (s *createBundleFromRefListSender) Append(msg proto.Message) { // Send should send the current response message func (s *createBundleFromRefListSender) Send() error { - return s.stream.Send(&s.chunk) + err := s.stream.Send(&s.chunk) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := s.stream.Recv(); recvErr != nil { + return recvErr + } + } + } + + return err } // updateRefsSender chunks requests to the UpdateReferences RPC. type updateRefsSender struct { - refs []*gitalypb.UpdateReferencesRequest_Update - send func([]*gitalypb.UpdateReferencesRequest_Update) error + refs []*gitalypb.UpdateReferencesRequest_Update + stream gitalypb.RefService_UpdateReferencesClient } // Reset should create a fresh response message. @@ -213,7 +225,21 @@ func (s *updateRefsSender) Append(msg proto.Message) { // Send should send the current response message func (s *updateRefsSender) Send() error { - return s.send(s.refs) + err := s.stream.Send(&gitalypb.UpdateReferencesRequest{ + Updates: s.refs, + }) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := s.stream.CloseAndRecv(); recvErr != nil { + return recvErr + } + } + } + + return err } // Remove removes the repository. Does not return an error if the repository @@ -274,13 +300,7 @@ func (rr *remoteRepository) ResetRefs(ctx context.Context, refs []git.Reference) return fmt.Errorf("send initial request: %w", err) } - chunker := chunk.New(&updateRefsSender{ - send: func(updates []*gitalypb.UpdateReferencesRequest_Update) error { - return stream.Send(&gitalypb.UpdateReferencesRequest{ - Updates: updates, - }) - }, - }) + chunker := chunk.New(&updateRefsSender{stream: stream}) for _, ref := range refs[1:] { if err := chunker.Send(&gitalypb.UpdateReferencesRequest_Update{ diff --git a/internal/backup/repository_test.go b/internal/backup/repository_test.go index fe9789562bb..19f592a9bed 100644 --- a/internal/backup/repository_test.go +++ b/internal/backup/repository_test.go @@ -1,6 +1,8 @@ package backup_test import ( + "io" + "math/rand" "testing" "github.com/stretchr/testify/require" @@ -210,3 +212,45 @@ func TestLocalRepository_SetHeadReference(t *testing.T) { require.Equal(t, expectedHead, actualHead) require.NotEqual(t, newHead, actualHead) } + +func TestCreateBundlePatterns_HandleEOF(t *testing.T) { + cfg := testcfg.Build(t) + ctx := testhelper.Context(t) + + cfg.SocketPath = testserver.RunGitalyServer(t, cfg, setup.RegisterAll) + + conn, err := client.Dial(ctx, cfg.SocketPath) + require.NoError(t, err) + defer testhelper.MustClose(t, conn) + + // setting nil repository to replicate server returning early error + rr := backup.NewRemoteRepository(nil, conn) + + require.ErrorContains(t, + rr.CreateBundle(ctx, io.Discard, rand.New(rand.NewSource(0))), + "repository not set", + ) +} + +func TestRemoteRepository_ResetRefs_HandleEOF(t *testing.T) { + cfg := testcfg.Build(t) + testcfg.BuildGitalyHooks(t, cfg) + cfg.SocketPath = testserver.RunGitalyServer(t, cfg, setup.RegisterAll) + ctx := testhelper.Context(t) + + conn, err := client.Dial(ctx, cfg.SocketPath) + require.NoError(t, err) + defer testhelper.MustClose(t, conn) + + repo, _ := gittest.CreateRepository(t, ctx, cfg) + rr := backup.NewRemoteRepository(repo, conn) + + // Create a large number of references to pass chunker limit + refs := make([]git.Reference, 30000) + for i := range refs { + // Set references to an invalid ObjectID to trigger error + refs[i] = git.NewReference("refs/heads/main", "invalid-object-id") + } + + require.ErrorContains(t, rr.ResetRefs(ctx, refs), "invalid object ID") +} -- GitLab