diff --git a/internal/git/catfile/catfile.go b/internal/git/catfile/catfile.go index 144595edafe58c9cc4e8d5c8fe2ac1c5cba4a800..b43cf15b799f7cee54e1e1ee0f4128165d764847 100644 --- a/internal/git/catfile/catfile.go +++ b/internal/git/catfile/catfile.go @@ -57,6 +57,9 @@ func ParseObjectInfo(stdout *bufio.Reader) (*ObjectInfo, error) { } info := strings.Split(infoLine, " ") + if len(info) != 3 { + return nil, fmt.Errorf("strings split: expected %d strings, got %v", len(info), info) + } objectSizeStr := info[2] objectSize, err := strconv.ParseInt(objectSizeStr, 10, 64) diff --git a/internal/service/blob/get_blob.go b/internal/service/blob/get_blob.go index 573292e669edafb980ed1d198041c444af650af0..8b72cf1a8818762c558782cf53837343665e05bc 100644 --- a/internal/service/blob/get_blob.go +++ b/internal/service/blob/get_blob.go @@ -1,22 +1,19 @@ package blob import ( - "bufio" "fmt" - "io" - "os/exec" - "gitlab.com/gitlab-org/gitaly/internal/command" - "gitlab.com/gitlab-org/gitaly/internal/git/catfile" "gitlab.com/gitlab-org/gitaly/internal/helper" pb "gitlab.com/gitlab-org/gitaly-proto/go" - "gitlab.com/gitlab-org/gitaly/streamio" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) +type blobSender func(size int64, oid string, data []byte) error + +// GetBlob might get depricated in favour of the more versatile GetBlobs. func (s *server) GetBlob(in *pb.GetBlobRequest, stream pb.BlobService_GetBlobServer) error { if err := validateRequest(in); err != nil { return grpc.Errorf(codes.InvalidArgument, "GetBlob: %v", err) @@ -27,63 +24,15 @@ func (s *server) GetBlob(in *pb.GetBlobRequest, stream pb.BlobService_GetBlobSer return err } - stdinReader, stdinWriter := io.Pipe() - - cmdArgs := []string{"--git-dir", repoPath, "cat-file", "--batch"} - cmd, err := command.New(stream.Context(), exec.Command(command.GitPath(), cmdArgs...), stdinReader, nil, nil) - if err != nil { - return grpc.Errorf(codes.Internal, "GetBlob: cmd: %v", err) - } - defer stdinWriter.Close() - defer stdinReader.Close() - - if _, err := fmt.Fprintln(stdinWriter, in.Oid); err != nil { - return grpc.Errorf(codes.Internal, "GetBlob: stdin write: %v", err) - } - stdinWriter.Close() - - stdout := bufio.NewReader(cmd) - - objectInfo, err := catfile.ParseObjectInfo(stdout) - if err != nil { - return grpc.Errorf(codes.Internal, "GetBlob: %v", err) - } - if objectInfo.Type != "blob" { - return helper.DecorateError(codes.Unavailable, stream.Send(&pb.GetBlobResponse{})) - } - - readLimit := objectInfo.Size - if in.Limit >= 0 && in.Limit < readLimit { - readLimit = in.Limit - } - firstMessage := &pb.GetBlobResponse{ - Size: objectInfo.Size, - Oid: objectInfo.Oid, - } - - if readLimit == 0 { - return helper.DecorateError(codes.Unavailable, stream.Send(firstMessage)) - } - - sw := streamio.NewWriter(func(p []byte) error { - msg := &pb.GetBlobResponse{} - if firstMessage != nil { - msg = firstMessage - firstMessage = nil + return getBlobs(stream.Context(), repoPath, []string{in.Oid}, in.Limit, func(size int64, oid string, data []byte) error { + resp := &pb.GetBlobResponse{ + Size: size, + Oid: oid, + Data: data, } - msg.Data = p - return stream.Send(msg) - }) - n, err := io.Copy(sw, io.LimitReader(stdout, readLimit)) - if err != nil { - return grpc.Errorf(codes.Unavailable, "GetBlob: send: %v", err) - } - if n != readLimit { - return grpc.Errorf(codes.Unavailable, "GetBlob: short send: %d/%d bytes", n, objectInfo.Size) - } - - return nil + return stream.Send(resp) + }) } func validateRequest(in *pb.GetBlobRequest) error { diff --git a/internal/service/blob/get_blobs.go b/internal/service/blob/get_blobs.go index 951b28610439751fc3810bb3c7d56f398c56bc37..d21266461c0bcd7511669b8a596d2b7e9f826747 100644 --- a/internal/service/blob/get_blobs.go +++ b/internal/service/blob/get_blobs.go @@ -1,10 +1,127 @@ package blob import ( + "bufio" + "context" + "fmt" + "io" + "io/ioutil" + "os/exec" + pb "gitlab.com/gitlab-org/gitaly-proto/go" + "gitlab.com/gitlab-org/gitaly/internal/command" + "gitlab.com/gitlab-org/gitaly/internal/git/catfile" "gitlab.com/gitlab-org/gitaly/internal/helper" + "gitlab.com/gitlab-org/gitaly/streamio" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" ) -func (*server) GetBlobs(*pb.GetBlobsRequest, pb.BlobService_GetBlobsServer) error { - return helper.Unimplemented +func (s *server) GetBlobs(in *pb.GetBlobsRequest, stream pb.BlobService_GetBlobsServer) error { + if err := validateRequests(in); err != nil { + return grpc.Errorf(codes.InvalidArgument, "GetBlob: %v", err) + } + + repoPath, err := helper.GetRepoPath(in.Repository) + if err != nil { + return err + } + + return getBlobs(stream.Context(), repoPath, in.Oids, in.Limit, func(size int64, oid string, data []byte) error { + resp := &pb.GetBlobsResponse{ + Size: size, + Oid: oid, + Data: data, + } + + return stream.Send(resp) + }) +} + +func getBlobs(ctx context.Context, repoPath string, oids []string, limit int64, sender blobSender) error { + stdinReader, stdinWriter := io.Pipe() + + cmdArgs := []string{"--git-dir", repoPath, "cat-file", "--batch"} + cmd, err := command.New(ctx, exec.Command(command.GitPath(), cmdArgs...), stdinReader, nil, nil) + if err != nil { + return grpc.Errorf(codes.Internal, "getBlob: cmd: %v", err) + } + defer stdinWriter.Close() + defer stdinReader.Close() + + stdout := bufio.NewReader(cmd) + + var ( + firstMessage bool + objectInfo *catfile.ObjectInfo + ) + + sw := streamio.NewWriter(func(p []byte) error { + if firstMessage { + firstMessage = false + return sender(objectInfo.Size, objectInfo.Oid, p) + } + return sender(0, "", p) + }) + + for _, oid := range oids { + firstMessage = true + if _, err = fmt.Fprintln(stdinWriter, oid); err != nil { + return grpc.Errorf(codes.Internal, "getBlob: stdin write: %v", err) + } + + objectInfo, err = catfile.ParseObjectInfo(stdout) + if err != nil { + return grpc.Errorf(codes.Internal, "getBlob: %v", err) + } + if objectInfo.Type != "blob" { + return helper.DecorateError(codes.Unavailable, sender(0, "", nil)) + } + + readLimit := objectInfo.Size + if limit >= 0 && limit < readLimit { + readLimit = limit + } + + if readLimit == 0 { + err := sender(objectInfo.Size, objectInfo.Oid, nil) + if err != nil { + return grpc.Errorf(codes.Unavailable, "getBlob: send: %v", err) + } + } + + n, err := io.Copy(sw, io.LimitReader(stdout, readLimit)) + if err != nil { + return grpc.Errorf(codes.Unavailable, "getBlob: send: %v", err) + } + if n != readLimit { + return grpc.Errorf(codes.Unavailable, "getBlob: short send: %d/%d bytes", n, objectInfo.Size) + } + + // +1 because of newlines... + if rest := objectInfo.Size - readLimit + 1; rest > 0 { + n, err := io.Copy(ioutil.Discard, io.LimitReader(stdout, rest)) + if err != nil { + return grpc.Errorf(codes.Unavailable, "getBlob: read: %v", err) + } + if n != rest { + return grpc.Errorf(codes.Unavailable, "getBlob: short read: %d/%d bytes", n, rest) + } + } + } + stdinWriter.Close() + + return cmd.Wait() +} + +func validateRequests(in *pb.GetBlobsRequest) error { + if len(in.GetOids()) == 0 { + return fmt.Errorf("no Oids specified") + } + for _, oid := range in.GetOids() { + if len(oid) == 0 { + return fmt.Errorf("empty Oid found") + } + } + return nil } diff --git a/internal/service/blob/get_blobs_test.go b/internal/service/blob/get_blobs_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f35ed0a1616cb36ee752e1e2b0df3431bfc99547 --- /dev/null +++ b/internal/service/blob/get_blobs_test.go @@ -0,0 +1,237 @@ +package blob + +import ( + "bytes" + "io" + "testing" + + pb "gitlab.com/gitlab-org/gitaly-proto/go" + "gitlab.com/gitlab-org/gitaly/internal/testhelper" + + "github.com/stretchr/testify/require" +) + +func TestSuccessfulGetBlobs(t *testing.T) { + server := runBlobServer(t) + defer server.Stop() + + client, conn := newBlobClient(t, serverSocketPath) + defer conn.Close() + maintenanceMdBlobData := testhelper.MustReadFile(t, "testdata/maintenance-md-blob.txt") + testCases := []struct { + desc string + oids []string + blobs []blob + limit int + }{ + { + desc: "unlimited fetch", + oids: []string{"95d9f0a5e7bb054e9dd3975589b8dfc689e20e88"}, + limit: -1, + blobs: []blob{{data: maintenanceMdBlobData, size: int64(len(maintenanceMdBlobData))}}, + }, + { + desc: "limit larger than blob size", + oids: []string{"95d9f0a5e7bb054e9dd3975589b8dfc689e20e88"}, + limit: len(maintenanceMdBlobData) + 1, + blobs: []blob{{data: maintenanceMdBlobData, size: int64(len(maintenanceMdBlobData))}}, + }, + { + desc: "limit zero", + oids: []string{"95d9f0a5e7bb054e9dd3975589b8dfc689e20e88"}, + limit: 0, + blobs: []blob{{size: int64(len(maintenanceMdBlobData))}}, + }, + { + desc: "limit greater than zero, less than blob size", + oids: []string{"95d9f0a5e7bb054e9dd3975589b8dfc689e20e88"}, + limit: 10, + blobs: []blob{{data: maintenanceMdBlobData[:10], size: int64(len(maintenanceMdBlobData))}}, + }, + { + desc: "large blob", + oids: []string{"08cf843fd8fe1c50757df0a13fcc44661996b4df"}, + limit: 10, + blobs: []blob{{data: []byte{0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 0x4a, 0x46, 0x49, 0x46}, size: 111803}}, + }, + { + desc: "two identical blobs, no limit", + oids: []string{"95d9f0a5e7bb054e9dd3975589b8dfc689e20e88", "95d9f0a5e7bb054e9dd3975589b8dfc689e20e88"}, + limit: -1, + blobs: []blob{ + {data: maintenanceMdBlobData, size: int64(len(maintenanceMdBlobData))}, + {data: maintenanceMdBlobData, size: int64(len(maintenanceMdBlobData))}, + }, + }, + { + desc: "two identical blobs, with limit", + oids: []string{"95d9f0a5e7bb054e9dd3975589b8dfc689e20e88", "95d9f0a5e7bb054e9dd3975589b8dfc689e20e88"}, + limit: 20, + blobs: []blob{ + {data: maintenanceMdBlobData[:20], size: int64(len(maintenanceMdBlobData))}, + {data: maintenanceMdBlobData[:20], size: int64(len(maintenanceMdBlobData))}, + }, + }, + { + desc: "two blobs, with limit", + oids: []string{"95d9f0a5e7bb054e9dd3975589b8dfc689e20e88", "08cf843fd8fe1c50757df0a13fcc44661996b4df"}, + limit: 10, + blobs: []blob{ + {data: maintenanceMdBlobData[:10], size: int64(len(maintenanceMdBlobData))}, + {data: []byte{0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 0x4a, 0x46, 0x49, 0x46}, size: 111803}, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + request := &pb.GetBlobsRequest{ + Repository: testRepo, + Oids: tc.oids, + Limit: int64(tc.limit), + } + + ctx, cancel := testhelper.Context() + defer cancel() + + stream, err := client.GetBlobs(ctx, request) + require.NoError(t, err, "initiate RPC") + + blobs, err := getAllBlobs(stream) + require.NoError(t, err, "consume response") + + require.Equal(t, len(tc.oids), len(blobs)) + + for i, blob := range tc.blobs { + t.Logf("testing oid[%d] %q", i, tc.oids[i]) + require.Equal(t, int64(blobs[i].size), blob.size, "real blob size") + + require.NotEmpty(t, blobs[i].oid) + require.Equal(t, blobs[i].oid, tc.oids[i]) + require.Equal(t, len(blob.data), len(blobs[i].data), "returned data should have the same size") + require.True(t, bytes.Equal(blob.data, blobs[i].data), "returned data exactly as expected for oid %q", tc.oids[i]) + } + }) + } +} + +func TestGetBlobsNotFound(t *testing.T) { + server := runBlobServer(t) + defer server.Stop() + + client, conn := newBlobClient(t, serverSocketPath) + defer conn.Close() + + tests := []struct { + desc string + req pb.GetBlobsRequest + }{ + { + desc: "first of two is non-exist", + req: pb.GetBlobsRequest{ + Repository: testRepo, + Oids: []string{"doesnotexist", "95d9f0a5e7bb054e9dd3975589b8dfc689e20e88"}, // Second exist + }, + }, + { + desc: "second of three is non-exist", + req: pb.GetBlobsRequest{ + Repository: testRepo, + Oids: []string{"95d9f0a5e7bb054e9dd3975589b8dfc689e20e88", "doesnotexist", "95d9f0a5e7bb054e9dd3975589b8dfc689e20e88"}, // Second exist + }, + }, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + stream, err := client.GetBlobs(ctx, &tc.req) + require.NoError(t, err) + + blobs, err := getAllBlobs(stream) + require.NoError(t, err) + + require.Nil(t, blobs) + }) + } +} + +func TestFailedGetBlobsRequestDueToValidationError(t *testing.T) { + server := runBlobServer(t) + defer server.Stop() + + client, conn := newBlobClient(t, serverSocketPath) + defer conn.Close() + oid := "d42783470dc29fde2cf459eb3199ee1d7e3f3a72" + + tests := []struct { + desc string + req pb.GetBlobsRequest + }{ + { + desc: "repo does not exist", + req: pb.GetBlobsRequest{Repository: &pb.Repository{StorageName: "fake", RelativePath: "path"}, Oids: []string{oid}}, + }, + { + desc: "repo is nil", + req: pb.GetBlobsRequest{Repository: nil, Oids: []string{oid}}, + }, + { + desc: "oid list is empty", + req: pb.GetBlobsRequest{Repository: testRepo}, + }, + { + desc: "one oid is empty string", + req: pb.GetBlobsRequest{Repository: testRepo, Oids: []string{"foo", "", "bar"}}, + }, + } + + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := testhelper.Context() + defer cancel() + + stream, err := client.GetBlobs(ctx, &tc.req) + require.NoError(t, err) + _, err = stream.Recv() + require.NotEqual(t, io.EOF, err) + require.Error(t, err) + }) + } +} + +type blob struct { + oid string + size int64 + data []byte +} + +func getAllBlobs(stream pb.BlobService_GetBlobsClient) ([]*blob, error) { + var ( + blobs []*blob + curBlob = &blob{} + err error + ) + + resp, err := stream.Recv() + for err == nil { + if resp.GetOid() != "" { + if curBlob.oid != "" { + blobs = append(blobs, curBlob) + } + curBlob = &blob{oid: resp.GetOid(), size: resp.GetSize()} + } + curBlob.data = append(curBlob.data, resp.GetData()...) + resp, err = stream.Recv() + } + + if curBlob.oid != "" { + blobs = append(blobs, curBlob) + } + + if err != io.EOF { + return nil, err + } + return blobs, nil +}