feat: Add bulk source upsert mechanism (#3456)

This commit is contained in:
Matthew Zhou
2025-07-21 14:52:20 -07:00
committed by GitHub
parent dc8b8bf4e4
commit 3cf9580d9b
5 changed files with 327 additions and 14 deletions

View File

@@ -386,14 +386,27 @@ class AgentFileManager:
file_to_db_ids[block_schema.id] = created_block.id
imported_count += 1
# 3. Create sources (no dependencies)
for source_schema in schema.sources:
# Convert SourceSchema back to Source
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
source = Source(**source_data)
created_source = await self.source_manager.create_source(source, actor)
file_to_db_ids[source_schema.id] = created_source.id
imported_count += 1
# 3. Create sources (no dependencies) - using bulk upsert for efficiency
if schema.sources:
# convert source schemas to pydantic sources
pydantic_sources = []
for source_schema in schema.sources:
source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"})
pydantic_sources.append(Source(**source_data))
# bulk upsert all sources at once
created_sources = await self.source_manager.bulk_upsert_sources_async(pydantic_sources, actor)
# map file ids to database ids
# note: sources are matched by name during upsert, so we need to match by name here too
created_sources_by_name = {source.name: source for source in created_sources}
for source_schema in schema.sources:
created_source = created_sources_by_name.get(source_schema.name)
if created_source:
file_to_db_ids[source_schema.id] = created_source.id
imported_count += 1
else:
logger.warning(f"Source {source_schema.name} was not created during bulk upsert")
# 4. Create files (depends on sources)
for file_schema in schema.files:

View File

@@ -58,6 +58,123 @@ class SourceManager:
await source.create_async(session, actor=actor)
return source.to_pydantic()
@enforce_types
@trace_method
async def bulk_upsert_sources_async(self, pydantic_sources: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
"""
Bulk create or update multiple sources in a single database transaction.
Uses optimized PostgreSQL bulk upsert when available, falls back to individual
upserts for SQLite. This is much more efficient than calling create_source
in a loop.
IMPORTANT BEHAVIOR NOTES:
- Sources are matched by (name, organization_id) unique constraint, NOT by ID
- If a source with the same name already exists for the organization, it will be updated
regardless of any ID provided in the input source
- The existing source's ID is preserved during updates
- If you provide a source with an explicit ID but a name that matches an existing source,
the existing source will be updated and the provided ID will be ignored
- This matches the behavior of create_source which also checks by ID first
PostgreSQL optimization:
- Uses native ON CONFLICT (name, organization_id) DO UPDATE for atomic upserts
- All sources are processed in a single SQL statement for maximum efficiency
SQLite fallback:
- Falls back to individual create_source calls
- Still benefits from batched transaction handling
Args:
pydantic_sources: List of sources to create or update
actor: User performing the action
Returns:
List of created/updated sources
"""
if not pydantic_sources:
return []
from letta.settings import settings
if settings.letta_pg_uri_no_default:
# use optimized postgresql bulk upsert
async with db_registry.async_session() as session:
return await self._bulk_upsert_postgresql(session, pydantic_sources, actor)
else:
# fallback to individual upserts for sqlite
return await self._upsert_sources_individually(pydantic_sources, actor)
@trace_method
async def _bulk_upsert_postgresql(self, session, source_data_list: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
"""Hyper-optimized PostgreSQL bulk upsert using ON CONFLICT DO UPDATE."""
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert
# prepare data for bulk insert
table = SourceModel.__table__
valid_columns = {col.name for col in table.columns}
insert_data = []
for source in source_data_list:
source_dict = source.model_dump(to_orm=True)
# set created/updated by fields
if actor:
source_dict["_created_by_id"] = actor.id
source_dict["_last_updated_by_id"] = actor.id
source_dict["organization_id"] = actor.organization_id
# filter to only include columns that exist in the table
filtered_dict = {k: v for k, v in source_dict.items() if k in valid_columns}
insert_data.append(filtered_dict)
# use postgresql's native bulk upsert
stmt = insert(table).values(insert_data)
# on conflict, update all columns except id, created_at, and _created_by_id
excluded = stmt.excluded
update_dict = {}
for col in table.columns:
if col.name not in ("id", "created_at", "_created_by_id"):
if col.name == "updated_at":
update_dict[col.name] = func.now()
else:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
await session.execute(upsert_stmt)
await session.commit()
# fetch results
source_names = [source.name for source in source_data_list]
result_query = select(SourceModel).where(
SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
)
result = await session.execute(result_query)
return [source.to_pydantic() for source in result.scalars()]
@trace_method
async def _upsert_sources_individually(self, source_data_list: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
"""Fallback to individual upserts for SQLite."""
sources = []
for source in source_data_list:
# try to get existing source by name
existing_source = await self.get_source_by_name(source.name, actor)
if existing_source:
# update existing source
from letta.schemas.source import SourceUpdate
update_data = source.model_dump(exclude={"id"}, exclude_none=True)
updated_source = await self.update_source(existing_source.id, SourceUpdate(**update_data), actor)
sources.append(updated_source)
else:
# create new source
created_source = await self.create_source(source, actor)
sources.append(created_source)
return sources
@enforce_types
@trace_method
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:

View File

@@ -1,6 +1,5 @@
import asyncio
import json
import os
import time
from typing import Any, Dict, List, Literal, Optional
@@ -176,7 +175,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
app = AsyncFirecrawlApp(api_key=firecrawl_api_key)
# Process all search tasks in parallel
search_task_coroutines = [self._process_single_search_task(app, task, limit, return_raw, api_key_source) for task in search_tasks]
search_task_coroutines = [
self._process_single_search_task(app, task, limit, return_raw, api_key_source, agent_state) for task in search_tasks
]
# Execute all searches concurrently
search_results = await asyncio.gather(*search_task_coroutines, return_exceptions=True)
@@ -205,7 +206,7 @@ class LettaBuiltinToolExecutor(ToolExecutor):
@trace_method
async def _process_single_search_task(
self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str
self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str, agent_state: "AgentState"
) -> Dict[str, Any]:
"""Process a single search task."""
from firecrawl import ScrapeOptions
@@ -246,7 +247,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
for result in search_result.get("data"):
if result.get("markdown"):
# Create async task for OpenAI analysis
analysis_task = self._analyze_document_with_openai(client, result["markdown"], task.query, task.question)
analysis_task = self._analyze_document_with_openai(
client, result["markdown"], task.query, task.question, agent_state
)
analysis_tasks.append(analysis_task)
results_with_markdown.append(result)
else:
@@ -300,7 +303,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
return {"query": task.query, "question": task.question, "raw_results": search_result}
@trace_method
async def _analyze_document_with_openai(self, client, markdown_content: str, query: str, question: str) -> Optional[DocumentAnalysis]:
async def _analyze_document_with_openai(
self, client, markdown_content: str, query: str, question: str, agent_state: "AgentState"
) -> Optional[DocumentAnalysis]:
"""Use OpenAI to analyze a document and extract relevant passages using line numbers."""
original_length = len(markdown_content)
@@ -324,7 +329,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
# Time the OpenAI request
start_time = time.time()
model = os.getenv(WEB_SEARCH_MODEL_ENV_VAR_NAME, WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE)
# Check agent state env vars first, then fall back to os.getenv
agent_state_tool_env_vars = agent_state.get_agent_env_vars_as_dict()
model = agent_state_tool_env_vars.get(WEB_SEARCH_MODEL_ENV_VAR_NAME) or WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE
logger.info(f"Using model {model} for web search result parsing")
response = await client.beta.chat.completions.parse(
model=model,

View File

@@ -321,6 +321,7 @@ async def test_web_search_uses_agent_env_var_model():
patch("openai.AsyncOpenAI") as mock_openai_class,
patch("letta.services.tool_executor.builtin_tool_executor.model_settings") as mock_model_settings,
patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}),
patch("firecrawl.AsyncFirecrawlApp") as mock_firecrawl_class,
):
# setup mocks
@@ -330,6 +331,23 @@ async def test_web_search_uses_agent_env_var_model():
mock_openai_class.return_value = mock_openai_client
mock_openai_client.beta.chat.completions.parse.return_value = mock_openai_response
# Mock Firecrawl
mock_firecrawl_app = AsyncMock()
mock_firecrawl_class.return_value = mock_firecrawl_app
# Mock search results with markdown content
mock_search_result = {
"data": [
{
"url": "https://example.com/test",
"title": "Test Result",
"description": "Test description",
"markdown": "This is test markdown content for the search result.",
}
]
}
mock_firecrawl_app.search.return_value = mock_search_result
# create executor with mock dependencies
executor = LettaBuiltinToolExecutor(
message_manager=MagicMock(),

View File

@@ -5262,6 +5262,164 @@ async def test_update_source_no_changes(server: SyncServer, default_user):
assert updated_source.description == source.description
@pytest.mark.asyncio
async def test_bulk_upsert_sources_async(server: SyncServer, default_user):
"""Test bulk upserting sources."""
sources_data = [
PydanticSource(
name="Bulk Source 1",
description="First bulk source",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
PydanticSource(
name="Bulk Source 2",
description="Second bulk source",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
PydanticSource(
name="Bulk Source 3",
description="Third bulk source",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
]
# Bulk upsert sources
created_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user)
# Verify all sources were created
assert len(created_sources) == 3
# Verify source details
created_names = {source.name for source in created_sources}
expected_names = {"Bulk Source 1", "Bulk Source 2", "Bulk Source 3"}
assert created_names == expected_names
# Verify organization assignment
for source in created_sources:
assert source.organization_id == default_user.organization_id
@pytest.mark.asyncio
async def test_bulk_upsert_sources_name_conflict(server: SyncServer, default_user):
"""Test bulk upserting sources with name conflicts."""
# Create an existing source
existing_source = await server.source_manager.create_source(
PydanticSource(
name="Existing Source",
description="Already exists",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
default_user,
)
# Try to bulk upsert with the same name
sources_data = [
PydanticSource(
name="Existing Source", # Same name as existing
description="Updated description",
metadata={"updated": True},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
PydanticSource(
name="New Bulk Source",
description="Completely new",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
]
# Bulk upsert should update existing and create new
result_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user)
# Should return 2 sources
assert len(result_sources) == 2
# Find the updated source
updated_source = next(s for s in result_sources if s.name == "Existing Source")
# Verify the existing source was updated, not replaced
assert updated_source.id == existing_source.id # ID should be preserved
assert updated_source.description == "Updated description"
assert updated_source.metadata == {"updated": True}
# Verify new source was created
new_source = next(s for s in result_sources if s.name == "New Bulk Source")
assert new_source.description == "Completely new"
@pytest.mark.asyncio
async def test_bulk_upsert_sources_mixed_create_update(server: SyncServer, default_user):
"""Test bulk upserting with a mix of creates and updates."""
# Create some existing sources
existing1 = await server.source_manager.create_source(
PydanticSource(
name="Mixed Source 1",
description="Original 1",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
default_user,
)
existing2 = await server.source_manager.create_source(
PydanticSource(
name="Mixed Source 2",
description="Original 2",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
default_user,
)
# Bulk upsert with updates and new sources
sources_data = [
PydanticSource(
name="Mixed Source 1", # Update existing
description="Updated 1",
instructions="New instructions 1",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
PydanticSource(
name="Mixed Source 3", # Create new
description="New 3",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
PydanticSource(
name="Mixed Source 2", # Update existing
description="Updated 2",
metadata={"version": 2},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
PydanticSource(
name="Mixed Source 4", # Create new
description="New 4",
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
]
# Perform bulk upsert
result_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user)
# Should return 4 sources
assert len(result_sources) == 4
# Verify updates preserved IDs
source1 = next(s for s in result_sources if s.name == "Mixed Source 1")
assert source1.id == existing1.id
assert source1.description == "Updated 1"
assert source1.instructions == "New instructions 1"
source2 = next(s for s in result_sources if s.name == "Mixed Source 2")
assert source2.id == existing2.id
assert source2.description == "Updated 2"
assert source2.metadata == {"version": 2}
# Verify new sources were created
source3 = next(s for s in result_sources if s.name == "Mixed Source 3")
assert source3.description == "New 3"
assert source3.id != existing1.id and source3.id != existing2.id
source4 = next(s for s in result_sources if s.name == "Mixed Source 4")
assert source4.description == "New 4"
assert source4.id != existing1.id and source4.id != existing2.id
# ======================================================================================================================
# Source Manager Tests - Files
# ======================================================================================================================