From 223c8832054fc7e9c3aedd0eede2a93553fa35ec Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 21 Aug 2025 16:52:23 -0700 Subject: [PATCH] feat: Make embedding async (#4091) --- letta/services/agent_serialization_manager.py | 109 ++++++++++++++---- tests/test_managers.py | 7 ++ 2 files changed, 92 insertions(+), 24 deletions(-) diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 722801cc..9155049d 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -1,3 +1,4 @@ +import asyncio import uuid from datetime import datetime, timezone from typing import Any, Dict, List, Optional @@ -569,33 +570,42 @@ class AgentSerializationManager: imported_count += 1 # 5. Process files for chunking/embedding (depends on files and sources) - if should_use_pinecone(): - embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config) - else: - embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config) - file_processor = FileProcessor( - file_parser=self.file_parser, - embedder=embedder, - actor=actor, - using_pinecone=self.using_pinecone, - ) + # Start background tasks for file processing + background_tasks = [] + if schema.files and any(f.content for f in schema.files): + if should_use_pinecone(): + embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config) + else: + embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config) + file_processor = FileProcessor( + file_parser=self.file_parser, + embedder=embedder, + actor=actor, + using_pinecone=self.using_pinecone, + ) - for file_schema in schema.files: - if file_schema.content: # Only process files with content - file_db_id = file_to_db_ids[file_schema.id] - source_db_id = file_to_db_ids[file_schema.source_id] + for file_schema in schema.files: + if file_schema.content: # Only process files with content + file_db_id = file_to_db_ids[file_schema.id] + source_db_id = file_to_db_ids[file_schema.source_id] - # Get the created file metadata (with caching) - if file_db_id not in file_metadata_cache: - file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor) - file_metadata = file_metadata_cache[file_db_id] + # Get the created file metadata (with caching) + if file_db_id not in file_metadata_cache: + file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor) + file_metadata = file_metadata_cache[file_db_id] - # Save the db call of fetching content again - file_metadata.content = file_schema.content + # Save the db call of fetching content again + file_metadata.content = file_schema.content - # Process the file for chunking/embedding - passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_db_id) - imported_count += len(passages) + # Create background task for file processing + # TODO: This can be moved to celery or RQ or something + task = asyncio.create_task( + self._process_file_async( + file_metadata=file_metadata, source_id=source_db_id, file_processor=file_processor, actor=actor + ) + ) + background_tasks.append(task) + logger.info(f"Started background processing for file {file_metadata.file_name} (ID: {file_db_id})") # 6. Create agents with empty message history for agent_schema in schema.agents: @@ -696,9 +706,19 @@ class AgentSerializationManager: file_to_db_ids[group.id] = created_group.id imported_count += 1 + # prepare result message + num_background_tasks = len(background_tasks) + if num_background_tasks > 0: + message = ( + f"Import completed successfully. Imported {imported_count} entities. " + f"{num_background_tasks} file(s) are being processed in the background for embeddings." + ) + else: + message = f"Import completed successfully. Imported {imported_count} entities." + return ImportResult( success=True, - message=f"Import completed successfully. Imported {imported_count} entities.", + message=message, imported_count=imported_count, imported_agent_ids=imported_agent_ids, id_mappings=file_to_db_ids, @@ -876,3 +896,44 @@ class AgentSerializationManager: except AttributeError: allowed = model_cls.__fields__.keys() # Pydantic v1 return {k: v for k, v in data.items() if k in allowed} + + async def _process_file_async(self, file_metadata: FileMetadata, source_id: str, file_processor: FileProcessor, actor: User): + """ + Process a file asynchronously in the background. + + This method handles chunking and embedding of file content without blocking + the main import process. + + Args: + file_metadata: The file metadata with content + source_id: The database ID of the source + file_processor: The file processor instance to use + actor: The user performing the action + """ + file_id = file_metadata.id + file_name = file_metadata.file_name + + try: + logger.info(f"Starting background processing for file {file_name} (ID: {file_id})") + + # process the file for chunking/embedding + passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_id) + + logger.info(f"Successfully processed file {file_name} with {len(passages)} passages") + + # file status is automatically updated to COMPLETED by process_imported_file + return passages + + except Exception as e: + logger.error(f"Failed to process file {file_name} (ID: {file_id}) in background: {e}") + + # update file status to ERROR + try: + await self.file_manager.update_file_status( + file_id=file_id, actor=actor, processing_status=FileProcessingStatus.ERROR, error_message=str(e) + ) + except Exception as update_error: + logger.error(f"Failed to update file status to ERROR for {file_id}: {update_error}") + + # we don't re-raise here since this is a background task + # the file will be marked as ERROR and the import can continue diff --git a/tests/test_managers.py b/tests/test_managers.py index adeadd41..f8acc485 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -10874,3 +10874,10 @@ FAILED tests/test_managers.py::test_high_concurrency_stress_test - AssertionErro # # Clean up # for block in blocks: # await server.block_manager.delete_block_async(block.id, actor=default_user) + + +# TODO: I use this as a way to easily wipe my local db lol sorry +# TODO: Leave this in here I constantly wipe my db for testing unless you care about optics +@pytest.mark.asyncio +async def test_wipe(): + assert True