feat: Robustify openai embedding [LET-4256] (#4478)

* Robustify embedding

* Remove unecessary imports

* Add test embeddings
This commit is contained in:
Matthew Zhou
2025-09-08 17:18:54 -07:00
committed by GitHub
parent 5eca45cb56
commit 73e8ee73fc
2 changed files with 186 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import os
from typing import List, Optional
@@ -319,13 +320,53 @@ class OpenAIClient(LLMClientBase):
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
"""Request embeddings given texts and embedding config with chunking and retry logic"""
if not inputs:
return []
kwargs = self._prepare_client_kwargs_embedding(embedding_config)
client = AsyncOpenAI(**kwargs)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]
# track results by original index to maintain order
results = [None] * len(inputs)
# queue of (start_idx, chunk_inputs) to process
chunks_to_process = [(i, inputs[i : i + 2048]) for i in range(0, len(inputs), 2048)]
min_chunk_size = 256
while chunks_to_process:
tasks = []
task_metadata = []
for start_idx, chunk_inputs in chunks_to_process:
task = client.embeddings.create(model=embedding_config.embedding_model, input=chunk_inputs)
tasks.append(task)
task_metadata.append((start_idx, chunk_inputs))
task_results = await asyncio.gather(*tasks, return_exceptions=True)
failed_chunks = []
for (start_idx, chunk_inputs), result in zip(task_metadata, task_results):
if isinstance(result, Exception):
# check if we can retry with smaller chunks
if len(chunk_inputs) > min_chunk_size:
# split chunk in half and queue for retry
mid = len(chunk_inputs) // 2
failed_chunks.append((start_idx, chunk_inputs[:mid]))
failed_chunks.append((start_idx + mid, chunk_inputs[mid:]))
else:
# can't split further, re-raise the error
logger.error(f"Failed to get embeddings for chunk starting at {start_idx} even with minimum size {min_chunk_size}")
raise result
else:
embeddings = [r.embedding for r in result.data]
for i, embedding in enumerate(embeddings):
results[start_idx + i] = embedding
chunks_to_process = failed_chunks
return results
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:

View File

@@ -1,11 +1,13 @@
import glob
import json
import os
from unittest.mock import AsyncMock, patch
import pytest
from letta.config import LettaConfig
from letta.llm_api.llm_client import LLMClient
from letta.llm_api.openai_client import OpenAIClient
from letta.schemas.embedding_config import EmbeddingConfig
from letta.server.server import SyncServer
@@ -60,3 +62,142 @@ async def test_embeddings(embedding_config: EmbeddingConfig, default_user):
embeddings = await embedding_client.request_embeddings([test_input], embedding_config)
assert len(embeddings) == 1
assert len(embeddings[0]) == embedding_config.embedding_dim
@pytest.mark.asyncio
async def test_openai_embedding_chunking(default_user):
"""Test that large inputs are split into 2048-sized chunks"""
embedding_config = EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-3-small",
embedding_dim=1536,
)
client = OpenAIClient(actor=default_user)
with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
async def mock_create(**kwargs):
input_size = len(kwargs["input"])
assert input_size <= 2048 # verify chunking
mock_response = AsyncMock()
mock_response.data = [AsyncMock(embedding=[0.1] * 1536) for _ in range(input_size)]
return mock_response
mock_client.embeddings.create.side_effect = mock_create
# test with 5000 inputs (should be split into 3 chunks: 2048, 2048, 904)
test_inputs = [f"Input {i}" for i in range(5000)]
embeddings = await client.request_embeddings(test_inputs, embedding_config)
assert len(embeddings) == 5000
assert mock_client.embeddings.create.call_count == 3
@pytest.mark.asyncio
async def test_openai_embedding_retry_logic(default_user):
"""Test that failed chunks are retried with halved size"""
embedding_config = EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-3-small",
embedding_dim=1536,
)
client = OpenAIClient(actor=default_user)
with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
call_count = 0
async def mock_create(**kwargs):
nonlocal call_count
call_count += 1
input_size = len(kwargs["input"])
# fail on first attempt for large chunks only
if input_size == 2048 and call_count <= 2:
raise Exception("Too many inputs")
mock_response = AsyncMock()
mock_response.data = [AsyncMock(embedding=[0.1] * 1536) for _ in range(input_size)]
return mock_response
mock_client.embeddings.create.side_effect = mock_create
test_inputs = [f"Input {i}" for i in range(3000)]
embeddings = await client.request_embeddings(test_inputs, embedding_config)
assert len(embeddings) == 3000
# initial: 2 chunks (2048, 952)
# after retry: first 2048 splits into 2x1024, so total 3 successful calls + 2 failed = 5
assert call_count > 3
@pytest.mark.asyncio
async def test_openai_embedding_order_preserved(default_user):
"""Test that order is maintained despite chunking and retries"""
embedding_config = EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-3-small",
embedding_dim=1536,
)
client = OpenAIClient(actor=default_user)
with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
async def mock_create(**kwargs):
# return embeddings where first element = input index
mock_response = AsyncMock()
mock_response.data = []
for text in kwargs["input"]:
idx = int(text.split()[-1])
embedding = [float(idx)] + [0.0] * 1535
mock_response.data.append(AsyncMock(embedding=embedding))
return mock_response
mock_client.embeddings.create.side_effect = mock_create
test_inputs = [f"Text {i}" for i in range(100)]
embeddings = await client.request_embeddings(test_inputs, embedding_config)
assert len(embeddings) == 100
for i in range(100):
assert embeddings[i][0] == float(i)
@pytest.mark.asyncio
async def test_openai_embedding_minimum_chunk_failure(default_user):
"""Test that persistent failures at minimum chunk size raise error"""
embedding_config = EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_model="text-embedding-3-small",
embedding_dim=1536,
)
client = OpenAIClient(actor=default_user)
with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
async def mock_create(**kwargs):
raise Exception("API error")
mock_client.embeddings.create.side_effect = mock_create
# test with 300 inputs - will retry down to 256 minimum then fail
test_inputs = [f"Input {i}" for i in range(300)]
with pytest.raises(Exception, match="API error"):
await client.request_embeddings(test_inputs, embedding_config)