feat: Robustify openai embedding [LET-4256] (#4478)
* Robustify embedding * Remove unecessary imports * Add test embeddings
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -319,13 +320,53 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config"""
|
||||
"""Request embeddings given texts and embedding config with chunking and retry logic"""
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
kwargs = self._prepare_client_kwargs_embedding(embedding_config)
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
# TODO: add total usage
|
||||
return [r.embedding for r in response.data]
|
||||
# track results by original index to maintain order
|
||||
results = [None] * len(inputs)
|
||||
|
||||
# queue of (start_idx, chunk_inputs) to process
|
||||
chunks_to_process = [(i, inputs[i : i + 2048]) for i in range(0, len(inputs), 2048)]
|
||||
|
||||
min_chunk_size = 256
|
||||
|
||||
while chunks_to_process:
|
||||
tasks = []
|
||||
task_metadata = []
|
||||
|
||||
for start_idx, chunk_inputs in chunks_to_process:
|
||||
task = client.embeddings.create(model=embedding_config.embedding_model, input=chunk_inputs)
|
||||
tasks.append(task)
|
||||
task_metadata.append((start_idx, chunk_inputs))
|
||||
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
failed_chunks = []
|
||||
for (start_idx, chunk_inputs), result in zip(task_metadata, task_results):
|
||||
if isinstance(result, Exception):
|
||||
# check if we can retry with smaller chunks
|
||||
if len(chunk_inputs) > min_chunk_size:
|
||||
# split chunk in half and queue for retry
|
||||
mid = len(chunk_inputs) // 2
|
||||
failed_chunks.append((start_idx, chunk_inputs[:mid]))
|
||||
failed_chunks.append((start_idx + mid, chunk_inputs[mid:]))
|
||||
else:
|
||||
# can't split further, re-raise the error
|
||||
logger.error(f"Failed to get embeddings for chunk starting at {start_idx} even with minimum size {min_chunk_size}")
|
||||
raise result
|
||||
else:
|
||||
embeddings = [r.embedding for r in result.data]
|
||||
for i, embedding in enumerate(embeddings):
|
||||
results[start_idx + i] = embedding
|
||||
|
||||
chunks_to_process = failed_chunks
|
||||
|
||||
return results
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@@ -60,3 +62,142 @@ async def test_embeddings(embedding_config: EmbeddingConfig, default_user):
|
||||
embeddings = await embedding_client.request_embeddings([test_input], embedding_config)
|
||||
assert len(embeddings) == 1
|
||||
assert len(embeddings[0]) == embedding_config.embedding_dim
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding_chunking(default_user):
|
||||
"""Test that large inputs are split into 2048-sized chunks"""
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_dim=1536,
|
||||
)
|
||||
|
||||
client = OpenAIClient(actor=default_user)
|
||||
|
||||
with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async def mock_create(**kwargs):
|
||||
input_size = len(kwargs["input"])
|
||||
assert input_size <= 2048 # verify chunking
|
||||
mock_response = AsyncMock()
|
||||
mock_response.data = [AsyncMock(embedding=[0.1] * 1536) for _ in range(input_size)]
|
||||
return mock_response
|
||||
|
||||
mock_client.embeddings.create.side_effect = mock_create
|
||||
|
||||
# test with 5000 inputs (should be split into 3 chunks: 2048, 2048, 904)
|
||||
test_inputs = [f"Input {i}" for i in range(5000)]
|
||||
embeddings = await client.request_embeddings(test_inputs, embedding_config)
|
||||
|
||||
assert len(embeddings) == 5000
|
||||
assert mock_client.embeddings.create.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding_retry_logic(default_user):
|
||||
"""Test that failed chunks are retried with halved size"""
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_dim=1536,
|
||||
)
|
||||
|
||||
client = OpenAIClient(actor=default_user)
|
||||
|
||||
with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_create(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
input_size = len(kwargs["input"])
|
||||
|
||||
# fail on first attempt for large chunks only
|
||||
if input_size == 2048 and call_count <= 2:
|
||||
raise Exception("Too many inputs")
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.data = [AsyncMock(embedding=[0.1] * 1536) for _ in range(input_size)]
|
||||
return mock_response
|
||||
|
||||
mock_client.embeddings.create.side_effect = mock_create
|
||||
|
||||
test_inputs = [f"Input {i}" for i in range(3000)]
|
||||
embeddings = await client.request_embeddings(test_inputs, embedding_config)
|
||||
|
||||
assert len(embeddings) == 3000
|
||||
# initial: 2 chunks (2048, 952)
|
||||
# after retry: first 2048 splits into 2x1024, so total 3 successful calls + 2 failed = 5
|
||||
assert call_count > 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding_order_preserved(default_user):
|
||||
"""Test that order is maintained despite chunking and retries"""
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_dim=1536,
|
||||
)
|
||||
|
||||
client = OpenAIClient(actor=default_user)
|
||||
|
||||
with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async def mock_create(**kwargs):
|
||||
# return embeddings where first element = input index
|
||||
mock_response = AsyncMock()
|
||||
mock_response.data = []
|
||||
for text in kwargs["input"]:
|
||||
idx = int(text.split()[-1])
|
||||
embedding = [float(idx)] + [0.0] * 1535
|
||||
mock_response.data.append(AsyncMock(embedding=embedding))
|
||||
return mock_response
|
||||
|
||||
mock_client.embeddings.create.side_effect = mock_create
|
||||
|
||||
test_inputs = [f"Text {i}" for i in range(100)]
|
||||
embeddings = await client.request_embeddings(test_inputs, embedding_config)
|
||||
|
||||
assert len(embeddings) == 100
|
||||
for i in range(100):
|
||||
assert embeddings[i][0] == float(i)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding_minimum_chunk_failure(default_user):
|
||||
"""Test that persistent failures at minimum chunk size raise error"""
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_model="text-embedding-3-small",
|
||||
embedding_dim=1536,
|
||||
)
|
||||
|
||||
client = OpenAIClient(actor=default_user)
|
||||
|
||||
with patch("letta.llm_api.openai_client.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
async def mock_create(**kwargs):
|
||||
raise Exception("API error")
|
||||
|
||||
mock_client.embeddings.create.side_effect = mock_create
|
||||
|
||||
# test with 300 inputs - will retry down to 256 minimum then fail
|
||||
test_inputs = [f"Input {i}" for i in range(300)]
|
||||
|
||||
with pytest.raises(Exception, match="API error"):
|
||||
await client.request_embeddings(test_inputs, embedding_config)
|
||||
|
||||
Reference in New Issue
Block a user