feat: Modify embedding strategy to first halve the batch size v.s. the batc… [LET-5510] (#5434)
Modify embedding strategy to first halve the batch size v.s. the batch size
This commit is contained in:
committed by
Caren Thomas
parent
0543a60538
commit
09ba075cfa
@@ -714,7 +714,13 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config with chunking and retry logic"""
|
||||
"""Request embeddings given texts and embedding config with chunking and retry logic
|
||||
|
||||
Retry strategy prioritizes reducing batch size before chunk size to maintain retrieval quality:
|
||||
1. Start with batch_size=2048 (texts per request)
|
||||
2. On failure, halve batch_size until it reaches 1
|
||||
3. Only then start reducing chunk_size (for very large individual texts)
|
||||
"""
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
@@ -723,35 +729,48 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
# track results by original index to maintain order
|
||||
results = [None] * len(inputs)
|
||||
|
||||
# queue of (start_idx, chunk_inputs) to process
|
||||
chunks_to_process = [(i, inputs[i : i + 2048]) for i in range(0, len(inputs), 2048)]
|
||||
|
||||
min_chunk_size = 256
|
||||
initial_batch_size = 2048
|
||||
chunks_to_process = [(i, inputs[i : i + initial_batch_size], initial_batch_size) for i in range(0, len(inputs), initial_batch_size)]
|
||||
min_chunk_size = 128
|
||||
|
||||
while chunks_to_process:
|
||||
tasks = []
|
||||
task_metadata = []
|
||||
|
||||
for start_idx, chunk_inputs in chunks_to_process:
|
||||
for start_idx, chunk_inputs, current_batch_size in chunks_to_process:
|
||||
task = client.embeddings.create(model=embedding_config.embedding_model, input=chunk_inputs)
|
||||
tasks.append(task)
|
||||
task_metadata.append((start_idx, chunk_inputs))
|
||||
task_metadata.append((start_idx, chunk_inputs, current_batch_size))
|
||||
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
failed_chunks = []
|
||||
for (start_idx, chunk_inputs), result in zip(task_metadata, task_results):
|
||||
for (start_idx, chunk_inputs, current_batch_size), result in zip(task_metadata, task_results):
|
||||
if isinstance(result, Exception):
|
||||
# check if we can retry with smaller chunks
|
||||
if len(chunk_inputs) > min_chunk_size:
|
||||
# split chunk in half and queue for retry
|
||||
current_size = len(chunk_inputs)
|
||||
|
||||
if current_batch_size > 1:
|
||||
new_batch_size = max(1, current_batch_size // 2)
|
||||
logger.warning(
|
||||
f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. "
|
||||
f"Reducing batch size from {current_batch_size} to {new_batch_size} and retrying."
|
||||
)
|
||||
mid = len(chunk_inputs) // 2
|
||||
failed_chunks.append((start_idx, chunk_inputs[:mid]))
|
||||
failed_chunks.append((start_idx + mid, chunk_inputs[mid:]))
|
||||
failed_chunks.append((start_idx, chunk_inputs[:mid], new_batch_size))
|
||||
failed_chunks.append((start_idx + mid, chunk_inputs[mid:], new_batch_size))
|
||||
elif current_size > min_chunk_size:
|
||||
logger.warning(
|
||||
f"Embeddings request failed for single item at {start_idx} with size {current_size}. "
|
||||
f"Splitting individual text content and retrying."
|
||||
)
|
||||
mid = len(chunk_inputs) // 2
|
||||
failed_chunks.append((start_idx, chunk_inputs[:mid], 1))
|
||||
failed_chunks.append((start_idx + mid, chunk_inputs[mid:], 1))
|
||||
else:
|
||||
# can't split further, re-raise the error
|
||||
logger.error(f"Failed to get embeddings for chunk starting at {start_idx} even with minimum size {min_chunk_size}")
|
||||
logger.error(
|
||||
f"Failed to get embeddings for chunk starting at {start_idx} even with batch_size=1 "
|
||||
f"and minimum chunk size {min_chunk_size}. Error: {result}"
|
||||
)
|
||||
raise result
|
||||
else:
|
||||
embeddings = [r.embedding for r in result.data]
|
||||
|
||||
Reference in New Issue
Block a user