diff --git a/internal/backup/repository.go b/internal/backup/repository.go index 73fd549bead129008c967bc2740348d5e14ff8ef..31582d4d92dad355be545e4e0cb2649d2d99ca4e 100644 --- a/internal/backup/repository.go +++ b/internal/backup/repository.go @@ -148,14 +148,14 @@ func (rr *remoteRepository) createBundlePatterns(ctx context.Context, out io.Wri Repository: rr.repo, Patterns: [][]byte{line}, }); err != nil { - return fmt.Errorf("remote repository: create bundle patterns: %w", err) + return fmt.Errorf("remote repository: create bundle patterns: send: %w", err) } } if err := c.Flush(); err != nil { - return fmt.Errorf("remote repository: create bundle patterns: %w", err) + return fmt.Errorf("remote repository: create bundle patterns: flush: %w", err) } if err := stream.CloseSend(); err != nil { - return fmt.Errorf("remote repository: create bundle patterns: %w", err) + return fmt.Errorf("remote repository: create bundle patterns: close: %w", err) } bundle := streamio.NewReader(func() ([]byte, error) { @@ -167,7 +167,7 @@ func (rr *remoteRepository) createBundlePatterns(ctx context.Context, out io.Wri }) if _, err := io.Copy(out, bundle); err != nil { - return fmt.Errorf("remote repository: create bundle patterns: %w", err) + return fmt.Errorf("remote repository: create bundle patterns: recv: %w", err) } return nil @@ -192,13 +192,25 @@ func (s *createBundleFromRefListSender) Append(msg proto.Message) { // Send should send the current response message func (s *createBundleFromRefListSender) Send() error { - return s.stream.Send(&s.chunk) + err := s.stream.Send(&s.chunk) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := s.stream.Recv(); recvErr != nil { + return recvErr + } + } + } + + return err } // updateRefsSender chunks requests to the UpdateReferences RPC. type updateRefsSender struct { - refs []*gitalypb.UpdateReferencesRequest_Update - send func([]*gitalypb.UpdateReferencesRequest_Update) error + refs []*gitalypb.UpdateReferencesRequest_Update + stream gitalypb.RefService_UpdateReferencesClient } // Reset should create a fresh response message. @@ -213,7 +225,21 @@ func (s *updateRefsSender) Append(msg proto.Message) { // Send should send the current response message func (s *updateRefsSender) Send() error { - return s.send(s.refs) + err := s.stream.Send(&gitalypb.UpdateReferencesRequest{ + Updates: s.refs, + }) + if err != nil { + // On error, SendMsg aborts the stream. If the error was generated by + // the client, the status is returned directly; otherwise, io.EOF is + // returned and the status of the stream may be discovered using Recv. + if errors.Is(err, io.EOF) { + if _, recvErr := s.stream.CloseAndRecv(); recvErr != nil { + return recvErr + } + } + } + + return err } // Remove removes the repository. Does not return an error if the repository @@ -274,13 +300,7 @@ func (rr *remoteRepository) ResetRefs(ctx context.Context, refs []git.Reference) return fmt.Errorf("send initial request: %w", err) } - chunker := chunk.New(&updateRefsSender{ - send: func(updates []*gitalypb.UpdateReferencesRequest_Update) error { - return stream.Send(&gitalypb.UpdateReferencesRequest{ - Updates: updates, - }) - }, - }) + chunker := chunk.New(&updateRefsSender{stream: stream}) for _, ref := range refs[1:] { if err := chunker.Send(&gitalypb.UpdateReferencesRequest_Update{ diff --git a/internal/backup/repository_test.go b/internal/backup/repository_test.go index fe9789562bb867194c0661f8284475a223fbd4c0..19f592a9bedd9248f012d6f5c9fc81f8e66d1521 100644 --- a/internal/backup/repository_test.go +++ b/internal/backup/repository_test.go @@ -1,6 +1,8 @@ package backup_test import ( + "io" + "math/rand" "testing" "github.com/stretchr/testify/require" @@ -210,3 +212,45 @@ func TestLocalRepository_SetHeadReference(t *testing.T) { require.Equal(t, expectedHead, actualHead) require.NotEqual(t, newHead, actualHead) } + +func TestCreateBundlePatterns_HandleEOF(t *testing.T) { + cfg := testcfg.Build(t) + ctx := testhelper.Context(t) + + cfg.SocketPath = testserver.RunGitalyServer(t, cfg, setup.RegisterAll) + + conn, err := client.Dial(ctx, cfg.SocketPath) + require.NoError(t, err) + defer testhelper.MustClose(t, conn) + + // setting nil repository to replicate server returning early error + rr := backup.NewRemoteRepository(nil, conn) + + require.ErrorContains(t, + rr.CreateBundle(ctx, io.Discard, rand.New(rand.NewSource(0))), + "repository not set", + ) +} + +func TestRemoteRepository_ResetRefs_HandleEOF(t *testing.T) { + cfg := testcfg.Build(t) + testcfg.BuildGitalyHooks(t, cfg) + cfg.SocketPath = testserver.RunGitalyServer(t, cfg, setup.RegisterAll) + ctx := testhelper.Context(t) + + conn, err := client.Dial(ctx, cfg.SocketPath) + require.NoError(t, err) + defer testhelper.MustClose(t, conn) + + repo, _ := gittest.CreateRepository(t, ctx, cfg) + rr := backup.NewRemoteRepository(repo, conn) + + // Create a large number of references to pass chunker limit + refs := make([]git.Reference, 30000) + for i := range refs { + // Set references to an invalid ObjectID to trigger error + refs[i] = git.NewReference("refs/heads/main", "invalid-object-id") + } + + require.ErrorContains(t, rr.ResetRefs(ctx, refs), "invalid object ID") +}