diff --git a/internal/git/gitpipe/pipeline_test.go b/internal/git/gitpipe/pipeline_test.go index 2771cb8f802484f4bd7c0f749920b34bbfd01b58..6bee6799f33365051016ac727767573ba54f1918 100644 --- a/internal/git/gitpipe/pipeline_test.go +++ b/internal/git/gitpipe/pipeline_test.go @@ -379,7 +379,7 @@ func TestPipeline_forEachRef(t *testing.T) { // We certainly don't want to hard-code all the references, so we just cross-check with the // localrepo implementation to verify that both return the same data. - refs, err := repo.GetReferences(ctx) + refs, err := repo.GetReferences(ctx, false) require.NoError(t, err) require.Equal(t, len(refs), len(objectsByRef)) diff --git a/internal/git/localrepo/refs.go b/internal/git/localrepo/refs.go index ac9c69fd1c135389dca195cac2cfde4dc9253ae0..b6068156701aa3bb59fd1cd2540c082c69485ee5 100644 --- a/internal/git/localrepo/refs.go +++ b/internal/git/localrepo/refs.go @@ -64,7 +64,7 @@ func (repo *Repo) ResolveRevision(ctx context.Context, revision git.Revision) (g // GetReference looks up and returns the given reference. Returns a // ReferenceNotFound error if the reference was not found. func (repo *Repo) GetReference(ctx context.Context, reference git.ReferenceName) (git.Reference, error) { - refs, err := repo.getReferences(ctx, 1, reference.String()) + refs, err := repo.getReferences(ctx, 1, false, reference.String()) if err != nil { return git.Reference{}, err } @@ -82,17 +82,30 @@ func (repo *Repo) GetReference(ctx context.Context, reference git.ReferenceName) // HasBranches determines whether there is at least one branch in the // repository. func (repo *Repo) HasBranches(ctx context.Context) (bool, error) { - refs, err := repo.getReferences(ctx, 1, "refs/heads/") + refs, err := repo.getReferences(ctx, 1, false, "refs/heads/") return len(refs) > 0, err } // GetReferences returns references matching any of the given patterns. If no patterns are given, // all references are returned. -func (repo *Repo) GetReferences(ctx context.Context, patterns ...string) ([]git.Reference, error) { - return repo.getReferences(ctx, 0, patterns...) +func (repo *Repo) GetReferences(ctx context.Context, head bool, patterns ...string) ([]git.Reference, error) { + return repo.getReferences(ctx, 0, head, patterns...) } -func (repo *Repo) getReferences(ctx context.Context, limit uint, patterns ...string) ([]git.Reference, error) { +func (repo *Repo) getReferences(ctx context.Context, limit uint, head bool, patterns ...string) ([]git.Reference, error) { + var refs []git.Reference + if head { + headOid, err := repo.ResolveRevision(ctx, git.Revision("HEAD")) + switch { + case errors.Is(err, git.ErrReferenceNotFound): + // ignore missing HEAD + case err != nil: + return nil, err + default: + refs = append(refs, git.NewReference("HEAD", headOid.String())) + } + } + flags := []git.Option{git.Flag{Name: "--format=%(refname)%00%(objectname)%00%(symref)"}} if limit > 0 { flags = append(flags, git.Flag{Name: fmt.Sprintf("--count=%d", limit)}) @@ -109,7 +122,6 @@ func (repo *Repo) getReferences(ctx context.Context, limit uint, patterns ...str scanner := bufio.NewScanner(cmd) - var refs []git.Reference for scanner.Scan() { line := bytes.SplitN(scanner.Bytes(), []byte{0}, 3) if len(line) != 3 { @@ -135,7 +147,7 @@ func (repo *Repo) getReferences(ctx context.Context, limit uint, patterns ...str // GetBranches returns all branches. func (repo *Repo) GetBranches(ctx context.Context) ([]git.Reference, error) { - return repo.GetReferences(ctx, "refs/heads/") + return repo.GetReferences(ctx, false, "refs/heads/") } // UpdateRef updates reference from oldValue to newValue. If oldValue is a diff --git a/internal/git/localrepo/refs_test.go b/internal/git/localrepo/refs_test.go index 4e9e911ba215badbbd8741be527bfce7f2465f2a..7745f4c7fe9729d9a1eb56448018393e77a268d5 100644 --- a/internal/git/localrepo/refs_test.go +++ b/internal/git/localrepo/refs_test.go @@ -148,11 +148,17 @@ func TestRepo_GetReferences(t *testing.T) { repo, _ := setupRepo(t, false) + headOid, err := repo.ResolveRevision(ctx, git.Revision("HEAD")) + require.NoError(t, err) + + head := git.NewReference("HEAD", headOid.String()) + masterBranch, err := repo.GetReference(ctx, "refs/heads/master") require.NoError(t, err) testcases := []struct { desc string + head bool patterns []string match func(t *testing.T, refs []git.Reference) }{ @@ -200,11 +206,19 @@ func TestRepo_GetReferences(t *testing.T) { require.Empty(t, refs) }, }, + { + desc: "master branch with head", + head: true, + patterns: []string{"refs/heads/master"}, + match: func(t *testing.T, refs []git.Reference) { + require.Equal(t, []git.Reference{head, masterBranch}, refs) + }, + }, } for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { - refs, err := repo.GetReferences(ctx, tc.patterns...) + refs, err := repo.GetReferences(ctx, tc.head, tc.patterns...) require.NoError(t, err) tc.match(t, refs) }) diff --git a/internal/git/localrepo/remote_test.go b/internal/git/localrepo/remote_test.go index bc97cb2228f3b10d29985ea5bfc7afda97298a62..3b0984fac2acf45267e5f682d13885957eedd943 100644 --- a/internal/git/localrepo/remote_test.go +++ b/internal/git/localrepo/remote_test.go @@ -600,10 +600,10 @@ func TestRepo_Push(t *testing.T) { require.Equal(t, tc.sshCommand, string(gitSSHCommand)) - actual, err := pushRepo.GetReferences(ctx) + actual, err := pushRepo.GetReferences(ctx, false) require.NoError(t, err) - expected, err := sourceRepo.GetReferences(ctx, tc.expectedFilter...) + expected, err := sourceRepo.GetReferences(ctx, false, tc.expectedFilter...) require.NoError(t, err) require.Equal(t, expected, actual) diff --git a/internal/git/updateref/updateref_test.go b/internal/git/updateref/updateref_test.go index e6affcc991899ead719d48ca6b2935aae137efae..4a75134282bc66a83d875899b1c4cd7c14cb0341 100644 --- a/internal/git/updateref/updateref_test.go +++ b/internal/git/updateref/updateref_test.go @@ -149,7 +149,7 @@ func TestBulkOperation(t *testing.T) { require.NoError(t, updater.Wait()) - refs, err := repo.GetReferences(ctx, "refs/") + refs, err := repo.GetReferences(ctx, false, "refs/") require.NoError(t, err) require.Greater(t, len(refs), 1000, "At least 1000 refs should be present") } diff --git a/internal/gitaly/service/cleanup/apply_bfg_object_map_stream_test.go b/internal/gitaly/service/cleanup/apply_bfg_object_map_stream_test.go index d64965eae6ea83fa3ecdfaef7990fff760cb7fdc..fd1b50c16631017bd9f7659581de75342f4e1117 100644 --- a/internal/gitaly/service/cleanup/apply_bfg_object_map_stream_test.go +++ b/internal/gitaly/service/cleanup/apply_bfg_object_map_stream_test.go @@ -66,7 +66,7 @@ func TestApplyBfgObjectMapStreamSuccess(t *testing.T) { require.NoError(t, err) // Ensure that the internal refs are gone, but the others still exist - refs, err := repo.GetReferences(ctx, "refs/") + refs, err := repo.GetReferences(ctx, false, "refs/") require.NoError(t, err) refNames := make([]string, len(refs)) diff --git a/internal/gitaly/service/ref/delete_refs.go b/internal/gitaly/service/ref/delete_refs.go index 19899814de93f5224375fead74c72e70b54bbf6d..490ea0cb2572e3e8b460b8bee3c54096ae48bc2d 100644 --- a/internal/gitaly/service/ref/delete_refs.go +++ b/internal/gitaly/service/ref/delete_refs.go @@ -86,7 +86,7 @@ func (s *server) refsToRemove(ctx context.Context, repo *localrepo.Repo, req *gi prefixes[i] = string(prefix) } - existingRefs, err := repo.GetReferences(ctx) + existingRefs, err := repo.GetReferences(ctx, false) if err != nil { return nil, err } diff --git a/internal/gitaly/service/ref/delete_refs_test.go b/internal/gitaly/service/ref/delete_refs_test.go index 81a687eec58c0845e41e412cce4cc94ab8bb99c9..d6ad6bdae148161ca24923cfd53025a3fda93f02 100644 --- a/internal/gitaly/service/ref/delete_refs_test.go +++ b/internal/gitaly/service/ref/delete_refs_test.go @@ -61,7 +61,7 @@ func TestSuccessfulDeleteRefs(t *testing.T) { require.NoError(t, err) // Ensure that the internal refs are gone, but the others still exist - refs, err := localrepo.NewTestRepo(t, cfg, repo).GetReferences(ctx, "refs/") + refs, err := localrepo.NewTestRepo(t, cfg, repo).GetReferences(ctx, false, "refs/") require.NoError(t, err) refNames := make([]string, len(refs)) diff --git a/internal/gitaly/service/remote/update_remote_mirror.go b/internal/gitaly/service/remote/update_remote_mirror.go index bc7ed1dd22bec956a395ba5ddda984eacc78f352..d61345aee9065febf22332a33c22f9cf1a24a513 100644 --- a/internal/gitaly/service/remote/update_remote_mirror.go +++ b/internal/gitaly/service/remote/update_remote_mirror.go @@ -101,7 +101,7 @@ func (s *server) updateRemoteMirror(stream gitalypb.RemoteService_UpdateRemoteMi return fmt.Errorf("get remote references: %w", err) } - localRefs, err := repo.GetReferences(ctx, "refs/heads/", "refs/tags/") + localRefs, err := repo.GetReferences(ctx, false, "refs/heads/", "refs/tags/") if err != nil { return fmt.Errorf("get local references: %w", err) } diff --git a/internal/gitaly/service/repository/calculate_checksum.go b/internal/gitaly/service/repository/calculate_checksum.go index 7c246131316f432107844bec3e56f3131a58a37d..5625d4feb964630ebd05a4acd8c3e17387f3b398 100644 --- a/internal/gitaly/service/repository/calculate_checksum.go +++ b/internal/gitaly/service/repository/calculate_checksum.go @@ -1,11 +1,11 @@ package repository import ( - "bufio" "bytes" "context" "crypto/sha1" "encoding/hex" + "fmt" "math/big" "regexp" "strings" @@ -19,35 +19,36 @@ import ( var refWhitelist = regexp.MustCompile(`HEAD|(refs/(heads|tags|keep-around|merge-requests|environments|notes)/)`) func (s *server) CalculateChecksum(ctx context.Context, in *gitalypb.CalculateChecksumRequest) (*gitalypb.CalculateChecksumResponse, error) { - repo := in.GetRepository() + repo := s.localrepo(in.GetRepository()) - repoPath, err := s.locator.GetRepoPath(repo) - if err != nil { - return nil, err - } + refs, err := repo.GetReferences(ctx, true) + if len(refs) == 0 || err != nil { + if s.isValidRepo(ctx, in.GetRepository()) { + return &gitalypb.CalculateChecksumResponse{Checksum: git.ZeroOID.String()}, nil + } - cmd, err := s.gitCmdFactory.New(ctx, repo, git.SubCmd{Name: "show-ref", Flags: []git.Option{git.Flag{Name: "--head"}}}) - if err != nil { if _, ok := status.FromError(err); ok { return nil, err } - return nil, status.Errorf(codes.Internal, "CalculateChecksum: gitCommand: %v", err) + repoPath, err := s.locator.GetRepoPath(repo) + if err != nil { + return nil, err + } + + return nil, status.Errorf(codes.DataLoss, "CalculateChecksum: not a git repository '%s'", repoPath) } var checksum *big.Int - scanner := bufio.NewScanner(cmd) - for scanner.Scan() { - ref := scanner.Bytes() - - if !refWhitelist.Match(ref) { + for _, ref := range refs { + if !refWhitelist.MatchString(ref.Name.String()) { continue } h := sha1.New() // hash.Hash will never return an error. - _, _ = h.Write(ref) + _, _ = fmt.Fprintf(h, "%s %s", ref.Target, ref.Name) hash := hex.EncodeToString(h.Sum(nil)) hashIntBase16, _ := (&big.Int{}).SetString(hash, 16) @@ -59,18 +60,6 @@ func (s *server) CalculateChecksum(ctx context.Context, in *gitalypb.CalculateCh } } - if err := scanner.Err(); err != nil { - return nil, status.Errorf(codes.Internal, err.Error()) - } - - if err := cmd.Wait(); checksum == nil || err != nil { - if s.isValidRepo(ctx, repo) { - return &gitalypb.CalculateChecksumResponse{Checksum: git.ZeroOID.String()}, nil - } - - return nil, status.Errorf(codes.DataLoss, "CalculateChecksum: not a git repository '%s'", repoPath) - } - return &gitalypb.CalculateChecksumResponse{Checksum: hex.EncodeToString(checksum.Bytes())}, nil } diff --git a/internal/gitaly/service/repository/fetch_remote_test.go b/internal/gitaly/service/repository/fetch_remote_test.go index 8dc6751278df473c6dd60ea014337fc1933dbae9..1cf2797900f442fcb7c4273ea5143d396fdac9ca 100644 --- a/internal/gitaly/service/repository/fetch_remote_test.go +++ b/internal/gitaly/service/repository/fetch_remote_test.go @@ -199,9 +199,9 @@ func TestFetchRemote_withDefaultRefmaps(t *testing.T) { require.NoError(t, err) require.NotNil(t, resp) - sourceRefs, err := sourceRepo.GetReferences(ctx) + sourceRefs, err := sourceRepo.GetReferences(ctx, false) require.NoError(t, err) - targetRefs, err := targetRepo.GetReferences(ctx) + targetRefs, err := targetRepo.GetReferences(ctx, false) require.NoError(t, err) require.Equal(t, sourceRefs, targetRefs) }