Files
letta-server/letta/services/file_processor/embedder/openai_embedder.py

227 lines
9.7 KiB
Python

import asyncio
import time
from typing import List, Optional, Tuple, cast
from letta.llm_api.llm_client import LLMClient
from letta.llm_api.openai_client import OpenAIClient
from letta.log import get_logger
from letta.otel.tracing import log_event, trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderType
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
from letta.settings import model_settings
logger = get_logger(__name__)
# Global semaphore shared across ALL embedding operations to prevent overwhelming OpenAI API
# This ensures that even when processing multiple files concurrently, we don't exceed rate limits
_GLOBAL_EMBEDDING_SEMAPHORE = asyncio.Semaphore(3)
class OpenAIEmbedder(BaseEmbedder):
"""OpenAI-based embedding generation"""
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
super().__init__()
# OpenAI embedder uses the native vector db (PostgreSQL)
# self.vector_db_type already set to VectorDBProvider.NATIVE by parent
self.default_embedding_config = (
EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai")
if model_settings.openai_api_key
else EmbeddingConfig.default_config(model_name="letta")
)
self.embedding_config = embedding_config or self.default_embedding_config
# TODO: Unify to global OpenAI client
self.client: OpenAIClient = cast(
OpenAIClient,
LLMClient.create(
provider_type=ProviderType.openai,
put_inner_thoughts_first=False,
actor=None, # Not necessary
),
)
@trace_method
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:
"""Embed a single batch and return embeddings with their original indices"""
log_event(
"embedder.batch_started",
{
"batch_size": len(batch),
"model": self.embedding_config.embedding_model,
"embedding_endpoint_type": self.embedding_config.embedding_endpoint_type,
},
)
try:
embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config)
log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)})
return [(idx, e) for idx, e in zip(batch_indices, embeddings)]
except Exception as e:
# if it's a token limit error and we can split, do it
if self._is_token_limit_error(e) and len(batch) > 1:
logger.warning(f"Token limit exceeded for batch of size {len(batch)}, splitting in half and retrying")
log_event(
"embedder.batch_split_retry",
{
"original_batch_size": len(batch),
"error": str(e),
"split_size": len(batch) // 2,
},
)
# split batch in half
mid = len(batch) // 2
batch1 = batch[:mid]
batch1_indices = batch_indices[:mid]
batch2 = batch[mid:]
batch2_indices = batch_indices[mid:]
# retry with smaller batches
result1 = await self._embed_batch(batch1, batch1_indices)
result2 = await self._embed_batch(batch2, batch2_indices)
return result1 + result2
else:
# re-raise for other errors or if batch size is already 1
raise
def _is_token_limit_error(self, error: Exception) -> bool:
"""Check if the error is due to token limit exceeded"""
# convert to string and check for token limit patterns
error_str = str(error).lower()
# TODO: This is quite brittle, works for now
# check for the specific patterns we see in token limit errors
is_token_limit = (
"max_tokens_per_request" in error_str
or ("requested" in error_str and "tokens" in error_str and "max" in error_str and "per request" in error_str)
or "token limit" in error_str
or ("bad request to openai" in error_str and "tokens" in error_str and "max" in error_str)
)
return is_token_limit
@trace_method
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings for chunks with batching and concurrent processing"""
if not chunks:
return []
# Filter out empty or whitespace-only chunks that would fail embedding
valid_chunks = [(i, chunk) for i, chunk in enumerate(chunks) if chunk and chunk.strip()]
if not valid_chunks:
logger.warning(f"No valid text chunks found for file {file_id}. PDF may contain only images without text layer.")
log_event(
"embedder.no_valid_chunks",
{"file_id": file_id, "source_id": source_id, "total_chunks": len(chunks), "reason": "All chunks empty or whitespace-only"},
)
return []
if len(valid_chunks) < len(chunks):
logger.info(f"Filtered out {len(chunks) - len(valid_chunks)} empty chunks from {len(chunks)} total")
log_event(
"embedder.chunks_filtered",
{
"file_id": file_id,
"original_chunks": len(chunks),
"valid_chunks": len(valid_chunks),
"filtered_chunks": len(chunks) - len(valid_chunks),
},
)
# Extract just the chunk text and indices for processing
chunk_indices = [i for i, _ in valid_chunks]
chunks_to_embed = [chunk for _, chunk in valid_chunks]
embedding_start = time.time()
logger.info(f"Generating embeddings for {len(chunks_to_embed)} chunks using {self.embedding_config.embedding_model}")
log_event(
"embedder.generation_started",
{
"total_chunks": len(chunks_to_embed),
"model": self.embedding_config.embedding_model,
"embedding_endpoint_type": self.embedding_config.embedding_endpoint_type,
"batch_size": self.embedding_config.batch_size,
"file_id": file_id,
"source_id": source_id,
},
)
# Create batches with their original indices
batches = []
batch_indices = []
for i in range(0, len(chunks_to_embed), self.embedding_config.batch_size):
batch = chunks_to_embed[i : i + self.embedding_config.batch_size]
indices = list(range(i, min(i + self.embedding_config.batch_size, len(chunks_to_embed))))
batches.append(batch)
batch_indices.append(indices)
logger.info(f"Processing {len(batches)} batches")
log_event(
"embedder.batching_completed",
{"total_batches": len(batches), "batch_size": self.embedding_config.batch_size, "total_chunks": len(chunks_to_embed)},
)
# Use global semaphore to limit concurrent embedding requests across ALL file processing
# This prevents rate limiting even when processing multiple files simultaneously
async def process(batch: List[str], indices: List[int]):
async with _GLOBAL_EMBEDDING_SEMAPHORE:
try:
return await self._embed_batch(batch, indices)
except Exception as e:
logger.error("Failed to embed batch of size %s: %s", len(batch), e)
log_event("embedder.batch_failed", {"batch_size": len(batch), "error": str(e), "error_type": type(e).__name__})
raise
# Execute all batches with global semaphore control to limit concurrency
tasks = [process(batch, indices) for batch, indices in zip(batches, batch_indices)]
log_event(
"embedder.concurrent_processing_started",
{"concurrent_tasks": len(tasks), "max_concurrent_global": 3},
)
results = await asyncio.gather(*tasks)
log_event("embedder.concurrent_processing_completed", {"batches_processed": len(results)})
# Flatten results and sort by original index
indexed_embeddings = []
for batch_result in results:
indexed_embeddings.extend(batch_result)
# Sort by index to maintain original order
indexed_embeddings.sort(key=lambda x: x[0])
# Create Passage objects in original order
passages = []
for (idx, embedding), text in zip(indexed_embeddings, chunks_to_embed):
passage = Passage(
text=text,
file_id=file_id,
source_id=source_id,
embedding=embedding,
embedding_config=self.embedding_config,
organization_id=actor.organization_id,
)
passages.append(passage)
embedding_duration = time.time() - embedding_start
logger.info(f"Successfully generated {len(passages)} embeddings (took {embedding_duration:.2f}s)")
log_event(
"embedder.generation_completed",
{
"passages_created": len(passages),
"total_chunks_processed": len(chunks_to_embed),
"file_id": file_id,
"source_id": source_id,
"duration_seconds": embedding_duration,
},
)
return passages