import asyncio import time from typing import List, Optional, Tuple, cast from letta.llm_api.llm_client import LLMClient from letta.llm_api.openai_client import OpenAIClient from letta.log import get_logger from letta.otel.tracing import log_event, trace_method from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ProviderType from letta.schemas.passage import Passage from letta.schemas.user import User from letta.services.file_processor.embedder.base_embedder import BaseEmbedder from letta.settings import model_settings logger = get_logger(__name__) # Global semaphore shared across ALL embedding operations to prevent overwhelming OpenAI API # This ensures that even when processing multiple files concurrently, we don't exceed rate limits _GLOBAL_EMBEDDING_SEMAPHORE = asyncio.Semaphore(3) class OpenAIEmbedder(BaseEmbedder): """OpenAI-based embedding generation""" def __init__(self, embedding_config: Optional[EmbeddingConfig] = None): super().__init__() # OpenAI embedder uses the native vector db (PostgreSQL) # self.vector_db_type already set to VectorDBProvider.NATIVE by parent self.default_embedding_config = ( EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai") if model_settings.openai_api_key else EmbeddingConfig.default_config(model_name="letta") ) self.embedding_config = embedding_config or self.default_embedding_config # TODO: Unify to global OpenAI client self.client: OpenAIClient = cast( OpenAIClient, LLMClient.create( provider_type=ProviderType.openai, put_inner_thoughts_first=False, actor=None, # Not necessary ), ) @trace_method async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]: """Embed a single batch and return embeddings with their original indices""" log_event( "embedder.batch_started", { "batch_size": len(batch), "model": self.embedding_config.embedding_model, "embedding_endpoint_type": self.embedding_config.embedding_endpoint_type, }, ) try: embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config) log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)}) return [(idx, e) for idx, e in zip(batch_indices, embeddings)] except Exception as e: # if it's a token limit error and we can split, do it if self._is_token_limit_error(e) and len(batch) > 1: logger.warning(f"Token limit exceeded for batch of size {len(batch)}, splitting in half and retrying") log_event( "embedder.batch_split_retry", { "original_batch_size": len(batch), "error": str(e), "split_size": len(batch) // 2, }, ) # split batch in half mid = len(batch) // 2 batch1 = batch[:mid] batch1_indices = batch_indices[:mid] batch2 = batch[mid:] batch2_indices = batch_indices[mid:] # retry with smaller batches result1 = await self._embed_batch(batch1, batch1_indices) result2 = await self._embed_batch(batch2, batch2_indices) return result1 + result2 else: # re-raise for other errors or if batch size is already 1 raise def _is_token_limit_error(self, error: Exception) -> bool: """Check if the error is due to token limit exceeded""" # convert to string and check for token limit patterns error_str = str(error).lower() # TODO: This is quite brittle, works for now # check for the specific patterns we see in token limit errors is_token_limit = ( "max_tokens_per_request" in error_str or ("requested" in error_str and "tokens" in error_str and "max" in error_str and "per request" in error_str) or "token limit" in error_str or ("bad request to openai" in error_str and "tokens" in error_str and "max" in error_str) ) return is_token_limit @trace_method async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]: """Generate embeddings for chunks with batching and concurrent processing""" if not chunks: return [] # Filter out empty or whitespace-only chunks that would fail embedding valid_chunks = [(i, chunk) for i, chunk in enumerate(chunks) if chunk and chunk.strip()] if not valid_chunks: logger.warning(f"No valid text chunks found for file {file_id}. PDF may contain only images without text layer.") log_event( "embedder.no_valid_chunks", {"file_id": file_id, "source_id": source_id, "total_chunks": len(chunks), "reason": "All chunks empty or whitespace-only"}, ) return [] if len(valid_chunks) < len(chunks): logger.info(f"Filtered out {len(chunks) - len(valid_chunks)} empty chunks from {len(chunks)} total") log_event( "embedder.chunks_filtered", { "file_id": file_id, "original_chunks": len(chunks), "valid_chunks": len(valid_chunks), "filtered_chunks": len(chunks) - len(valid_chunks), }, ) # Extract just the chunk text and indices for processing chunk_indices = [i for i, _ in valid_chunks] chunks_to_embed = [chunk for _, chunk in valid_chunks] embedding_start = time.time() logger.info(f"Generating embeddings for {len(chunks_to_embed)} chunks using {self.embedding_config.embedding_model}") log_event( "embedder.generation_started", { "total_chunks": len(chunks_to_embed), "model": self.embedding_config.embedding_model, "embedding_endpoint_type": self.embedding_config.embedding_endpoint_type, "batch_size": self.embedding_config.batch_size, "file_id": file_id, "source_id": source_id, }, ) # Create batches with their original indices batches = [] batch_indices = [] for i in range(0, len(chunks_to_embed), self.embedding_config.batch_size): batch = chunks_to_embed[i : i + self.embedding_config.batch_size] indices = list(range(i, min(i + self.embedding_config.batch_size, len(chunks_to_embed)))) batches.append(batch) batch_indices.append(indices) logger.info(f"Processing {len(batches)} batches") log_event( "embedder.batching_completed", {"total_batches": len(batches), "batch_size": self.embedding_config.batch_size, "total_chunks": len(chunks_to_embed)}, ) # Use global semaphore to limit concurrent embedding requests across ALL file processing # This prevents rate limiting even when processing multiple files simultaneously async def process(batch: List[str], indices: List[int]): async with _GLOBAL_EMBEDDING_SEMAPHORE: try: return await self._embed_batch(batch, indices) except Exception as e: logger.error("Failed to embed batch of size %s: %s", len(batch), e) log_event("embedder.batch_failed", {"batch_size": len(batch), "error": str(e), "error_type": type(e).__name__}) raise # Execute all batches with global semaphore control to limit concurrency tasks = [process(batch, indices) for batch, indices in zip(batches, batch_indices)] log_event( "embedder.concurrent_processing_started", {"concurrent_tasks": len(tasks), "max_concurrent_global": 3}, ) results = await asyncio.gather(*tasks) log_event("embedder.concurrent_processing_completed", {"batches_processed": len(results)}) # Flatten results and sort by original index indexed_embeddings = [] for batch_result in results: indexed_embeddings.extend(batch_result) # Sort by index to maintain original order indexed_embeddings.sort(key=lambda x: x[0]) # Create Passage objects in original order passages = [] for (idx, embedding), text in zip(indexed_embeddings, chunks_to_embed): passage = Passage( text=text, file_id=file_id, source_id=source_id, embedding=embedding, embedding_config=self.embedding_config, organization_id=actor.organization_id, ) passages.append(passage) embedding_duration = time.time() - embedding_start logger.info(f"Successfully generated {len(passages)} embeddings (took {embedding_duration:.2f}s)") log_event( "embedder.generation_completed", { "passages_created": len(passages), "total_chunks_processed": len(chunks_to_embed), "file_id": file_id, "source_id": source_id, "duration_seconds": embedding_duration, }, ) return passages