diff --git a/internal/backchannel/backchannel.go b/internal/backchannel/backchannel.go index a429ea8773b9a88b9d6b4784ac1ddb35ecba1e91..a7c03e6c8389f219212e8712efec4d16852006da 100644 --- a/internal/backchannel/backchannel.go +++ b/internal/backchannel/backchannel.go @@ -38,25 +38,6 @@ import ( // magicBytes are sent by the client to server to identify as a multiplexing aware client. var magicBytes = []byte("backchannel") -// muxConfig returns a new config to use with the multiplexing session. -func muxConfig(logger io.Writer) *yamux.Config { - cfg := yamux.DefaultConfig() - cfg.LogOutput = logger - // The server only accepts a single stream from the client, which is the client's gRPC stream. - // The backchannel server should only receive a single stream from the server. As such, we can - // limit maximum pending streams to 1 as there should never be more streams waiting. - cfg.AcceptBacklog = 1 - // gRPC is already configured to send keep alives so we don't need yamux to do this for us. - // gRPC is a better choice as it sends the keep alives also to non-multiplexed connections. - cfg.EnableKeepAlive = false - // MaxStreamWindowSize configures the maximum receive buffer size for each stream. The sender - // is allowed to send the configured amount of bytes without receiving an acknowledgement from the - // receiver. This is can have a big impact on throughput as the latency increases, as the sender - // can't proceed sending without receiving an acknowledgement back. - cfg.MaxStreamWindowSize = 16 * 1024 * 1024 - return cfg -} - // connCloser wraps a net.Conn and calls the provided close function instead when Close // is called. type connCloser struct { @@ -66,3 +47,39 @@ type connCloser struct { // Close calls the provided close function. func (cc connCloser) Close() error { return cc.close() } + +type options struct { + yamuxConfig *yamux.Config +} + +// A Option sets options such as yamux configurations for backchannel +type Option func(*options) + +// WithYamuxConfig customizes the yamux configuration used in backchannel +func WithYamuxConfig(yamuxConfig *yamux.Config) Option { + return func(opts *options) { opts.yamuxConfig = yamuxConfig } +} + +func defaultBackchannelOptions(logger io.Writer) *options { + yamuxConf := yamux.DefaultConfig() + // The server only accepts a single stream from the client, which is the client's gRPC stream. + // The backchannel server should only receive a single stream from the server. As such, we can + // limit maximum pending streams to 1 as there should never be more streams waiting. + yamuxConf.AcceptBacklog = 1 + + // MaxStreamWindowSize configures the maximum receive buffer size for each stream. The sender + // is allowed to send the configured amount of bytes without receiving an acknowledgement from the + // receiver. This is can have a big impact on throughput as the latency increases, as the sender + // can't proceed sending without receiving an acknowledgement back. + yamuxConf.MaxStreamWindowSize = 16 * 1024 * 1024 + + // gRPC is already configured to send keep alives so we don't need yamux to do this for us. + // gRPC is a better choice as it sends the keep alives also to non-multiplexed connections. + yamuxConf.EnableKeepAlive = false + + yamuxConf.LogOutput = logger + + return &options{ + yamuxConfig: yamuxConf, + } +} diff --git a/internal/backchannel/client.go b/internal/backchannel/client.go index fd252d349f654927d02a15727d1f624eaf627b1a..5af88cd0c9346faec7c1909d12a2ea6661c2e365 100644 --- a/internal/backchannel/client.go +++ b/internal/backchannel/client.go @@ -26,27 +26,29 @@ type ServerFactory func() Server // ClientHandshaker implements the client side handshake of the multiplexed connection. type ClientHandshaker struct { - logger *logrus.Entry - serverFactory ServerFactory + logger *logrus.Entry + serverFactory ServerFactory + backchannelOpts []Option } // NewClientHandshaker returns a new client side implementation of the backchannel. The provided // logger is used to log multiplexing errors. -func NewClientHandshaker(logger *logrus.Entry, serverFactory ServerFactory) ClientHandshaker { - return ClientHandshaker{logger: logger, serverFactory: serverFactory} +func NewClientHandshaker(logger *logrus.Entry, serverFactory ServerFactory, opts ...Option) ClientHandshaker { + return ClientHandshaker{logger: logger, serverFactory: serverFactory, backchannelOpts: opts} } // ClientHandshake returns TransportCredentials that perform the client side multiplexing handshake and // start the backchannel Server on the established connections. The transport credentials are used to intiliaze the // connection prior to the multiplexing. func (ch ClientHandshaker) ClientHandshake(tc credentials.TransportCredentials) credentials.TransportCredentials { - return clientHandshake{TransportCredentials: tc, serverFactory: ch.serverFactory, logger: ch.logger} + return clientHandshake{TransportCredentials: tc, serverFactory: ch.serverFactory, logger: ch.logger, backchannelOpts: ch.backchannelOpts} } type clientHandshake struct { credentials.TransportCredentials - serverFactory ServerFactory - logger *logrus.Entry + serverFactory ServerFactory + logger *logrus.Entry + backchannelOpts []Option } func (ch clientHandshake) ClientHandshake(ctx context.Context, serverName string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { @@ -90,9 +92,13 @@ func (ch clientHandshake) serve(ctx context.Context, conn net.Conn) (net.Conn, e } logger := ch.logger.WriterLevel(logrus.ErrorLevel) + options := defaultBackchannelOptions(logger) + for _, opt := range ch.backchannelOpts { + opt(options) + } // Initiate the multiplexing session. - muxSession, err := yamux.Client(conn, muxConfig(logger)) + muxSession, err := yamux.Client(conn, options.yamuxConfig) if err != nil { logger.Close() return nil, fmt.Errorf("open multiplexing session: %w", err) diff --git a/internal/backchannel/server.go b/internal/backchannel/server.go index d318e929af6a589d22b71379a4a97943ec691724..378b1a847e36ecfe7e225cabe4ff3741e5cd2241 100644 --- a/internal/backchannel/server.go +++ b/internal/backchannel/server.go @@ -70,9 +70,10 @@ func withSessionInfo(authInfo credentials.AuthInfo, id ID, muxSession *yamux.Ses // ServerHandshaker implements the server side handshake of the multiplexed connection. type ServerHandshaker struct { - registry *Registry - logger *logrus.Entry - dialOpts []grpc.DialOption + registry *Registry + logger *logrus.Entry + dialOpts []grpc.DialOption + backchannelOpts []Option } // Magic is used by listenmux to retrieve the magic string for @@ -83,8 +84,8 @@ func (s *ServerHandshaker) Magic() string { return string(magicBytes) } // are handshaked prior to initializing the multiplexing session. The Registry is used to store the backchannel connections. // DialOptions can be used to set custom dial options for the backchannel connections. They must not contain a dialer or // transport credentials as those set by the handshaker. -func NewServerHandshaker(logger *logrus.Entry, reg *Registry, dialOpts []grpc.DialOption) *ServerHandshaker { - return &ServerHandshaker{registry: reg, logger: logger, dialOpts: dialOpts} +func NewServerHandshaker(logger *logrus.Entry, reg *Registry, dialOpts []grpc.DialOption, opts ...Option) *ServerHandshaker { + return &ServerHandshaker{registry: reg, logger: logger, dialOpts: dialOpts, backchannelOpts: opts} } // Handshake establishes a gRPC ClientConn back to the backchannel client @@ -98,9 +99,13 @@ func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInf // session as well. logger := s.logger.WriterLevel(logrus.ErrorLevel) + options := defaultBackchannelOptions(logger) + for _, opt := range s.backchannelOpts { + opt(options) + } // Open the server side of the multiplexing session. - muxSession, err := yamux.Server(conn, muxConfig(logger)) + muxSession, err := yamux.Server(conn, options.yamuxConfig) if err != nil { logger.Close() return nil, nil, fmt.Errorf("create multiplexing session: %w", err) diff --git a/internal/middleware/cancelhandler/cancelhandler.go b/internal/middleware/cancelhandler/cancelhandler.go index 8ee8705227da06acc18cd36c67fa30f653bb0bb3..d8f5f4c53fca2b920475e2d6792f5b965cee7657 100644 --- a/internal/middleware/cancelhandler/cancelhandler.go +++ b/internal/middleware/cancelhandler/cancelhandler.go @@ -2,7 +2,9 @@ package cancelhandler import ( "context" + "errors" + "github.com/hashicorp/yamux" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -22,7 +24,24 @@ func Stream(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerIn } func wrapErr(ctx context.Context, err error) error { - if err == nil || ctx.Err() == nil { + if err == nil { + return err + } + + // The gRPC may not be aware that sidechannel clients hang up. Therefore, + // we have to handle yamux errors here. There are two errors for this scenario: + // - The client called Close(). A flagFIN is sent to the server. + // + The server starts force close timer, but is still able to write + // + The client starts a force close timer + // - When the timers are due, any read/write operations raise + // ErrStreamClosed. They send flagRST flag to the other side. + // - When either side receives flagRST, any read/write operations raise + // ErrConnectionReset + if errors.Is(err, yamux.ErrStreamClosed) || errors.Is(err, yamux.ErrConnectionReset) { + return status.Errorf(codes.Canceled, "%v", err) + } + + if ctx.Err() == nil { return err } diff --git a/internal/sidechannel/proxy_test.go b/internal/sidechannel/proxy_test.go index 1501a02d33376aa296575c690b3151fb7e00bdeb..14a2c87c1e5a9f9c4229e5f18219bb886cff841a 100644 --- a/internal/sidechannel/proxy_test.go +++ b/internal/sidechannel/proxy_test.go @@ -72,6 +72,7 @@ func TestUnaryProxy(t *testing.T) { } return &healthpb.HealthCheckResponse{}, nil }, + nil, ) proxyAddr := startServer( @@ -86,6 +87,7 @@ func TestUnaryProxy(t *testing.T) { ctxOut := metadata.IncomingToOutgoing(ctx) return healthpb.NewHealthClient(conn).Check(ctxOut, request) }, + nil, ) ctx, cancel := testhelper.Context() diff --git a/internal/sidechannel/sidechannel.go b/internal/sidechannel/sidechannel.go index 74c7d0ab8f580db6d52b9eb5b2f8283e33b6f26a..56f517e0a5eb9e18e604c6d80b844c9be0933e71 100644 --- a/internal/sidechannel/sidechannel.go +++ b/internal/sidechannel/sidechannel.go @@ -9,6 +9,7 @@ import ( "strconv" "time" + "github.com/hashicorp/yamux" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" "gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/client" @@ -30,6 +31,33 @@ const ( sidechannelMetadataKey = "gitaly-sidechannel-id" ) +type options struct { + yamuxConfig *yamux.Config +} + +func defaultSidechannelOptions(logger io.Writer) *options { + yamuxConf := yamux.DefaultConfig() + + // At the moment, those configurations are the subset of backchannel yamux + // configurations, defined in internal/backchannel/backchannel.go. It's + // subject to change in the near future. + yamuxConf.MaxStreamWindowSize = 16 * 1024 * 1024 + yamuxConf.EnableKeepAlive = false + yamuxConf.LogOutput = logger + + return &options{ + yamuxConfig: yamuxConf, + } +} + +// A Option sets options such as yamux configurations for sidechannel +type Option func(*options) + +// WithYamuxConfig customizes the yamux configuration used in sidechannel +func WithYamuxConfig(yamuxConfig *yamux.Config) Option { + return func(opts *options) { opts.yamuxConfig = yamuxConfig } +} + // OpenSidechannel opens a sidechannel connection from the stream opener // extracted from the current peer connection. func OpenSidechannel(ctx context.Context) (_ *ServerConn, err error) { @@ -132,7 +160,12 @@ func NewServerHandshaker(registry *Registry) *ServerHandshaker { // NewClientHandshaker is used to enable sidechannel support on outbound // gRPC connections. -func NewClientHandshaker(logger *logrus.Entry, registry *Registry) client.Handshaker { +func NewClientHandshaker(logger *logrus.Entry, registry *Registry, opts ...Option) client.Handshaker { + sidechannelOpts := defaultSidechannelOptions(logger.Logger.Out) + for _, opt := range opts { + opt(sidechannelOpts) + } + return backchannel.NewClientHandshaker( logger, func() backchannel.Server { @@ -140,5 +173,6 @@ func NewClientHandshaker(logger *logrus.Entry, registry *Registry) client.Handsh lm.Register(NewServerHandshaker(registry)) return grpc.NewServer(grpc.Creds(lm)) }, + backchannel.WithYamuxConfig(sidechannelOpts.yamuxConfig), ) } diff --git a/internal/sidechannel/sidechannel_test.go b/internal/sidechannel/sidechannel_test.go index 285765b512089a41e4b0d7bb7c9af15e194b4ef0..d1c80922f920564ff078b8e7aaf48204fea3154c 100644 --- a/internal/sidechannel/sidechannel_test.go +++ b/internal/sidechannel/sidechannel_test.go @@ -8,11 +8,16 @@ import ( "net" "sync" "testing" + "time" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/gitaly/v14/internal/backchannel" "gitlab.com/gitlab-org/gitaly/v14/internal/listenmux" + "gitlab.com/gitlab-org/gitaly/v14/internal/middleware/cancelhandler" + "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" healthpb "google.golang.org/grpc/health/grpc_health_v1" ) @@ -41,6 +46,7 @@ func TestSidechannel(t *testing.T) { } return &healthpb.HealthCheckResponse{}, conn.Close() }, + nil, ) conn, registry := dial(t, addr) @@ -98,6 +104,7 @@ func TestSidechannelConcurrency(t *testing.T) { return &healthpb.HealthCheckResponse{}, conn.Close() }, + nil, ) conn, registry := dial(t, addr) @@ -141,15 +148,62 @@ func TestSidechannelConcurrency(t *testing.T) { } } -func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string { +func TestSidechannelCancelled(t *testing.T) { + addr := startServer( + t, + func(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + conn, err := OpenSidechannel(context) + if err != nil { + return nil, err + } + defer conn.Close() + + if _, err := io.Copy(io.Discard, conn); err != nil { + return nil, err + } + + responseData := make([]byte, 64*1024) + for { + // Write into yamux connection until reaching error. + if _, err = conn.Write(responseData); err != nil { + return nil, err + } + } + }, + []grpc.ServerOption{ + grpc.UnaryInterceptor(cancelhandler.Unary), + grpc.StreamInterceptor(cancelhandler.Stream), + }, + withYamuxCfgFastTimeout(), + ) + + conn, registry := dial(t, addr, withYamuxCfgFastTimeout()) + client := healthpb.NewHealthClient(conn) + + ctxOut, waiter := RegisterSidechannel(context.Background(), registry, func(conn *ClientConn) error { + // Send data to the server but not wait for the response + return conn.CloseWrite() + }) + defer waiter.Close() + + _, err := client.Check(ctxOut, &healthpb.HealthCheckRequest{}) + testhelper.RequireGrpcError(t, err, codes.Canceled) +} + +func startServer(t *testing.T, th testHandler, grpcOpts []grpc.ServerOption, sidechannelOpts ...Option) string { t.Helper() + options := defaultSidechannelOptions(logrus.StandardLogger().Writer()) + for _, opt := range sidechannelOpts { + opt(options) + } + lm := listenmux.New(insecure.NewCredentials()) - lm.Register(backchannel.NewServerHandshaker(newLogger(), backchannel.NewRegistry(), nil)) + lm.Register(backchannel.NewServerHandshaker(newLogger(), backchannel.NewRegistry(), nil, backchannel.WithYamuxConfig(options.yamuxConfig))) - opts = append(opts, grpc.Creds(lm)) + grpcOpts = append(grpcOpts, grpc.Creds(lm)) - s := grpc.NewServer(opts...) + s := grpc.NewServer(grpcOpts...) t.Cleanup(func() { s.Stop() }) handler := &server{testHandler: th} @@ -164,9 +218,9 @@ func startServer(t *testing.T, th testHandler, opts ...grpc.ServerOption) string return lis.Addr().String() } -func dial(t *testing.T, addr string) (*grpc.ClientConn, *Registry) { +func dial(t *testing.T, addr string, opts ...Option) (*grpc.ClientConn, *Registry) { registry := NewRegistry() - clientHandshaker := NewClientHandshaker(newLogger(), registry) + clientHandshaker := NewClientHandshaker(newLogger(), registry, opts...) dialOpt := grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())) conn, err := grpc.Dial(addr, dialOpt) @@ -203,3 +257,9 @@ type server struct { func (s *server) Check(context context.Context, request *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { return s.testHandler(context, request) } + +func withYamuxCfgFastTimeout() Option { + return func(options *options) { + options.yamuxConfig.StreamCloseTimeout = 10 * time.Millisecond + } +}