From 0eed3722cc388ea7363fb670cd6cb3341ea91952 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 8 Sep 2025 17:18:54 -0700 Subject: [PATCH] feat: Robustify openai embedding [LET-4256] (#4478) * Robustify embedding * Remove unecessary imports * Add test embeddings --- letta/llm_api/openai_client.py | 49 +++++++++++- tests/test_embeddings.py | 141 +++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+), 4 deletions(-) diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 21c94b0d..7f29da9a 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -1,3 +1,4 @@ +import asyncio import os from typing import List, Optional @@ -319,13 +320,53 @@ class OpenAIClient(LLMClientBase): @trace_method async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: - """Request embeddings given texts and embedding config""" + """Request embeddings given texts and embedding config with chunking and retry logic""" + if not inputs: + return [] + kwargs = self._prepare_client_kwargs_embedding(embedding_config) client = AsyncOpenAI(**kwargs) - response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs) - # TODO: add total usage - return [r.embedding for r in response.data] + # track results by original index to maintain order + results = [None] * len(inputs) + + # queue of (start_idx, chunk_inputs) to process + chunks_to_process = [(i, inputs[i : i + 2048]) for i in range(0, len(inputs), 2048)] + + min_chunk_size = 256 + + while chunks_to_process: + tasks = [] + task_metadata = [] + + for start_idx, chunk_inputs in chunks_to_process: + task = client.embeddings.create(model=embedding_config.embedding_model, input=chunk_inputs) + tasks.append(task) + task_metadata.append((start_idx, chunk_inputs)) + + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + failed_chunks = [] + for (start_idx, chunk_inputs), result in zip(task_metadata, task_results): + if isinstance(result, Exception): + # check if we can retry with smaller chunks + if len(chunk_inputs) > min_chunk_size: + # split chunk in half and queue for retry + mid = len(chunk_inputs) // 2 + failed_chunks.append((start_idx, chunk_inputs[:mid])) + failed_chunks.append((start_idx + mid, chunk_inputs[mid:])) + else: + # can't split further, re-raise the error + logger.error(f"Failed to get embeddings for chunk starting at {start_idx} even with minimum size {min_chunk_size}") + raise result + else: + embeddings = [r.embedding for r in result.data] + for i, embedding in enumerate(embeddings): + results[start_idx + i] = embedding + + chunks_to_process = failed_chunks + + return results @trace_method def handle_llm_error(self, e: Exception) -> Exception: diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 6dd38862..a4c13791 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -1,11 +1,13 @@ import glob import json import os +from unittest.mock import AsyncMock, patch import pytest from letta.config import LettaConfig from letta.llm_api.llm_client import LLMClient +from letta.llm_api.openai_client import OpenAIClient from letta.schemas.embedding_config import EmbeddingConfig from letta.server.server import SyncServer @@ -60,3 +62,142 @@ async def test_embeddings(embedding_config: EmbeddingConfig, default_user): embeddings = await embedding_client.request_embeddings([test_input], embedding_config) assert len(embeddings) == 1 assert len(embeddings[0]) == embedding_config.embedding_dim + + +@pytest.mark.asyncio +async def test_openai_embedding_chunking(default_user): + """Test that large inputs are split into 2048-sized chunks""" + embedding_config = EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_model="text-embedding-3-small", + embedding_dim=1536, + ) + + client = OpenAIClient(actor=default_user) + + with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + async def mock_create(**kwargs): + input_size = len(kwargs["input"]) + assert input_size <= 2048 # verify chunking + mock_response = AsyncMock() + mock_response.data = [AsyncMock(embedding=[0.1] * 1536) for _ in range(input_size)] + return mock_response + + mock_client.embeddings.create.side_effect = mock_create + + # test with 5000 inputs (should be split into 3 chunks: 2048, 2048, 904) + test_inputs = [f"Input {i}" for i in range(5000)] + embeddings = await client.request_embeddings(test_inputs, embedding_config) + + assert len(embeddings) == 5000 + assert mock_client.embeddings.create.call_count == 3 + + +@pytest.mark.asyncio +async def test_openai_embedding_retry_logic(default_user): + """Test that failed chunks are retried with halved size""" + embedding_config = EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_model="text-embedding-3-small", + embedding_dim=1536, + ) + + client = OpenAIClient(actor=default_user) + + with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + call_count = 0 + + async def mock_create(**kwargs): + nonlocal call_count + call_count += 1 + input_size = len(kwargs["input"]) + + # fail on first attempt for large chunks only + if input_size == 2048 and call_count <= 2: + raise Exception("Too many inputs") + + mock_response = AsyncMock() + mock_response.data = [AsyncMock(embedding=[0.1] * 1536) for _ in range(input_size)] + return mock_response + + mock_client.embeddings.create.side_effect = mock_create + + test_inputs = [f"Input {i}" for i in range(3000)] + embeddings = await client.request_embeddings(test_inputs, embedding_config) + + assert len(embeddings) == 3000 + # initial: 2 chunks (2048, 952) + # after retry: first 2048 splits into 2x1024, so total 3 successful calls + 2 failed = 5 + assert call_count > 3 + + +@pytest.mark.asyncio +async def test_openai_embedding_order_preserved(default_user): + """Test that order is maintained despite chunking and retries""" + embedding_config = EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_model="text-embedding-3-small", + embedding_dim=1536, + ) + + client = OpenAIClient(actor=default_user) + + with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + async def mock_create(**kwargs): + # return embeddings where first element = input index + mock_response = AsyncMock() + mock_response.data = [] + for text in kwargs["input"]: + idx = int(text.split()[-1]) + embedding = [float(idx)] + [0.0] * 1535 + mock_response.data.append(AsyncMock(embedding=embedding)) + return mock_response + + mock_client.embeddings.create.side_effect = mock_create + + test_inputs = [f"Text {i}" for i in range(100)] + embeddings = await client.request_embeddings(test_inputs, embedding_config) + + assert len(embeddings) == 100 + for i in range(100): + assert embeddings[i][0] == float(i) + + +@pytest.mark.asyncio +async def test_openai_embedding_minimum_chunk_failure(default_user): + """Test that persistent failures at minimum chunk size raise error""" + embedding_config = EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_model="text-embedding-3-small", + embedding_dim=1536, + ) + + client = OpenAIClient(actor=default_user) + + with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + async def mock_create(**kwargs): + raise Exception("API error") + + mock_client.embeddings.create.side_effect = mock_create + + # test with 300 inputs - will retry down to 256 minimum then fail + test_inputs = [f"Input {i}" for i in range(300)] + + with pytest.raises(Exception, match="API error"): + await client.request_embeddings(test_inputs, embedding_config)