feat: Modify embedding strategy to first halve the batch size v.s. the batc… [LET-5510] (#5434)

Modify embedding strategy to first halve the batch size v.s. the batch size
This commit is contained in:
Matthew Zhou
2025-10-14 13:50:25 -07:00
committed by Caren Thomas
parent 0543a60538
commit 09ba075cfa
2 changed files with 38 additions and 19 deletions

View File

@@ -714,7 +714,13 @@ class OpenAIClient(LLMClientBase):
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config with chunking and retry logic"""
"""Request embeddings given texts and embedding config with chunking and retry logic
Retry strategy prioritizes reducing batch size before chunk size to maintain retrieval quality:
1. Start with batch_size=2048 (texts per request)
2. On failure, halve batch_size until it reaches 1
3. Only then start reducing chunk_size (for very large individual texts)
"""
if not inputs:
return []
@@ -723,35 +729,48 @@ class OpenAIClient(LLMClientBase):
# track results by original index to maintain order
results = [None] * len(inputs)
# queue of (start_idx, chunk_inputs) to process
chunks_to_process = [(i, inputs[i : i + 2048]) for i in range(0, len(inputs), 2048)]
min_chunk_size = 256
initial_batch_size = 2048
chunks_to_process = [(i, inputs[i : i + initial_batch_size], initial_batch_size) for i in range(0, len(inputs), initial_batch_size)]
min_chunk_size = 128
while chunks_to_process:
tasks = []
task_metadata = []
for start_idx, chunk_inputs in chunks_to_process:
for start_idx, chunk_inputs, current_batch_size in chunks_to_process:
task = client.embeddings.create(model=embedding_config.embedding_model, input=chunk_inputs)
tasks.append(task)
task_metadata.append((start_idx, chunk_inputs))
task_metadata.append((start_idx, chunk_inputs, current_batch_size))
task_results = await asyncio.gather(*tasks, return_exceptions=True)
failed_chunks = []
for (start_idx, chunk_inputs), result in zip(task_metadata, task_results):
for (start_idx, chunk_inputs, current_batch_size), result in zip(task_metadata, task_results):
if isinstance(result, Exception):
# check if we can retry with smaller chunks
if len(chunk_inputs) > min_chunk_size:
# split chunk in half and queue for retry
current_size = len(chunk_inputs)
if current_batch_size > 1:
new_batch_size = max(1, current_batch_size // 2)
logger.warning(
f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. "
f"Reducing batch size from {current_batch_size} to {new_batch_size} and retrying."
)
mid = len(chunk_inputs) // 2
failed_chunks.append((start_idx, chunk_inputs[:mid]))
failed_chunks.append((start_idx + mid, chunk_inputs[mid:]))
failed_chunks.append((start_idx, chunk_inputs[:mid], new_batch_size))
failed_chunks.append((start_idx + mid, chunk_inputs[mid:], new_batch_size))
elif current_size > min_chunk_size:
logger.warning(
f"Embeddings request failed for single item at {start_idx} with size {current_size}. "
f"Splitting individual text content and retrying."
)
mid = len(chunk_inputs) // 2
failed_chunks.append((start_idx, chunk_inputs[:mid], 1))
failed_chunks.append((start_idx + mid, chunk_inputs[mid:], 1))
else:
# can't split further, re-raise the error
logger.error(f"Failed to get embeddings for chunk starting at {start_idx} even with minimum size {min_chunk_size}")
logger.error(
f"Failed to get embeddings for chunk starting at {start_idx} even with batch_size=1 "
f"and minimum chunk size {min_chunk_size}. Error: {result}"
)
raise result
else:
embeddings = [r.embedding for r in result.data]

View File

@@ -102,7 +102,7 @@ async def test_openai_embedding_chunking(default_user):
@pytest.mark.asyncio
async def test_openai_embedding_retry_logic(default_user):
"""Test that failed chunks are retried with halved size"""
"""Test that failed chunks are retried with reduced batch size"""
embedding_config = EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
@@ -123,7 +123,7 @@ async def test_openai_embedding_retry_logic(default_user):
call_count += 1
input_size = len(kwargs["input"])
# fail on first attempt for large chunks only
# fail on first attempt for large batches only
if input_size == 2048 and call_count <= 2:
raise Exception("Too many inputs")
@@ -138,7 +138,7 @@ async def test_openai_embedding_retry_logic(default_user):
assert len(embeddings) == 3000
# initial: 2 chunks (2048, 952)
# after retry: first 2048 splits into 2x1024, so total 3 successful calls + 2 failed = 5
# after retry: first 2048 splits into 2x1024 with reduced batch_size, so total 3 successful calls + 2 failed = 5
assert call_count > 3