diff --git a/letta/services/agent_file_manager.py b/letta/services/agent_file_manager.py index 37a6affd..78f34b89 100644 --- a/letta/services/agent_file_manager.py +++ b/letta/services/agent_file_manager.py @@ -386,14 +386,27 @@ class AgentFileManager: file_to_db_ids[block_schema.id] = created_block.id imported_count += 1 - # 3. Create sources (no dependencies) - for source_schema in schema.sources: - # Convert SourceSchema back to Source - source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"}) - source = Source(**source_data) - created_source = await self.source_manager.create_source(source, actor) - file_to_db_ids[source_schema.id] = created_source.id - imported_count += 1 + # 3. Create sources (no dependencies) - using bulk upsert for efficiency + if schema.sources: + # convert source schemas to pydantic sources + pydantic_sources = [] + for source_schema in schema.sources: + source_data = source_schema.model_dump(exclude={"id", "embedding", "embedding_chunk_size"}) + pydantic_sources.append(Source(**source_data)) + + # bulk upsert all sources at once + created_sources = await self.source_manager.bulk_upsert_sources_async(pydantic_sources, actor) + + # map file ids to database ids + # note: sources are matched by name during upsert, so we need to match by name here too + created_sources_by_name = {source.name: source for source in created_sources} + for source_schema in schema.sources: + created_source = created_sources_by_name.get(source_schema.name) + if created_source: + file_to_db_ids[source_schema.id] = created_source.id + imported_count += 1 + else: + logger.warning(f"Source {source_schema.name} was not created during bulk upsert") # 4. Create files (depends on sources) for file_schema in schema.files: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index e3287130..96786919 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -58,6 +58,123 @@ class SourceManager: await source.create_async(session, actor=actor) return source.to_pydantic() + @enforce_types + @trace_method + async def bulk_upsert_sources_async(self, pydantic_sources: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]: + """ + Bulk create or update multiple sources in a single database transaction. + + Uses optimized PostgreSQL bulk upsert when available, falls back to individual + upserts for SQLite. This is much more efficient than calling create_source + in a loop. + + IMPORTANT BEHAVIOR NOTES: + - Sources are matched by (name, organization_id) unique constraint, NOT by ID + - If a source with the same name already exists for the organization, it will be updated + regardless of any ID provided in the input source + - The existing source's ID is preserved during updates + - If you provide a source with an explicit ID but a name that matches an existing source, + the existing source will be updated and the provided ID will be ignored + - This matches the behavior of create_source which also checks by ID first + + PostgreSQL optimization: + - Uses native ON CONFLICT (name, organization_id) DO UPDATE for atomic upserts + - All sources are processed in a single SQL statement for maximum efficiency + + SQLite fallback: + - Falls back to individual create_source calls + - Still benefits from batched transaction handling + + Args: + pydantic_sources: List of sources to create or update + actor: User performing the action + + Returns: + List of created/updated sources + """ + if not pydantic_sources: + return [] + + from letta.settings import settings + + if settings.letta_pg_uri_no_default: + # use optimized postgresql bulk upsert + async with db_registry.async_session() as session: + return await self._bulk_upsert_postgresql(session, pydantic_sources, actor) + else: + # fallback to individual upserts for sqlite + return await self._upsert_sources_individually(pydantic_sources, actor) + + @trace_method + async def _bulk_upsert_postgresql(self, session, source_data_list: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]: + """Hyper-optimized PostgreSQL bulk upsert using ON CONFLICT DO UPDATE.""" + from sqlalchemy import func, select + from sqlalchemy.dialects.postgresql import insert + + # prepare data for bulk insert + table = SourceModel.__table__ + valid_columns = {col.name for col in table.columns} + + insert_data = [] + for source in source_data_list: + source_dict = source.model_dump(to_orm=True) + # set created/updated by fields + + if actor: + source_dict["_created_by_id"] = actor.id + source_dict["_last_updated_by_id"] = actor.id + source_dict["organization_id"] = actor.organization_id + + # filter to only include columns that exist in the table + filtered_dict = {k: v for k, v in source_dict.items() if k in valid_columns} + insert_data.append(filtered_dict) + + # use postgresql's native bulk upsert + stmt = insert(table).values(insert_data) + + # on conflict, update all columns except id, created_at, and _created_by_id + excluded = stmt.excluded + update_dict = {} + for col in table.columns: + if col.name not in ("id", "created_at", "_created_by_id"): + if col.name == "updated_at": + update_dict[col.name] = func.now() + else: + update_dict[col.name] = excluded[col.name] + + upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict) + + await session.execute(upsert_stmt) + await session.commit() + + # fetch results + source_names = [source.name for source in source_data_list] + result_query = select(SourceModel).where( + SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False + ) + result = await session.execute(result_query) + return [source.to_pydantic() for source in result.scalars()] + + @trace_method + async def _upsert_sources_individually(self, source_data_list: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]: + """Fallback to individual upserts for SQLite.""" + sources = [] + for source in source_data_list: + # try to get existing source by name + existing_source = await self.get_source_by_name(source.name, actor) + if existing_source: + # update existing source + from letta.schemas.source import SourceUpdate + + update_data = source.model_dump(exclude={"id"}, exclude_none=True) + updated_source = await self.update_source(existing_source.id, SourceUpdate(**update_data), actor) + sources.append(updated_source) + else: + # create new source + created_source = await self.create_source(source, actor) + sources.append(created_source) + return sources + @enforce_types @trace_method async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource: diff --git a/letta/services/tool_executor/builtin_tool_executor.py b/letta/services/tool_executor/builtin_tool_executor.py index d8cd24a6..e7415cac 100644 --- a/letta/services/tool_executor/builtin_tool_executor.py +++ b/letta/services/tool_executor/builtin_tool_executor.py @@ -1,6 +1,5 @@ import asyncio import json -import os import time from typing import Any, Dict, List, Literal, Optional @@ -176,7 +175,9 @@ class LettaBuiltinToolExecutor(ToolExecutor): app = AsyncFirecrawlApp(api_key=firecrawl_api_key) # Process all search tasks in parallel - search_task_coroutines = [self._process_single_search_task(app, task, limit, return_raw, api_key_source) for task in search_tasks] + search_task_coroutines = [ + self._process_single_search_task(app, task, limit, return_raw, api_key_source, agent_state) for task in search_tasks + ] # Execute all searches concurrently search_results = await asyncio.gather(*search_task_coroutines, return_exceptions=True) @@ -205,7 +206,7 @@ class LettaBuiltinToolExecutor(ToolExecutor): @trace_method async def _process_single_search_task( - self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str + self, app: "AsyncFirecrawlApp", task: SearchTask, limit: int, return_raw: bool, api_key_source: str, agent_state: "AgentState" ) -> Dict[str, Any]: """Process a single search task.""" from firecrawl import ScrapeOptions @@ -246,7 +247,9 @@ class LettaBuiltinToolExecutor(ToolExecutor): for result in search_result.get("data"): if result.get("markdown"): # Create async task for OpenAI analysis - analysis_task = self._analyze_document_with_openai(client, result["markdown"], task.query, task.question) + analysis_task = self._analyze_document_with_openai( + client, result["markdown"], task.query, task.question, agent_state + ) analysis_tasks.append(analysis_task) results_with_markdown.append(result) else: @@ -300,7 +303,9 @@ class LettaBuiltinToolExecutor(ToolExecutor): return {"query": task.query, "question": task.question, "raw_results": search_result} @trace_method - async def _analyze_document_with_openai(self, client, markdown_content: str, query: str, question: str) -> Optional[DocumentAnalysis]: + async def _analyze_document_with_openai( + self, client, markdown_content: str, query: str, question: str, agent_state: "AgentState" + ) -> Optional[DocumentAnalysis]: """Use OpenAI to analyze a document and extract relevant passages using line numbers.""" original_length = len(markdown_content) @@ -324,7 +329,9 @@ class LettaBuiltinToolExecutor(ToolExecutor): # Time the OpenAI request start_time = time.time() - model = os.getenv(WEB_SEARCH_MODEL_ENV_VAR_NAME, WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE) + # Check agent state env vars first, then fall back to os.getenv + agent_state_tool_env_vars = agent_state.get_agent_env_vars_as_dict() + model = agent_state_tool_env_vars.get(WEB_SEARCH_MODEL_ENV_VAR_NAME) or WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE logger.info(f"Using model {model} for web search result parsing") response = await client.beta.chat.completions.parse( model=model, diff --git a/tests/integration_test_builtin_tools.py b/tests/integration_test_builtin_tools.py index 4e853825..83af3608 100644 --- a/tests/integration_test_builtin_tools.py +++ b/tests/integration_test_builtin_tools.py @@ -321,6 +321,7 @@ async def test_web_search_uses_agent_env_var_model(): patch("openai.AsyncOpenAI") as mock_openai_class, patch("letta.services.tool_executor.builtin_tool_executor.model_settings") as mock_model_settings, patch.dict(os.environ, {WEB_SEARCH_MODEL_ENV_VAR_NAME: "gpt-4o"}), + patch("firecrawl.AsyncFirecrawlApp") as mock_firecrawl_class, ): # setup mocks @@ -330,6 +331,23 @@ async def test_web_search_uses_agent_env_var_model(): mock_openai_class.return_value = mock_openai_client mock_openai_client.beta.chat.completions.parse.return_value = mock_openai_response + # Mock Firecrawl + mock_firecrawl_app = AsyncMock() + mock_firecrawl_class.return_value = mock_firecrawl_app + + # Mock search results with markdown content + mock_search_result = { + "data": [ + { + "url": "https://example.com/test", + "title": "Test Result", + "description": "Test description", + "markdown": "This is test markdown content for the search result.", + } + ] + } + mock_firecrawl_app.search.return_value = mock_search_result + # create executor with mock dependencies executor = LettaBuiltinToolExecutor( message_manager=MagicMock(), diff --git a/tests/test_managers.py b/tests/test_managers.py index 5caf25e4..ffc583c0 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5262,6 +5262,164 @@ async def test_update_source_no_changes(server: SyncServer, default_user): assert updated_source.description == source.description +@pytest.mark.asyncio +async def test_bulk_upsert_sources_async(server: SyncServer, default_user): + """Test bulk upserting sources.""" + sources_data = [ + PydanticSource( + name="Bulk Source 1", + description="First bulk source", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + PydanticSource( + name="Bulk Source 2", + description="Second bulk source", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + PydanticSource( + name="Bulk Source 3", + description="Third bulk source", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + ] + + # Bulk upsert sources + created_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user) + + # Verify all sources were created + assert len(created_sources) == 3 + + # Verify source details + created_names = {source.name for source in created_sources} + expected_names = {"Bulk Source 1", "Bulk Source 2", "Bulk Source 3"} + assert created_names == expected_names + + # Verify organization assignment + for source in created_sources: + assert source.organization_id == default_user.organization_id + + +@pytest.mark.asyncio +async def test_bulk_upsert_sources_name_conflict(server: SyncServer, default_user): + """Test bulk upserting sources with name conflicts.""" + # Create an existing source + existing_source = await server.source_manager.create_source( + PydanticSource( + name="Existing Source", + description="Already exists", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + default_user, + ) + + # Try to bulk upsert with the same name + sources_data = [ + PydanticSource( + name="Existing Source", # Same name as existing + description="Updated description", + metadata={"updated": True}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + PydanticSource( + name="New Bulk Source", + description="Completely new", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + ] + + # Bulk upsert should update existing and create new + result_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user) + + # Should return 2 sources + assert len(result_sources) == 2 + + # Find the updated source + updated_source = next(s for s in result_sources if s.name == "Existing Source") + + # Verify the existing source was updated, not replaced + assert updated_source.id == existing_source.id # ID should be preserved + assert updated_source.description == "Updated description" + assert updated_source.metadata == {"updated": True} + + # Verify new source was created + new_source = next(s for s in result_sources if s.name == "New Bulk Source") + assert new_source.description == "Completely new" + + +@pytest.mark.asyncio +async def test_bulk_upsert_sources_mixed_create_update(server: SyncServer, default_user): + """Test bulk upserting with a mix of creates and updates.""" + # Create some existing sources + existing1 = await server.source_manager.create_source( + PydanticSource( + name="Mixed Source 1", + description="Original 1", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + default_user, + ) + existing2 = await server.source_manager.create_source( + PydanticSource( + name="Mixed Source 2", + description="Original 2", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + default_user, + ) + + # Bulk upsert with updates and new sources + sources_data = [ + PydanticSource( + name="Mixed Source 1", # Update existing + description="Updated 1", + instructions="New instructions 1", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + PydanticSource( + name="Mixed Source 3", # Create new + description="New 3", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + PydanticSource( + name="Mixed Source 2", # Update existing + description="Updated 2", + metadata={"version": 2}, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + PydanticSource( + name="Mixed Source 4", # Create new + description="New 4", + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + ] + + # Perform bulk upsert + result_sources = await server.source_manager.bulk_upsert_sources_async(sources_data, default_user) + + # Should return 4 sources + assert len(result_sources) == 4 + + # Verify updates preserved IDs + source1 = next(s for s in result_sources if s.name == "Mixed Source 1") + assert source1.id == existing1.id + assert source1.description == "Updated 1" + assert source1.instructions == "New instructions 1" + + source2 = next(s for s in result_sources if s.name == "Mixed Source 2") + assert source2.id == existing2.id + assert source2.description == "Updated 2" + assert source2.metadata == {"version": 2} + + # Verify new sources were created + source3 = next(s for s in result_sources if s.name == "Mixed Source 3") + assert source3.description == "New 3" + assert source3.id != existing1.id and source3.id != existing2.id + + source4 = next(s for s in result_sources if s.name == "Mixed Source 4") + assert source4.description == "New 4" + assert source4.id != existing1.id and source4.id != existing2.id + + # ====================================================================================================================== # Source Manager Tests - Files # ======================================================================================================================