diff --git a/workhorse/internal/ai_assist/duoworkflow/runner.go b/workhorse/internal/ai_assist/duoworkflow/runner.go index a8a2d883d9040abe91d7efbe7876beac3216eba0..693cbc13f25e767e846d0269ef73545416f2b896 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner.go @@ -45,32 +45,47 @@ type runner struct { } func (r *runner) Execute(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + errCh := make(chan error, 2) go func() { for { - if err := r.handleWebSocketMessage(); err != nil { - errCh <- err + select { + case <-ctx.Done(): + errCh <- ctx.Err() return + default: + if err := r.handleWebSocketMessage(); err != nil { + errCh <- err + return + } } } }() go func() { for { - action, err := r.wf.Recv() - if err != nil { - if err == io.EOF { - errCh <- nil // Expected error when a workflow ends - } else { - errCh <- fmt.Errorf("duoworkflow: failed to read a gRPC message: %v", err) - } + select { + case <-ctx.Done(): + errCh <- ctx.Err() return - } + default: + action, err := r.wf.Recv() + if err != nil { + if err == io.EOF { + errCh <- nil // Expected when workflow ends + } else { + errCh <- fmt.Errorf("duoworkflow: failed to read a gRPC message: %v", err) + } + return + } - if err := r.handleAgentAction(ctx, action); err != nil { - errCh <- err - return + if err := r.handleAgentAction(ctx, action); err != nil { + errCh <- err + return + } } } }() diff --git a/workhorse/internal/ai_assist/duoworkflow/runner_test.go b/workhorse/internal/ai_assist/duoworkflow/runner_test.go index 4b11db4ec3fbec5e67f3a6c2941e632648b9da50..5795e3c0cac6e878439a734c9cb8de2452e4dfdc 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner_test.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner_test.go @@ -9,11 +9,13 @@ import ( "net/http/httptest" "net/url" "testing" + "time" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" pb "gitlab.com/gitlab-org/modelops/applied-ml/code-suggestions/ai-assist/clients/gopb/contract" + "go.uber.org/goleak" "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" ) @@ -89,53 +91,56 @@ func (m *mockWorkflowStream) CloseSend() error { } func TestRunner_Execute(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + ) + tests := []struct { name string wsMessages [][]byte recvActions []*pb.Action writeMsgCount int sendEventsCount int - expectedErrMsg string - wsBlockCh chan bool - wfBlockCh chan bool + expectedErrMsgs []string // Accept multiple possible errors + expectSuccess bool }{ { - name: "ws messages", + name: "messages processed", wsMessages: [][]byte{[]byte(`{"type": "test"}`), []byte(`{"type": "test2"}`)}, - wfBlockCh: make(chan bool), sendEventsCount: 2, - expectedErrMsg: "handleWebSocketMessage: failed to read a WS message: EOF", + // Both goroutines hit EOF - either could finish first + expectedErrMsgs: []string{ + "handleWebSocketMessage: failed to read a WS message: EOF", + "", // gRPC goroutine finishing first (EOF converted to nil) + }, }, { - name: "wf actions", + name: "workflow actions processed", recvActions: []*pb.Action{{ Action: &pb.Action_RunCommand{ - RunCommand: &pb.RunCommandAction{ - Program: "ls", - }, - }, - }, { - Action: &pb.Action_RunCommand{ - RunCommand: &pb.RunCommandAction{ - Program: "ls", - }, + RunCommand: &pb.RunCommandAction{Program: "ls"}, }, }}, - writeMsgCount: 2, - wsBlockCh: make(chan bool), - expectedErrMsg: "", + writeMsgCount: 1, + // Both goroutines hit EOF - either could finish first + expectedErrMsgs: []string{ + "handleWebSocketMessage: failed to read a WS message: EOF", + "", // gRPC goroutine finishing first (EOF converted to nil) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + ) + mockConn := &mockWebSocketConn{ readMessages: tt.wsMessages, - blockCh: tt.wsBlockCh, } mockWf := &mockWorkflowStream{ recvActions: tt.recvActions, - blockCh: tt.wfBlockCh, } testURL, _ := url.Parse("http://example.com") @@ -150,15 +155,27 @@ func TestRunner_Execute(t *testing.T) { wf: mockWf, } - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err := r.Execute(ctx) - if tt.expectedErrMsg != "" { - require.EqualError(t, err, tt.expectedErrMsg) - } else { - require.NoError(t, err) + // Check if the error matches any of the expected possibilities + errorMatched := false + for _, expectedErr := range tt.expectedErrMsgs { + if expectedErr == "" && err == nil { + errorMatched = true + break + } + if expectedErr != "" && err != nil && err.Error() == expectedErr { + errorMatched = true + break + } } + require.True(t, errorMatched, + "Expected one of %v, got: %v", tt.expectedErrMsgs, err) + require.Len(t, mockWf.sendEvents, tt.sendEventsCount) require.Len(t, mockConn.writeMessages, tt.writeMsgCount) })