diff --git a/workhorse/internal/ai_assist/duoworkflow/runner.go b/workhorse/internal/ai_assist/duoworkflow/runner.go index 3657eddfe74c9cf018fb1fc0fff018239d53fe61..3cb17e86abc575ce1356d7d49b82a36266e203fe 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner.go @@ -32,6 +32,7 @@ type websocketConn interface { type workflowStream interface { Send(*pb.ClientEvent) error Recv() (*pb.Action, error) + CloseSend() error } type runner struct { @@ -44,32 +45,45 @@ type runner struct { } func (r *runner) Execute(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + errCh := make(chan error, 2) go func() { for { - if err := r.handleWebSocketMessage(); err != nil { - errCh <- err + select { + case <-ctx.Done(): return + default: + if err := r.handleWebSocketMessage(); err != nil { + errCh <- err + return + } } } }() go func() { for { - action, err := r.wf.Recv() - if err != nil { - if err == io.EOF { - errCh <- nil // Expected error when a workflow ends - } else { - errCh <- fmt.Errorf("duoworkflow: failed to read a gRPC message: %v", err) - } + select { + case <-ctx.Done(): return - } + default: + action, err := r.wf.Recv() + if err != nil { + if err == io.EOF { + errCh <- nil // Expected error when a workflow ends + } else { + errCh <- fmt.Errorf("duoworkflow: failed to read a gRPC message: %v", err) + } + return + } - if err := r.handleAgentAction(ctx, action); err != nil { - errCh <- err - return + if err := r.handleAgentAction(ctx, action); err != nil { + errCh <- err + return + } } } }() diff --git a/workhorse/internal/ai_assist/duoworkflow/runner_test.go b/workhorse/internal/ai_assist/duoworkflow/runner_test.go index 0cc56df006f40caaa634a72a0c23fa66a67445d3..4b11db4ec3fbec5e67f3a6c2941e632648b9da50 100644 --- a/workhorse/internal/ai_assist/duoworkflow/runner_test.go +++ b/workhorse/internal/ai_assist/duoworkflow/runner_test.go @@ -84,6 +84,10 @@ func (m *mockWorkflowStream) Recv() (*pb.Action, error) { return action, nil } +func (m *mockWorkflowStream) CloseSend() error { + return nil +} + func TestRunner_Execute(t *testing.T) { tests := []struct { name string