Files
letta-server/letta/services/file_processor/embedder/openai_embedder.py

85 lines
3.3 KiB
Python

import asyncio
from typing import List, Optional, Tuple
import openai
from letta.log import get_logger
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.passage import Passage
from letta.schemas.user import User
from letta.settings import model_settings
logger = get_logger(__name__)
class OpenAIEmbedder:
"""OpenAI-based embedding generation"""
def __init__(self, embedding_config: Optional[EmbeddingConfig] = None):
self.embedding_config = embedding_config or EmbeddingConfig.default_config(model_name="text-embedding-3-small", provider="openai")
# TODO: Unify to global OpenAI client
self.client = openai.AsyncOpenAI(api_key=model_settings.openai_api_key)
self.max_batch = 1024
self.max_concurrent_requests = 20
async def _embed_batch(self, batch: List[str], batch_indices: List[int]) -> List[Tuple[int, List[float]]]:
"""Embed a single batch and return embeddings with their original indices"""
response = await self.client.embeddings.create(model=self.embedding_config.embedding_model, input=batch)
return [(idx, res.embedding) for idx, res in zip(batch_indices, response.data)]
async def generate_embedded_passages(self, file_id: str, source_id: str, chunks: List[str], actor: User) -> List[Passage]:
"""Generate embeddings for chunks with batching and concurrent processing"""
if not chunks:
return []
logger.info(f"Generating embeddings for {len(chunks)} chunks using {self.embedding_config.embedding_model}")
# Create batches with their original indices
batches = []
batch_indices = []
for i in range(0, len(chunks), self.max_batch):
batch = chunks[i : i + self.max_batch]
indices = list(range(i, min(i + self.max_batch, len(chunks))))
batches.append(batch)
batch_indices.append(indices)
logger.info(f"Processing {len(batches)} batches")
async def process(batch: List[str], indices: List[int]):
try:
return await self._embed_batch(batch, indices)
except Exception as e:
logger.error(f"Failed to embed batch of size {len(batch)}: {str(e)}")
raise
# Execute all batches concurrently with semaphore control
tasks = [process(batch, indices) for batch, indices in zip(batches, batch_indices)]
results = await asyncio.gather(*tasks)
# Flatten results and sort by original index
indexed_embeddings = []
for batch_result in results:
indexed_embeddings.extend(batch_result)
# Sort by index to maintain original order
indexed_embeddings.sort(key=lambda x: x[0])
# Create Passage objects in original order
passages = []
for (idx, embedding), text in zip(indexed_embeddings, chunks):
passage = Passage(
text=text,
file_id=file_id,
source_id=source_id,
embedding=embedding,
embedding_config=self.embedding_config,
organization_id=actor.organization_id,
)
passages.append(passage)
logger.info(f"Successfully generated {len(passages)} embeddings")
return passages