feat: Make embedding async (#4091)

This commit is contained in:
Matthew Zhou
2025-08-21 16:52:23 -07:00
committed by GitHub
parent a2f4ca5f89
commit 223c883205
2 changed files with 92 additions and 24 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
@@ -569,33 +570,42 @@ class AgentSerializationManager:
imported_count += 1
# 5. Process files for chunking/embedding (depends on files and sources)
if should_use_pinecone():
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
else:
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
file_processor = FileProcessor(
file_parser=self.file_parser,
embedder=embedder,
actor=actor,
using_pinecone=self.using_pinecone,
)
# Start background tasks for file processing
background_tasks = []
if schema.files and any(f.content for f in schema.files):
if should_use_pinecone():
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
else:
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
file_processor = FileProcessor(
file_parser=self.file_parser,
embedder=embedder,
actor=actor,
using_pinecone=self.using_pinecone,
)
for file_schema in schema.files:
if file_schema.content: # Only process files with content
file_db_id = file_to_db_ids[file_schema.id]
source_db_id = file_to_db_ids[file_schema.source_id]
for file_schema in schema.files:
if file_schema.content: # Only process files with content
file_db_id = file_to_db_ids[file_schema.id]
source_db_id = file_to_db_ids[file_schema.source_id]
# Get the created file metadata (with caching)
if file_db_id not in file_metadata_cache:
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
file_metadata = file_metadata_cache[file_db_id]
# Get the created file metadata (with caching)
if file_db_id not in file_metadata_cache:
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
file_metadata = file_metadata_cache[file_db_id]
# Save the db call of fetching content again
file_metadata.content = file_schema.content
# Save the db call of fetching content again
file_metadata.content = file_schema.content
# Process the file for chunking/embedding
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_db_id)
imported_count += len(passages)
# Create background task for file processing
# TODO: This can be moved to celery or RQ or something
task = asyncio.create_task(
self._process_file_async(
file_metadata=file_metadata, source_id=source_db_id, file_processor=file_processor, actor=actor
)
)
background_tasks.append(task)
logger.info(f"Started background processing for file {file_metadata.file_name} (ID: {file_db_id})")
# 6. Create agents with empty message history
for agent_schema in schema.agents:
@@ -696,9 +706,19 @@ class AgentSerializationManager:
file_to_db_ids[group.id] = created_group.id
imported_count += 1
# prepare result message
num_background_tasks = len(background_tasks)
if num_background_tasks > 0:
message = (
f"Import completed successfully. Imported {imported_count} entities. "
f"{num_background_tasks} file(s) are being processed in the background for embeddings."
)
else:
message = f"Import completed successfully. Imported {imported_count} entities."
return ImportResult(
success=True,
message=f"Import completed successfully. Imported {imported_count} entities.",
message=message,
imported_count=imported_count,
imported_agent_ids=imported_agent_ids,
id_mappings=file_to_db_ids,
@@ -876,3 +896,44 @@ class AgentSerializationManager:
except AttributeError:
allowed = model_cls.__fields__.keys() # Pydantic v1
return {k: v for k, v in data.items() if k in allowed}
async def _process_file_async(self, file_metadata: FileMetadata, source_id: str, file_processor: FileProcessor, actor: User):
"""
Process a file asynchronously in the background.
This method handles chunking and embedding of file content without blocking
the main import process.
Args:
file_metadata: The file metadata with content
source_id: The database ID of the source
file_processor: The file processor instance to use
actor: The user performing the action
"""
file_id = file_metadata.id
file_name = file_metadata.file_name
try:
logger.info(f"Starting background processing for file {file_name} (ID: {file_id})")
# process the file for chunking/embedding
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_id)
logger.info(f"Successfully processed file {file_name} with {len(passages)} passages")
# file status is automatically updated to COMPLETED by process_imported_file
return passages
except Exception as e:
logger.error(f"Failed to process file {file_name} (ID: {file_id}) in background: {e}")
# update file status to ERROR
try:
await self.file_manager.update_file_status(
file_id=file_id, actor=actor, processing_status=FileProcessingStatus.ERROR, error_message=str(e)
)
except Exception as update_error:
logger.error(f"Failed to update file status to ERROR for {file_id}: {update_error}")
# we don't re-raise here since this is a background task
# the file will be marked as ERROR and the import can continue

View File

@@ -10874,3 +10874,10 @@ FAILED tests/test_managers.py::test_high_concurrency_stress_test - AssertionErro
# # Clean up
# for block in blocks:
# await server.block_manager.delete_block_async(block.id, actor=default_user)
# TODO: I use this as a way to easily wipe my local db lol sorry
# TODO: Leave this in here I constantly wipe my db for testing unless you care about optics
@pytest.mark.asyncio
async def test_wipe():
assert True