diff --git a/ee/lib/ai/active_context/embeddings/code/vertex_text.rb b/ee/lib/ai/active_context/embeddings/code/vertex_text.rb index 73503cb950b419ffd020b74a7fac985c51beb1ef..4c63746f121c30d9f165bd31cc6a3ac0949e3667 100644 --- a/ee/lib/ai/active_context/embeddings/code/vertex_text.rb +++ b/ee/lib/ai/active_context/embeddings/code/vertex_text.rb @@ -5,17 +5,71 @@ module ActiveContext module Embeddings module Code class VertexText < ::ActiveContext::Embeddings - def self.generate_embeddings(content, unit_primitive:, model: nil, user: nil) - action = 'embedding' - embeddings = Gitlab::Llm::VertexAi::Embeddings::Text.new( - content, - user: user, - tracking_context: { action: action }, - unit_primitive: unit_primitive, - model: model - ).execute - - embeddings.all?(Array) ? embeddings : [embeddings] + EMBEDDINGS_MODEL_CLASS = Gitlab::Llm::VertexAi::Embeddings::Text + + class << self + # The caller of the `generate_embeddings` method should already have estimated + # calculations of the size of `contents` so as not to exceed limits. + # However, we cannot be certain that those calculations are accurate, + # so we still need to handle the possibility of a "token limits exceeded" error here. + def generate_embeddings(contents, unit_primitive:, model: nil, user: nil) + tracking_context = { action: 'embedding' } + + generate_with_recursive_batch_splitting( + contents, + unit_primitive: unit_primitive, + tracking_context: tracking_context, + model: model, + user: user + ) + end + + private + + # This handles the `TokenLimitExceededError` coming from the embeddings generation call. + # If the `TokenLimitExceededError` occurs, the `contents` array is split into 2 + # and the embeddings generation is called for each half batch. + # This has to be done recursively because the new half batch might still exceed limits. + def generate_with_recursive_batch_splitting( + contents, + unit_primitive:, + tracking_context:, + model: nil, + user: nil + ) + embeddings = EMBEDDINGS_MODEL_CLASS.new( + contents, + user: user, + tracking_context: tracking_context, + unit_primitive: unit_primitive, + model: model + ).execute + + embeddings.all?(Array) ? embeddings : [embeddings] + + rescue EMBEDDINGS_MODEL_CLASS::TokenLimitExceededError => e + contents_count = contents.length + if contents_count == 1 + # if we are still getting a `TokenLimitExceededError` even with a single content input, raise an error + raise StandardError, "Token limit exceeded for single content input: #{e.message.inspect}" + end + + # split the contents input into 2 arrays and recursively call + # `generate_with_recursive_batch_splitting` + embeddings = [] + half_batch_size = (contents_count / 2.0).ceil + contents.each_slice(half_batch_size) do |batch_contents| + embeddings += generate_with_recursive_batch_splitting( + batch_contents, + unit_primitive: unit_primitive, + model: model, + user: user, + tracking_context: tracking_context + ) + end + + embeddings + end end end end diff --git a/ee/spec/lib/ai/active_context/embeddings/code/vertex_text_spec.rb b/ee/spec/lib/ai/active_context/embeddings/code/vertex_text_spec.rb index e16d3d2afef3941ee249d49a61540fc393758021..bab329346bca1521d05e5907784fc3b4c3d90ccd 100644 --- a/ee/spec/lib/ai/active_context/embeddings/code/vertex_text_spec.rb +++ b/ee/spec/lib/ai/active_context/embeddings/code/vertex_text_spec.rb @@ -13,6 +13,8 @@ ) end + let(:llm_class) { Gitlab::Llm::VertexAi::Embeddings::Text } + let(:contents) { %w[content-1 content-2 content-3 content-4 content-5] } let(:unit_primitive) { 'embeddings_generation' } let(:model) { 'test-embedding-model' } @@ -28,19 +30,14 @@ ] end - let(:embeddings_model) do - instance_double( - Gitlab::Llm::VertexAi::Embeddings::Text, - execute: embeddings - ) - end + let(:embeddings_model) { instance_double(llm_class, execute: embeddings) } before do - allow(Gitlab::Llm::VertexAi::Embeddings::Text).to receive(:new).and_return(embeddings_model) + allow(llm_class).to receive(:new).and_return(embeddings_model) end it 'initializes the correct model class with the expected parameters' do - expect(Gitlab::Llm::VertexAi::Embeddings::Text).to receive(:new).with( + expect(llm_class).to receive(:new).with( contents, user: user, tracking_context: { action: 'embedding' }, @@ -50,7 +47,103 @@ expect(embeddings_model).to receive(:execute).and_return(embeddings) - generate_embeddings + expect(generate_embeddings).to eq embeddings + end + + context 'when running into token limits exceeded error' do + before do + allow(llm_class).to receive(:new) do |arg_contents| + if arg_contents.length >= 3 + embeddings_model_with_error + elsif arg_contents == contents_1_2 + embeddings_model_for_content_1_2 + elsif arg_contents == contents_3 + embeddings_model_for_content_3 + elsif arg_contents == contents_4_5 + embeddings_model_for_content_4_5 + end + end + end + + let(:token_limits_exceeded_error_class) { llm_class::TokenLimitExceededError } + + let(:embeddings_model_with_error) do + instance_double(llm_class).tap do |llm_model| + allow(llm_model).to receive(:execute).and_raise(token_limits_exceeded_error_class) + end + end + + let(:embeddings_model_for_content_1_2) do + instance_double(llm_class, execute: [[1, 1], [2, 2]]) + end + + let(:embeddings_model_for_content_3) do + instance_double(llm_class, execute: [[3, 3]]) + end + + let(:embeddings_model_for_content_4_5) do + instance_double(llm_class, execute: [[4, 4], [5, 5]]) + end + + let(:contents_batch_size_5) { contents } + let(:contents_batch_size_3) { %w[content-1 content-2 content-3] } + let(:contents_1_2) { %w[content-1 content-2] } + let(:contents_3) { ['content-3'] } + let(:contents_4_5) { %w[content-4 content-5] } + + it 'recursively splits the batch size and eventually succeeds' do + # in the `before` setup, we made sure that embeddings generation throws an error + # if the size of the `contents` input is 3 or greater + + # 1 - attempt to generate embeddings for the entire `contents` + # %w[content-1 content-2 content-3 content-4 content-5] + # with batch_size = 5, this throws an error + expect(llm_class).to receive(:new).with(contents_batch_size_5, anything).ordered + + # 2 - the batch is split in 2, and attempt to generate embeddings for the first half + # %w[content-1 content-2 content-3] + # with batch_size = 3, this throws an error + expect(llm_class).to receive(:new).with(contents_batch_size_3, anything).ordered + + # 3 - split the first half batch from step 2 even further, the second batch remains as-is + # %w[content-1 content-2] - 1st split of the 1st half + # %w[content-3] - 2nd split of the 1st half + # %w[content-4 content-5] - no split for the 2nd half of the original batch, + # because there are only 2 inputs + expect(llm_class).to receive(:new).with(contents_1_2, anything).ordered + expect(llm_class).to receive(:new).with(contents_3, anything).ordered + expect(llm_class).to receive(:new).with(contents_4_5, anything).ordered + + # after all the recursive batch splitting, + # we still expect the call to `generate_embeddings` to + # return the embeddings for *all* the contents of the original batch + expect(generate_embeddings).to eq embeddings + end + + context 'when running into token limits exceeded for a single input' do + let(:embeddings_model_for_content_3) do + instance_double(llm_class).tap do |llm_model| + allow(llm_model).to receive(:execute).and_raise(token_limits_exceeded_error_class, "some error") + end + end + + it 'recursively splits the batch size but eventually fails' do + # it tries to recursively split the batch until it gets to + # the single-input batch for `contents_3` which raises an error + expect(llm_class).to receive(:new).with(contents_batch_size_5, anything).ordered + expect(llm_class).to receive(:new).with(contents_batch_size_3, anything).ordered + expect(llm_class).to receive(:new).with(contents_1_2, anything).ordered + expect(llm_class).to receive(:new).with(contents_3, anything).ordered + + # it no longer tries to generate embeddings for the batch for `contents_4_5` + expect(llm_class).not_to receive(:new).with(contents_4_5, anything) + + expect { generate_embeddings }.to raise_error( + StandardError, + "Token limit exceeded for single content input: \"some error\"" + ) + end + end end end end