feat: Make embedding async (#4091)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -569,33 +570,42 @@ class AgentSerializationManager:
|
||||
imported_count += 1
|
||||
|
||||
# 5. Process files for chunking/embedding (depends on files and sources)
|
||||
if should_use_pinecone():
|
||||
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
|
||||
file_processor = FileProcessor(
|
||||
file_parser=self.file_parser,
|
||||
embedder=embedder,
|
||||
actor=actor,
|
||||
using_pinecone=self.using_pinecone,
|
||||
)
|
||||
# Start background tasks for file processing
|
||||
background_tasks = []
|
||||
if schema.files and any(f.content for f in schema.files):
|
||||
if should_use_pinecone():
|
||||
embedder = PineconeEmbedder(embedding_config=schema.agents[0].embedding_config)
|
||||
else:
|
||||
embedder = OpenAIEmbedder(embedding_config=schema.agents[0].embedding_config)
|
||||
file_processor = FileProcessor(
|
||||
file_parser=self.file_parser,
|
||||
embedder=embedder,
|
||||
actor=actor,
|
||||
using_pinecone=self.using_pinecone,
|
||||
)
|
||||
|
||||
for file_schema in schema.files:
|
||||
if file_schema.content: # Only process files with content
|
||||
file_db_id = file_to_db_ids[file_schema.id]
|
||||
source_db_id = file_to_db_ids[file_schema.source_id]
|
||||
for file_schema in schema.files:
|
||||
if file_schema.content: # Only process files with content
|
||||
file_db_id = file_to_db_ids[file_schema.id]
|
||||
source_db_id = file_to_db_ids[file_schema.source_id]
|
||||
|
||||
# Get the created file metadata (with caching)
|
||||
if file_db_id not in file_metadata_cache:
|
||||
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
|
||||
file_metadata = file_metadata_cache[file_db_id]
|
||||
# Get the created file metadata (with caching)
|
||||
if file_db_id not in file_metadata_cache:
|
||||
file_metadata_cache[file_db_id] = await self.file_manager.get_file_by_id(file_db_id, actor)
|
||||
file_metadata = file_metadata_cache[file_db_id]
|
||||
|
||||
# Save the db call of fetching content again
|
||||
file_metadata.content = file_schema.content
|
||||
# Save the db call of fetching content again
|
||||
file_metadata.content = file_schema.content
|
||||
|
||||
# Process the file for chunking/embedding
|
||||
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_db_id)
|
||||
imported_count += len(passages)
|
||||
# Create background task for file processing
|
||||
# TODO: This can be moved to celery or RQ or something
|
||||
task = asyncio.create_task(
|
||||
self._process_file_async(
|
||||
file_metadata=file_metadata, source_id=source_db_id, file_processor=file_processor, actor=actor
|
||||
)
|
||||
)
|
||||
background_tasks.append(task)
|
||||
logger.info(f"Started background processing for file {file_metadata.file_name} (ID: {file_db_id})")
|
||||
|
||||
# 6. Create agents with empty message history
|
||||
for agent_schema in schema.agents:
|
||||
@@ -696,9 +706,19 @@ class AgentSerializationManager:
|
||||
file_to_db_ids[group.id] = created_group.id
|
||||
imported_count += 1
|
||||
|
||||
# prepare result message
|
||||
num_background_tasks = len(background_tasks)
|
||||
if num_background_tasks > 0:
|
||||
message = (
|
||||
f"Import completed successfully. Imported {imported_count} entities. "
|
||||
f"{num_background_tasks} file(s) are being processed in the background for embeddings."
|
||||
)
|
||||
else:
|
||||
message = f"Import completed successfully. Imported {imported_count} entities."
|
||||
|
||||
return ImportResult(
|
||||
success=True,
|
||||
message=f"Import completed successfully. Imported {imported_count} entities.",
|
||||
message=message,
|
||||
imported_count=imported_count,
|
||||
imported_agent_ids=imported_agent_ids,
|
||||
id_mappings=file_to_db_ids,
|
||||
@@ -876,3 +896,44 @@ class AgentSerializationManager:
|
||||
except AttributeError:
|
||||
allowed = model_cls.__fields__.keys() # Pydantic v1
|
||||
return {k: v for k, v in data.items() if k in allowed}
|
||||
|
||||
async def _process_file_async(self, file_metadata: FileMetadata, source_id: str, file_processor: FileProcessor, actor: User):
|
||||
"""
|
||||
Process a file asynchronously in the background.
|
||||
|
||||
This method handles chunking and embedding of file content without blocking
|
||||
the main import process.
|
||||
|
||||
Args:
|
||||
file_metadata: The file metadata with content
|
||||
source_id: The database ID of the source
|
||||
file_processor: The file processor instance to use
|
||||
actor: The user performing the action
|
||||
"""
|
||||
file_id = file_metadata.id
|
||||
file_name = file_metadata.file_name
|
||||
|
||||
try:
|
||||
logger.info(f"Starting background processing for file {file_name} (ID: {file_id})")
|
||||
|
||||
# process the file for chunking/embedding
|
||||
passages = await file_processor.process_imported_file(file_metadata=file_metadata, source_id=source_id)
|
||||
|
||||
logger.info(f"Successfully processed file {file_name} with {len(passages)} passages")
|
||||
|
||||
# file status is automatically updated to COMPLETED by process_imported_file
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process file {file_name} (ID: {file_id}) in background: {e}")
|
||||
|
||||
# update file status to ERROR
|
||||
try:
|
||||
await self.file_manager.update_file_status(
|
||||
file_id=file_id, actor=actor, processing_status=FileProcessingStatus.ERROR, error_message=str(e)
|
||||
)
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update file status to ERROR for {file_id}: {update_error}")
|
||||
|
||||
# we don't re-raise here since this is a background task
|
||||
# the file will be marked as ERROR and the import can continue
|
||||
|
||||
@@ -10874,3 +10874,10 @@ FAILED tests/test_managers.py::test_high_concurrency_stress_test - AssertionErro
|
||||
# # Clean up
|
||||
# for block in blocks:
|
||||
# await server.block_manager.delete_block_async(block.id, actor=default_user)
|
||||
|
||||
|
||||
# TODO: I use this as a way to easily wipe my local db lol sorry
|
||||
# TODO: Leave this in here I constantly wipe my db for testing unless you care about optics
|
||||
@pytest.mark.asyncio
|
||||
async def test_wipe():
|
||||
assert True
|
||||
|
||||
Reference in New Issue
Block a user