feat: Add bulk source upsert mechanism (#3456)
This commit is contained in:
@@ -386,14 +386,27 @@ class AgentFileManager:
|
||||
file_to_db_ids[block_schema.id] = created_block.id
|
||||
imported_count += 1
|
||||
|
||||
# 3. Create sources (no dependencies)
|
||||
for source_schema in schema.sources:
|
||||
# Convert SourceSchema back to Source
|
||||
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
|
||||
source = Source(**source_data)
|
||||
created_source = await self.source_manager.create_source(source, actor)
|
||||
file_to_db_ids[source_schema.id] = created_source.id
|
||||
imported_count += 1
|
||||
# 3. Create sources (no dependencies) - using bulk upsert for efficiency
|
||||
if schema.sources:
|
||||
# convert source schemas to pydantic sources
|
||||
pydantic_sources = []
|
||||
for source_schema in schema.sources:
|
||||
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
|
||||
pydantic_sources.append(Source(**source_data))
|
||||
|
||||
# bulk upsert all sources at once
|
||||
created_sources = await self.source_manager.bulk_upsert_sources_async(pydantic_sources, actor)
|
||||
|
||||
# map file ids to database ids
|
||||
# note: sources are matched by name during upsert, so we need to match by name here too
|
||||
created_sources_by_name = {source.name: source for source in created_sources}
|
||||
for source_schema in schema.sources:
|
||||
created_source = created_sources_by_name.get(source_schema.name)
|
||||
if created_source:
|
||||
file_to_db_ids[source_schema.id] = created_source.id
|
||||
imported_count += 1
|
||||
else:
|
||||
logger.warning(f"Source {source_schema.name} was not created during bulk upsert")
|
||||
|
||||
# 4. Create files (depends on sources)
|
||||
for file_schema in schema.files:
|
||||
|
||||
@@ -58,6 +58,123 @@ class SourceManager:
|
||||
await source.create_async(session, actor=actor)
|
||||
return source.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def bulk_upsert_sources_async(self, pydantic_sources: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""
|
||||
Bulk create or update multiple sources in a single database transaction.
|
||||
|
||||
Uses optimized PostgreSQL bulk upsert when available, falls back to individual
|
||||
upserts for SQLite. This is much more efficient than calling create_source
|
||||
in a loop.
|
||||
|
||||
IMPORTANT BEHAVIOR NOTES:
|
||||
- Sources are matched by (name, organization_id) unique constraint, NOT by ID
|
||||
- If a source with the same name already exists for the organization, it will be updated
|
||||
regardless of any ID provided in the input source
|
||||
- The existing source's ID is preserved during updates
|
||||
- If you provide a source with an explicit ID but a name that matches an existing source,
|
||||
the existing source will be updated and the provided ID will be ignored
|
||||
- This matches the behavior of create_source which also checks by ID first
|
||||
|
||||
PostgreSQL optimization:
|
||||
- Uses native ON CONFLICT (name, organization_id) DO UPDATE for atomic upserts
|
||||
- All sources are processed in a single SQL statement for maximum efficiency
|
||||
|
||||
SQLite fallback:
|
||||
- Falls back to individual create_source calls
|
||||
- Still benefits from batched transaction handling
|
||||
|
||||
Args:
|
||||
pydantic_sources: List of sources to create or update
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
List of created/updated sources
|
||||
"""
|
||||
if not pydantic_sources:
|
||||
return []
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# use optimized postgresql bulk upsert
|
||||
async with db_registry.async_session() as session:
|
||||
return await self._bulk_upsert_postgresql(session, pydantic_sources, actor)
|
||||
else:
|
||||
# fallback to individual upserts for sqlite
|
||||
return await self._upsert_sources_individually(pydantic_sources, actor)
|
||||
|
||||
@trace_method
|
||||
async def _bulk_upsert_postgresql(self, session, source_data_list: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""Hyper-optimized PostgreSQL bulk upsert using ON CONFLICT DO UPDATE."""
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
# prepare data for bulk insert
|
||||
table = SourceModel.__table__
|
||||
valid_columns = {col.name for col in table.columns}
|
||||
|
||||
insert_data = []
|
||||
for source in source_data_list:
|
||||
source_dict = source.model_dump(to_orm=True)
|
||||
# set created/updated by fields
|
||||
|
||||
if actor:
|
||||
source_dict["_created_by_id"] = actor.id
|
||||
source_dict["_last_updated_by_id"] = actor.id
|
||||
source_dict["organization_id"] = actor.organization_id
|
||||
|
||||
# filter to only include columns that exist in the table
|
||||
filtered_dict = {k: v for k, v in source_dict.items() if k in valid_columns}
|
||||
insert_data.append(filtered_dict)
|
||||
|
||||
# use postgresql's native bulk upsert
|
||||
stmt = insert(table).values(insert_data)
|
||||
|
||||
# on conflict, update all columns except id, created_at, and _created_by_id
|
||||
excluded = stmt.excluded
|
||||
update_dict = {}
|
||||
for col in table.columns:
|
||||
if col.name not in ("id", "created_at", "_created_by_id"):
|
||||
if col.name == "updated_at":
|
||||
update_dict[col.name] = func.now()
|
||||
else:
|
||||
update_dict[col.name] = excluded[col.name]
|
||||
|
||||
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
|
||||
|
||||
await session.execute(upsert_stmt)
|
||||
await session.commit()
|
||||
|
||||
# fetch results
|
||||
source_names = [source.name for source in source_data_list]
|
||||
result_query = select(SourceModel).where(
|
||||
SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
|
||||
)
|
||||
result = await session.execute(result_query)
|
||||
return [source.to_pydantic() for source in result.scalars()]
|
||||
|
||||
@trace_method
|
||||
async def _upsert_sources_individually(self, source_data_list: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""Fallback to individual upserts for SQLite."""
|
||||
sources = []
|
||||
for source in source_data_list:
|
||||
# try to get existing source by name
|
||||
existing_source = await self.get_source_by_name(source.name, actor)
|
||||
if existing_source:
|
||||
# update existing source
|
||||
from letta.schemas.source import SourceUpdate
|
||||
|
||||
update_data = source.model_dump(exclude={"id"}, exclude_none=True)
|
||||
updated_source = await self.update_source(existing_source.id, SourceUpdate(**update_data), actor)
|
||||
sources.append(updated_source)
|
||||
else:
|
||||
# create new source
|
||||
created_source = await self.create_source(source, actor)
|
||||
sources.append(created_source)
|
||||
return sources
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
@@ -176,7 +175,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
app = AsyncFirecrawlApp(api_key=firecrawl_api_key)
|
||||
|
||||
# Process all search tasks in parallel
|
||||
search_task_coroutines = [self._process_single_search_task(app, task, limit, return_raw, api_key_source) for task in search_tasks]
|
||||
search_task_coroutines = [
|
||||
self._process_single_search_task(app, task, limit, return_raw, api_key_source, agent_state) for task in search_tasks
|
||||
]
|
||||
|
||||
# Execute all searches concurrently
|
||||
search_results = await asyncio.gather(*search_task_coroutines, return_exceptions=True)
|
||||
@@ -205,7 +206,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
|
||||
@trace_method
|
||||
async def _process_single_search_task(
|
||||
self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str
|
||||
self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str, agent_state: "AgentState"
|
||||
) -> Dict[str, Any]:
|
||||
"""Process a single search task."""
|
||||
from firecrawl import ScrapeOptions
|
||||
@@ -246,7 +247,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
for result in search_result.get("data"):
|
||||
if result.get("markdown"):
|
||||
# Create async task for OpenAI analysis
|
||||
analysis_task = self._analyze_document_with_openai(client, result["markdown"], task.query, task.question)
|
||||
analysis_task = self._analyze_document_with_openai(
|
||||
client, result["markdown"], task.query, task.question, agent_state
|
||||
)
|
||||
analysis_tasks.append(analysis_task)
|
||||
results_with_markdown.append(result)
|
||||
else:
|
||||
@@ -300,7 +303,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
return {"query": task.query, "question": task.question, "raw_results": search_result}
|
||||
|
||||
@trace_method
|
||||
async def _analyze_document_with_openai(self, client, markdown_content: str, query: str, question: str) -> Optional[DocumentAnalysis]:
|
||||
async def _analyze_document_with_openai(
|
||||
self, client, markdown_content: str, query: str, question: str, agent_state: "AgentState"
|
||||
) -> Optional[DocumentAnalysis]:
|
||||
"""Use OpenAI to analyze a document and extract relevant passages using line numbers."""
|
||||
original_length = len(markdown_content)
|
||||
|
||||
@@ -324,7 +329,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
# Time the OpenAI request
|
||||
start_time = time.time()
|
||||
|
||||
model = os.getenv(WEB_SEARCH_MODEL_ENV_VAR_NAME, WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE)
|
||||
# Check agent state env vars first, then fall back to os.getenv
|
||||
agent_state_tool_env_vars = agent_state.get_agent_env_vars_as_dict()
|
||||
model = agent_state_tool_env_vars.get(WEB_SEARCH_MODEL_ENV_VAR_NAME) or WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE
|
||||
logger.info(f"Using model {model} for web search result parsing")
|
||||
response = await client.beta.chat.completions.parse(
|
||||
model=model,
|
||||
|
||||
@@ -321,6 +321,7 @@ async def test_web_search_uses_agent_env_var_model():
|
||||
patch("openai.AsyncOpenAI") as mock_openai_class,
|
||||
patch("letta.services.tool_executor.builtin_tool_executor.model_settings") as mock_model_settings,
|
||||
patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}),
|
||||
patch("firecrawl.AsyncFirecrawlApp") as mock_firecrawl_class,
|
||||
):
|
||||
|
||||
# setup mocks
|
||||
@@ -330,6 +331,23 @@ async def test_web_search_uses_agent_env_var_model():
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
mock_openai_client.beta.chat.completions.parse.return_value = mock_openai_response
|
||||
|
||||
# Mock Firecrawl
|
||||
mock_firecrawl_app = AsyncMock()
|
||||
mock_firecrawl_class.return_value = mock_firecrawl_app
|
||||
|
||||
# Mock search results with markdown content
|
||||
mock_search_result = {
|
||||
"data": [
|
||||
{
|
||||
"url": "https://example.com/test",
|
||||
"title": "Test Result",
|
||||
"description": "Test description",
|
||||
"markdown": "This is test markdown content for the search result.",
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_firecrawl_app.search.return_value = mock_search_result
|
||||
|
||||
# create executor with mock dependencies
|
||||
executor = LettaBuiltinToolExecutor(
|
||||
message_manager=MagicMock(),
|
||||
|
||||
@@ -5262,6 +5262,164 @@ async def test_update_source_no_changes(server: SyncServer, default_user):
|
||||
assert updated_source.description == source.description
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_upsert_sources_async(server: SyncServer, default_user):
|
||||
"""Test bulk upserting sources."""
|
||||
sources_data = [
|
||||
PydanticSource(
|
||||
name="Bulk Source 1",
|
||||
description="First bulk source",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
PydanticSource(
|
||||
name="Bulk Source 2",
|
||||
description="Second bulk source",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
PydanticSource(
|
||||
name="Bulk Source 3",
|
||||
description="Third bulk source",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
]
|
||||
|
||||
# Bulk upsert sources
|
||||
created_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user)
|
||||
|
||||
# Verify all sources were created
|
||||
assert len(created_sources) == 3
|
||||
|
||||
# Verify source details
|
||||
created_names = {source.name for source in created_sources}
|
||||
expected_names = {"Bulk Source 1", "Bulk Source 2", "Bulk Source 3"}
|
||||
assert created_names == expected_names
|
||||
|
||||
# Verify organization assignment
|
||||
for source in created_sources:
|
||||
assert source.organization_id == default_user.organization_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_upsert_sources_name_conflict(server: SyncServer, default_user):
|
||||
"""Test bulk upserting sources with name conflicts."""
|
||||
# Create an existing source
|
||||
existing_source = await server.source_manager.create_source(
|
||||
PydanticSource(
|
||||
name="Existing Source",
|
||||
description="Already exists",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
|
||||
# Try to bulk upsert with the same name
|
||||
sources_data = [
|
||||
PydanticSource(
|
||||
name="Existing Source", # Same name as existing
|
||||
description="Updated description",
|
||||
metadata={"updated": True},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
PydanticSource(
|
||||
name="New Bulk Source",
|
||||
description="Completely new",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
]
|
||||
|
||||
# Bulk upsert should update existing and create new
|
||||
result_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user)
|
||||
|
||||
# Should return 2 sources
|
||||
assert len(result_sources) == 2
|
||||
|
||||
# Find the updated source
|
||||
updated_source = next(s for s in result_sources if s.name == "Existing Source")
|
||||
|
||||
# Verify the existing source was updated, not replaced
|
||||
assert updated_source.id == existing_source.id # ID should be preserved
|
||||
assert updated_source.description == "Updated description"
|
||||
assert updated_source.metadata == {"updated": True}
|
||||
|
||||
# Verify new source was created
|
||||
new_source = next(s for s in result_sources if s.name == "New Bulk Source")
|
||||
assert new_source.description == "Completely new"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_upsert_sources_mixed_create_update(server: SyncServer, default_user):
|
||||
"""Test bulk upserting with a mix of creates and updates."""
|
||||
# Create some existing sources
|
||||
existing1 = await server.source_manager.create_source(
|
||||
PydanticSource(
|
||||
name="Mixed Source 1",
|
||||
description="Original 1",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
existing2 = await server.source_manager.create_source(
|
||||
PydanticSource(
|
||||
name="Mixed Source 2",
|
||||
description="Original 2",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
default_user,
|
||||
)
|
||||
|
||||
# Bulk upsert with updates and new sources
|
||||
sources_data = [
|
||||
PydanticSource(
|
||||
name="Mixed Source 1", # Update existing
|
||||
description="Updated 1",
|
||||
instructions="New instructions 1",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
PydanticSource(
|
||||
name="Mixed Source 3", # Create new
|
||||
description="New 3",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
PydanticSource(
|
||||
name="Mixed Source 2", # Update existing
|
||||
description="Updated 2",
|
||||
metadata={"version": 2},
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
PydanticSource(
|
||||
name="Mixed Source 4", # Create new
|
||||
description="New 4",
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
]
|
||||
|
||||
# Perform bulk upsert
|
||||
result_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user)
|
||||
|
||||
# Should return 4 sources
|
||||
assert len(result_sources) == 4
|
||||
|
||||
# Verify updates preserved IDs
|
||||
source1 = next(s for s in result_sources if s.name == "Mixed Source 1")
|
||||
assert source1.id == existing1.id
|
||||
assert source1.description == "Updated 1"
|
||||
assert source1.instructions == "New instructions 1"
|
||||
|
||||
source2 = next(s for s in result_sources if s.name == "Mixed Source 2")
|
||||
assert source2.id == existing2.id
|
||||
assert source2.description == "Updated 2"
|
||||
assert source2.metadata == {"version": 2}
|
||||
|
||||
# Verify new sources were created
|
||||
source3 = next(s for s in result_sources if s.name == "Mixed Source 3")
|
||||
assert source3.description == "New 3"
|
||||
assert source3.id != existing1.id and source3.id != existing2.id
|
||||
|
||||
source4 = next(s for s in result_sources if s.name == "Mixed Source 4")
|
||||
assert source4.description == "New 4"
|
||||
assert source4.id != existing1.id and source4.id != existing2.id
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Source Manager Tests - Files
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user