diff --git a/internal/git/pktline/pktline.go b/internal/git/pktline/pktline.go index 665d6cbde64d5a1d7e4843be2e7b5f2c5ed72e99..532fba8577bde8beb8ba9db5f4f21bc0947ef733 100644 --- a/internal/git/pktline/pktline.go +++ b/internal/git/pktline/pktline.go @@ -10,6 +10,8 @@ import ( "io" "strconv" "sync" + + "gitlab.com/gitlab-org/gitaly/v14/streamio" ) const ( @@ -161,13 +163,17 @@ type errNotSideband struct{ pkt string } func (err *errNotSideband) Error() string { return fmt.Sprintf("invalid sideband packet: %q", err.pkt) } -// EachSidebandPacket iterates over a side-band-64k pktline stream. For -// each packet, it will call fn with the band ID and the packet. Fn must -// not retain the packet. +// EachSidebandPacket iterates over a side-band-64k pktline stream until +// it reaches a flush packet. For each packet, it will call fn with the +// band ID and the packet. Fn must not retain the packet. func EachSidebandPacket(r io.Reader, fn func(byte, []byte) error) error { scanner := NewScanner(r) for scanner.Scan() { + if IsFlush(scanner.Bytes()) { + return nil + } + data := Data(scanner.Bytes()) if len(data) == 0 { return &errNotSideband{scanner.Text()} @@ -177,5 +183,48 @@ func EachSidebandPacket(r io.Reader, fn func(byte, []byte) error) error { } } - return scanner.Err() + if err := scanner.Err(); err != nil { + return err + } + + return io.ErrUnexpectedEOF +} + +// SingleBandReader unwraps a flush-terminated sideband-64k stream. It +// expects a sequence of sideband packets all for the same band. The +// returned reader will return EOF when it encounters a flush packet. +// Anything else in the input stream will result in a read error. +func SingleBandReader(r io.Reader, band byte) io.Reader { + scanner := NewScanner(r) + + return streamio.NewReader(func() ([]byte, error) { + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, err + } + return nil, io.ErrUnexpectedEOF + } + + data := scanner.Bytes() + + if IsFlush(data) { + return nil, io.EOF + } + + if len(data) < 5 { + return nil, &errNotSideband{string(data)} + } + + if b := data[4]; b != band { + return nil, errUnexpectedSideband(b) + } + + return data[5:], nil + }) +} + +type errUnexpectedSideband byte + +func (b errUnexpectedSideband) Error() string { + return fmt.Sprintf("unexpected band: %d", b) } diff --git a/internal/git/pktline/pkt_line_test.go b/internal/git/pktline/pktline_test.go similarity index 81% rename from internal/git/pktline/pkt_line_test.go rename to internal/git/pktline/pktline_test.go index 32694a7e0ab4eac1d122d50f873a0585fa2d9ae0..dd26ab875ce127a9632af6ec1ab3ed985e58a38b 100644 --- a/internal/git/pktline/pkt_line_test.go +++ b/internal/git/pktline/pktline_test.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "io" + "io/ioutil" "math" "math/rand" "strings" @@ -279,16 +280,23 @@ func TestEachSidebandPacket(t *testing.T) { }{ { desc: "empty", + in: "0000", out: map[byte]string{}, }, { desc: "empty with failing callback: callback does not run", + in: "0000", out: map[byte]string{}, callback: func(byte, []byte) error { panic("oh no") }, }, { desc: "valid stream", - in: "0008\x00foo0008\x01bar0008\xfequx0008\xffbaz", + in: "0008\x00foo0008\x01bar0008\xfequx0008\xffbaz0000", + out: map[byte]string{0: "foo", 1: "bar", 254: "qux", 255: "baz"}, + }, + { + desc: "valid stream trailing garbage", + in: "0008\x00foo0008\x01bar0008\xfequx0008\xffbaz0000 garbage!!", out: map[byte]string{0: "foo", 1: "bar", 254: "qux", 255: "baz"}, }, { @@ -297,6 +305,11 @@ func TestEachSidebandPacket(t *testing.T) { callback: func(byte, []byte) error { return callbackError }, err: callbackError, }, + { + desc: "valid stream except missing flush", + in: "0008\x00foo0008\x01bar0008\xfequx0008\xffbaz", + err: io.ErrUnexpectedEOF, + }, { desc: "interrupted stream", in: "ffff\x10hello world!!", @@ -334,3 +347,61 @@ func TestEachSidebandPacket(t *testing.T) { }) } } + +func TestSingleBandReader(t *testing.T) { + testCases := []struct { + desc string + in string + out string + err error + }{ + { + desc: "empty", + in: "0000", + out: "", + }, + { + desc: "valid stream", + in: "0008\x00foo0008\x00bar0008\x00qux0008\x00baz0000", + out: "foobarquxbaz", + }, + { + desc: "valid stream trailing garbage", + in: "0008\x00foo0008\x00bar0008\x00qux0008\x00baz0000 garbage!!", + out: "foobarquxbaz", + }, + { + desc: "valid stream except missing flush", + in: "0008\x00foo0008\x00bar0008\x00qux0008\x00baz", + err: io.ErrUnexpectedEOF, + }, + { + desc: "interrupted stream", + in: "ffff\x00hello world!!", + err: io.ErrUnexpectedEOF, + }, + { + desc: "stream without band", + in: "0004", + err: &errNotSideband{pkt: "0004"}, + }, + { + desc: "stream with wrong band", + in: "0005\x01", + err: errUnexpectedSideband(1), + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + out, err := ioutil.ReadAll(SingleBandReader(strings.NewReader(tc.in), 0)) + if tc.err != nil { + require.Equal(t, tc.err, err) + return + } + + require.NoError(t, err) + require.Equal(t, tc.out, string(out)) + }) + } +} diff --git a/internal/gitaly/service/hook/pack_objects.go b/internal/gitaly/service/hook/pack_objects.go index 82c396503b46e013d5f63905e3f6e351546a2045..f1be31c1d18cab61406d90c6fb3b7e9d20521e4b 100644 --- a/internal/gitaly/service/hook/pack_objects.go +++ b/internal/gitaly/service/hook/pack_objects.go @@ -204,6 +204,10 @@ func (s *server) runPackObjects(ctx context.Context, w io.Writer, repo *gitalypb return fmt.Errorf("git-pack-objects: stderr: %q err: %w", stderrBuf.String(), err) } + if err := pktline.WriteFlush(w); err != nil { + return err + } + return nil }