From 09ba075cfa3ed8150b9b7bc4bc0e521e583f55fe Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 14 Oct 2025 13:50:25 -0700 Subject: [PATCH] =?UTF-8?q?feat:=20Modify=20embedding=20strategy=20to=20fi?= =?UTF-8?q?rst=20halve=20the=20batch=20size=20v.s.=20the=20batc=E2=80=A6?= =?UTF-8?q?=20[LET-5510]=20(#5434)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Modify embedding strategy to first halve the batch size v.s. the batch size --- letta/llm_api/openai_client.py | 51 +++++++++++++++++++++++----------- tests/test_embeddings.py | 6 ++-- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 0e8b8e39..4ba11674 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -714,7 +714,13 @@ class OpenAIClient(LLMClientBase): @trace_method async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: - """Request embeddings given texts and embedding config with chunking and retry logic""" + """Request embeddings given texts and embedding config with chunking and retry logic + + Retry strategy prioritizes reducing batch size before chunk size to maintain retrieval quality: + 1. Start with batch_size=2048 (texts per request) + 2. On failure, halve batch_size until it reaches 1 + 3. Only then start reducing chunk_size (for very large individual texts) + """ if not inputs: return [] @@ -723,35 +729,48 @@ class OpenAIClient(LLMClientBase): # track results by original index to maintain order results = [None] * len(inputs) - - # queue of (start_idx, chunk_inputs) to process - chunks_to_process = [(i, inputs[i : i + 2048]) for i in range(0, len(inputs), 2048)] - - min_chunk_size = 256 + initial_batch_size = 2048 + chunks_to_process = [(i, inputs[i : i + initial_batch_size], initial_batch_size) for i in range(0, len(inputs), initial_batch_size)] + min_chunk_size = 128 while chunks_to_process: tasks = [] task_metadata = [] - for start_idx, chunk_inputs in chunks_to_process: + for start_idx, chunk_inputs, current_batch_size in chunks_to_process: task = client.embeddings.create(model=embedding_config.embedding_model, input=chunk_inputs) tasks.append(task) - task_metadata.append((start_idx, chunk_inputs)) + task_metadata.append((start_idx, chunk_inputs, current_batch_size)) task_results = await asyncio.gather(*tasks, return_exceptions=True) failed_chunks = [] - for (start_idx, chunk_inputs), result in zip(task_metadata, task_results): + for (start_idx, chunk_inputs, current_batch_size), result in zip(task_metadata, task_results): if isinstance(result, Exception): - # check if we can retry with smaller chunks - if len(chunk_inputs) > min_chunk_size: - # split chunk in half and queue for retry + current_size = len(chunk_inputs) + + if current_batch_size > 1: + new_batch_size = max(1, current_batch_size // 2) + logger.warning( + f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. " + f"Reducing batch size from {current_batch_size} to {new_batch_size} and retrying." + ) mid = len(chunk_inputs) // 2 - failed_chunks.append((start_idx, chunk_inputs[:mid])) - failed_chunks.append((start_idx + mid, chunk_inputs[mid:])) + failed_chunks.append((start_idx, chunk_inputs[:mid], new_batch_size)) + failed_chunks.append((start_idx + mid, chunk_inputs[mid:], new_batch_size)) + elif current_size > min_chunk_size: + logger.warning( + f"Embeddings request failed for single item at {start_idx} with size {current_size}. " + f"Splitting individual text content and retrying." + ) + mid = len(chunk_inputs) // 2 + failed_chunks.append((start_idx, chunk_inputs[:mid], 1)) + failed_chunks.append((start_idx + mid, chunk_inputs[mid:], 1)) else: - # can't split further, re-raise the error - logger.error(f"Failed to get embeddings for chunk starting at {start_idx} even with minimum size {min_chunk_size}") + logger.error( + f"Failed to get embeddings for chunk starting at {start_idx} even with batch_size=1 " + f"and minimum chunk size {min_chunk_size}. Error: {result}" + ) raise result else: embeddings = [r.embedding for r in result.data] diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 9ae44a9d..1baaa784 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -102,7 +102,7 @@ async def test_openai_embedding_chunking(default_user): @pytest.mark.asyncio async def test_openai_embedding_retry_logic(default_user): - """Test that failed chunks are retried with halved size""" + """Test that failed chunks are retried with reduced batch size""" embedding_config = EmbeddingConfig( embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", @@ -123,7 +123,7 @@ async def test_openai_embedding_retry_logic(default_user): call_count += 1 input_size = len(kwargs["input"]) - # fail on first attempt for large chunks only + # fail on first attempt for large batches only if input_size == 2048 and call_count <= 2: raise Exception("Too many inputs") @@ -138,7 +138,7 @@ async def test_openai_embedding_retry_logic(default_user): assert len(embeddings) == 3000 # initial: 2 chunks (2048, 952) - # after retry: first 2048 splits into 2x1024, so total 3 successful calls + 2 failed = 5 + # after retry: first 2048 splits into 2x1024 with reduced batch_size, so total 3 successful calls + 2 failed = 5 assert call_count > 3