Fix: prevent empty embedding batches from causing memory spikes (#6230)

Root cause: When splitting failed embedding batches, mid=0 for single
items created empty chunks. These empty chunks were then processed,
creating hundreds of no-op tasks that consumed memory.

Crash pattern from logs:
- 600+ 'batch_size=0' embedding tasks created
- Memory spiked 531 MB → 4.9 GB
- Pod crashed

Fixes:
1. Skip empty chunks before creating tasks
2. Guard chunk splits to prevent empty slices (mid = max(1, len//2))
3. Break early if all chunks are empty

This prevents the asyncio.gather() from creating thousands of empty
coroutines that exhaust memory.
This commit is contained in:
Kian Jones
2025-11-17 17:52:40 -08:00
committed by Caren Thomas
parent a6b19bf3aa
commit ddb6f3836e

View File

@@ -779,6 +779,10 @@ class OpenAIClient(LLMClientBase):
task_metadata = []
for start_idx, chunk_inputs, current_batch_size in chunks_to_process:
if not chunk_inputs:
logger.warning(f"Skipping empty chunk at start_idx={start_idx}")
continue
logger.info(
f"Creating embedding task: start_idx={start_idx}, batch_size={len(chunk_inputs)}, "
f"first_input_len={len(chunk_inputs[0]) if chunk_inputs else 0}, "
@@ -788,6 +792,10 @@ class OpenAIClient(LLMClientBase):
tasks.append(task)
task_metadata.append((start_idx, chunk_inputs, current_batch_size))
if not tasks:
logger.warning("All chunks were empty, skipping embedding request")
break
task_results = await asyncio.gather(*tasks, return_exceptions=True)
failed_chunks = []
@@ -801,17 +809,21 @@ class OpenAIClient(LLMClientBase):
f"Embeddings request failed for batch starting at {start_idx} with size {current_size}. "
f"Reducing batch size from {current_batch_size} to {new_batch_size} and retrying."
)
mid = len(chunk_inputs) // 2
failed_chunks.append((start_idx, chunk_inputs[:mid], new_batch_size))
failed_chunks.append((start_idx + mid, chunk_inputs[mid:], new_batch_size))
mid = max(1, len(chunk_inputs) // 2)
if chunk_inputs[:mid]:
failed_chunks.append((start_idx, chunk_inputs[:mid], new_batch_size))
if chunk_inputs[mid:]:
failed_chunks.append((start_idx + mid, chunk_inputs[mid:], new_batch_size))
elif current_size > min_chunk_size:
logger.warning(
f"Embeddings request failed for single item at {start_idx} with size {current_size}. "
f"Splitting individual text content and retrying."
)
mid = len(chunk_inputs) // 2
failed_chunks.append((start_idx, chunk_inputs[:mid], 1))
failed_chunks.append((start_idx + mid, chunk_inputs[mid:], 1))
mid = max(1, len(chunk_inputs) // 2)
if chunk_inputs[:mid]:
failed_chunks.append((start_idx, chunk_inputs[:mid], 1))
if chunk_inputs[mid:]:
failed_chunks.append((start_idx + mid, chunk_inputs[mid:], 1))
else:
chunk_preview = str(chunk_inputs)[:500] if chunk_inputs else "None"
logger.error(