227 lines
9.7 KiB
Python
227 lines
9.7 KiB
Python
import asyncio
|
|
import time
|
|
from typing import List, Optional, Tuple, cast
|
|
|
|
from letta.llm_api.llm_client import LLMClient
|
|
from letta.llm_api.openai_client import OpenAIClient
|
|
from letta.log import get_logger
|
|
from letta.otel.tracing import log_event, trace_method
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import ProviderType
|
|
from letta.schemas.passage import Passage
|
|
from letta.schemas.user import User
|
|
from letta.services.file_processor.embedder.base_embedder import BaseEmbedder
|
|
from letta.settings import model_settings
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Global semaphore shared across ALL embedding operations to prevent overwhelming OpenAI API
|
|
# This ensures that even when processing multiple files concurrently, we don't exceed rate limits
|
|
_GLOBAL_EMBEDDING_SEMAPHORE = asyncio.Semaphore(3)
|
|
|
|
|
|
class OpenAIEmbedder(BaseEmbedder):
|
|
"""OpenAI-based embedding generation"""
|
|
|
|
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
|
|
super().__init__()
|
|
# OpenAI embedder uses the native vector db (PostgreSQL)
|
|
# self.vector_db_type already set to VectorDBProvider.NATIVE by parent
|
|
|
|
self.default_embedding_config = (
|
|
EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai")
|
|
if model_settings.openai_api_key
|
|
else EmbeddingConfig.default_config(model_name="letta")
|
|
)
|
|
self.embedding_config = embedding_config or self.default_embedding_config
|
|
|
|
# TODO: Unify to global OpenAI client
|
|
self.client: OpenAIClient = cast(
|
|
OpenAIClient,
|
|
LLMClient.create(
|
|
provider_type=ProviderType.openai,
|
|
put_inner_thoughts_first=False,
|
|
actor=None, # Not necessary
|
|
),
|
|
)
|
|
|
|
@trace_method
|
|
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:
|
|
"""Embed a single batch and return embeddings with their original indices"""
|
|
log_event(
|
|
"embedder.batch_started",
|
|
{
|
|
"batch_size": len(batch),
|
|
"model": self.embedding_config.embedding_model,
|
|
"embedding_endpoint_type": self.embedding_config.embedding_endpoint_type,
|
|
},
|
|
)
|
|
|
|
try:
|
|
embeddings = await self.client.request_embeddings(inputs=batch, embedding_config=self.embedding_config)
|
|
log_event("embedder.batch_completed", {"batch_size": len(batch), "embeddings_generated": len(embeddings)})
|
|
return [(idx, e) for idx, e in zip(batch_indices, embeddings)]
|
|
except Exception as e:
|
|
# if it's a token limit error and we can split, do it
|
|
if self._is_token_limit_error(e) and len(batch) > 1:
|
|
logger.warning(f"Token limit exceeded for batch of size {len(batch)}, splitting in half and retrying")
|
|
log_event(
|
|
"embedder.batch_split_retry",
|
|
{
|
|
"original_batch_size": len(batch),
|
|
"error": str(e),
|
|
"split_size": len(batch) // 2,
|
|
},
|
|
)
|
|
|
|
# split batch in half
|
|
mid = len(batch) // 2
|
|
batch1 = batch[:mid]
|
|
batch1_indices = batch_indices[:mid]
|
|
batch2 = batch[mid:]
|
|
batch2_indices = batch_indices[mid:]
|
|
|
|
# retry with smaller batches
|
|
result1 = await self._embed_batch(batch1, batch1_indices)
|
|
result2 = await self._embed_batch(batch2, batch2_indices)
|
|
|
|
return result1 + result2
|
|
else:
|
|
# re-raise for other errors or if batch size is already 1
|
|
raise
|
|
|
|
def _is_token_limit_error(self, error: Exception) -> bool:
|
|
"""Check if the error is due to token limit exceeded"""
|
|
# convert to string and check for token limit patterns
|
|
error_str = str(error).lower()
|
|
|
|
# TODO: This is quite brittle, works for now
|
|
# check for the specific patterns we see in token limit errors
|
|
is_token_limit = (
|
|
"max_tokens_per_request" in error_str
|
|
or ("requested" in error_str and "tokens" in error_str and "max" in error_str and "per request" in error_str)
|
|
or "token limit" in error_str
|
|
or ("bad request to openai" in error_str and "tokens" in error_str and "max" in error_str)
|
|
)
|
|
|
|
return is_token_limit
|
|
|
|
@trace_method
|
|
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
|
|
"""Generate embeddings for chunks with batching and concurrent processing"""
|
|
if not chunks:
|
|
return []
|
|
|
|
# Filter out empty or whitespace-only chunks that would fail embedding
|
|
valid_chunks = [(i, chunk) for i, chunk in enumerate(chunks) if chunk and chunk.strip()]
|
|
|
|
if not valid_chunks:
|
|
logger.warning(f"No valid text chunks found for file {file_id}. PDF may contain only images without text layer.")
|
|
log_event(
|
|
"embedder.no_valid_chunks",
|
|
{"file_id": file_id, "source_id": source_id, "total_chunks": len(chunks), "reason": "All chunks empty or whitespace-only"},
|
|
)
|
|
return []
|
|
|
|
if len(valid_chunks) < len(chunks):
|
|
logger.info(f"Filtered out {len(chunks) - len(valid_chunks)} empty chunks from {len(chunks)} total")
|
|
log_event(
|
|
"embedder.chunks_filtered",
|
|
{
|
|
"file_id": file_id,
|
|
"original_chunks": len(chunks),
|
|
"valid_chunks": len(valid_chunks),
|
|
"filtered_chunks": len(chunks) - len(valid_chunks),
|
|
},
|
|
)
|
|
|
|
# Extract just the chunk text and indices for processing
|
|
chunk_indices = [i for i, _ in valid_chunks]
|
|
chunks_to_embed = [chunk for _, chunk in valid_chunks]
|
|
|
|
embedding_start = time.time()
|
|
logger.info(f"Generating embeddings for {len(chunks_to_embed)} chunks using {self.embedding_config.embedding_model}")
|
|
log_event(
|
|
"embedder.generation_started",
|
|
{
|
|
"total_chunks": len(chunks_to_embed),
|
|
"model": self.embedding_config.embedding_model,
|
|
"embedding_endpoint_type": self.embedding_config.embedding_endpoint_type,
|
|
"batch_size": self.embedding_config.batch_size,
|
|
"file_id": file_id,
|
|
"source_id": source_id,
|
|
},
|
|
)
|
|
|
|
# Create batches with their original indices
|
|
batches = []
|
|
batch_indices = []
|
|
|
|
for i in range(0, len(chunks_to_embed), self.embedding_config.batch_size):
|
|
batch = chunks_to_embed[i : i + self.embedding_config.batch_size]
|
|
indices = list(range(i, min(i + self.embedding_config.batch_size, len(chunks_to_embed))))
|
|
batches.append(batch)
|
|
batch_indices.append(indices)
|
|
|
|
logger.info(f"Processing {len(batches)} batches")
|
|
log_event(
|
|
"embedder.batching_completed",
|
|
{"total_batches": len(batches), "batch_size": self.embedding_config.batch_size, "total_chunks": len(chunks_to_embed)},
|
|
)
|
|
|
|
# Use global semaphore to limit concurrent embedding requests across ALL file processing
|
|
# This prevents rate limiting even when processing multiple files simultaneously
|
|
async def process(batch: List[str], indices: List[int]):
|
|
async with _GLOBAL_EMBEDDING_SEMAPHORE:
|
|
try:
|
|
return await self._embed_batch(batch, indices)
|
|
except Exception as e:
|
|
logger.error("Failed to embed batch of size %s: %s", len(batch), e)
|
|
log_event("embedder.batch_failed", {"batch_size": len(batch), "error": str(e), "error_type": type(e).__name__})
|
|
raise
|
|
|
|
# Execute all batches with global semaphore control to limit concurrency
|
|
tasks = [process(batch, indices) for batch, indices in zip(batches, batch_indices)]
|
|
|
|
log_event(
|
|
"embedder.concurrent_processing_started",
|
|
{"concurrent_tasks": len(tasks), "max_concurrent_global": 3},
|
|
)
|
|
results = await asyncio.gather(*tasks)
|
|
log_event("embedder.concurrent_processing_completed", {"batches_processed": len(results)})
|
|
|
|
# Flatten results and sort by original index
|
|
indexed_embeddings = []
|
|
for batch_result in results:
|
|
indexed_embeddings.extend(batch_result)
|
|
|
|
# Sort by index to maintain original order
|
|
indexed_embeddings.sort(key=lambda x: x[0])
|
|
|
|
# Create Passage objects in original order
|
|
passages = []
|
|
for (idx, embedding), text in zip(indexed_embeddings, chunks_to_embed):
|
|
passage = Passage(
|
|
text=text,
|
|
file_id=file_id,
|
|
source_id=source_id,
|
|
embedding=embedding,
|
|
embedding_config=self.embedding_config,
|
|
organization_id=actor.organization_id,
|
|
)
|
|
passages.append(passage)
|
|
|
|
embedding_duration = time.time() - embedding_start
|
|
logger.info(f"Successfully generated {len(passages)} embeddings (took {embedding_duration:.2f}s)")
|
|
log_event(
|
|
"embedder.generation_completed",
|
|
{
|
|
"passages_created": len(passages),
|
|
"total_chunks_processed": len(chunks_to_embed),
|
|
"file_id": file_id,
|
|
"source_id": source_id,
|
|
"duration_seconds": embedding_duration,
|
|
},
|
|
)
|
|
return passages
|