diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index c3df88b7..bdfbb6d1 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -260,7 +260,9 @@ async def import_agent( raise HTTPException(status_code=422, detail=f"Invalid agent file schema: {e!s}") try: - import_result = await server.agent_serialization_manager.import_file(schema=agent_schema, actor=actor) + import_result = await server.agent_serialization_manager.import_file( + schema=agent_schema, actor=actor, append_copy_suffix=append_copy_suffix, override_existing_tools=override_existing_tools + ) if not import_result.success: raise HTTPException( diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 9155049d..55b50907 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -428,6 +428,8 @@ class AgentSerializationManager: self, schema: AgentFileSchema, actor: User, + append_copy_suffix: bool = False, + override_existing_tools: bool = True, dry_run: bool = False, env_vars: Optional[Dict[str, Any]] = None, ) -> ImportResult: @@ -489,7 +491,9 @@ class AgentSerializationManager: pydantic_tools.append(Tool(**tool_schema.model_dump(exclude={"id"}))) # bulk upsert all tools at once - created_tools = await self.tool_manager.bulk_upsert_tools_async(pydantic_tools, actor) + created_tools = await self.tool_manager.bulk_upsert_tools_async( + pydantic_tools, actor, override_existing_tools=override_existing_tools + ) # map file ids to database ids # note: tools are matched by name during upsert, so we need to match by name here too @@ -611,6 +615,8 @@ class AgentSerializationManager: for agent_schema in schema.agents: # Convert AgentSchema back to CreateAgent, remapping tool/block IDs agent_data = agent_schema.model_dump(exclude={"id", "in_context_message_ids", "messages"}) + if append_copy_suffix: + agent_data["name"] = agent_data.get("name") + "_copy" # Remap tool_ids from file IDs to database IDs if agent_data.get("tool_ids"): diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 96602921..d990ff92 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -184,7 +184,9 @@ class ToolManager: @enforce_types @trace_method - async def bulk_upsert_tools_async(self, pydantic_tools: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]: + async def bulk_upsert_tools_async( + self, pydantic_tools: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True + ) -> List[PydanticTool]: """ Bulk create or update multiple tools in a single database transaction. @@ -227,10 +229,10 @@ class ToolManager: if settings.letta_pg_uri_no_default: # use optimized postgresql bulk upsert async with db_registry.async_session() as session: - return await self._bulk_upsert_postgresql(session, pydantic_tools, actor) + return await self._bulk_upsert_postgresql(session, pydantic_tools, actor, override_existing_tools) else: # fallback to individual upserts for sqlite - return await self._upsert_tools_individually(pydantic_tools, actor) + return await self._upsert_tools_individually(pydantic_tools, actor, override_existing_tools) @enforce_types @trace_method @@ -784,8 +786,10 @@ class ToolManager: return await self._upsert_tools_individually(tool_data_list, actor) @trace_method - async def _bulk_upsert_postgresql(self, session, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]: - """hyper-optimized postgresql bulk upsert using on_conflict_do_update.""" + async def _bulk_upsert_postgresql( + self, session, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True + ) -> List[PydanticTool]: + """hyper-optimized postgresql bulk upsert using on_conflict_do_update or on_conflict_do_nothing.""" from sqlalchemy import func, select from sqlalchemy.dialects.postgresql import insert @@ -809,32 +813,51 @@ class ToolManager: # use postgresql's native bulk upsert stmt = insert(table).values(insert_data) - # on conflict, update all columns except id, created_at, and _created_by_id - excluded = stmt.excluded - update_dict = {} - for col in table.columns: - if col.name not in ("id", "created_at", "_created_by_id"): - if col.name == "updated_at": - update_dict[col.name] = func.now() - else: - update_dict[col.name] = excluded[col.name] + if override_existing_tools: + # on conflict, update all columns except id, created_at, and _created_by_id + excluded = stmt.excluded + update_dict = {} + for col in table.columns: + if col.name not in ("id", "created_at", "_created_by_id"): + if col.name == "updated_at": + update_dict[col.name] = func.now() + else: + update_dict[col.name] = excluded[col.name] - upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict) + upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict) + else: + # on conflict, do nothing (skip existing tools) + upsert_stmt = stmt.on_conflict_do_nothing(index_elements=["name", "organization_id"]) await session.execute(upsert_stmt) await session.commit() - # fetch results + # fetch results (includes both inserted and skipped tools) tool_names = [tool.name for tool in tool_data_list] result_query = select(ToolModel).where(ToolModel.name.in_(tool_names), ToolModel.organization_id == actor.organization_id) result = await session.execute(result_query) return [tool.to_pydantic() for tool in result.scalars()] @trace_method - async def _upsert_tools_individually(self, tool_data_list: List[PydanticTool], actor: PydanticUser) -> List[PydanticTool]: + async def _upsert_tools_individually( + self, tool_data_list: List[PydanticTool], actor: PydanticUser, override_existing_tools: bool = True + ) -> List[PydanticTool]: """fallback to individual upserts for sqlite (original approach).""" tools = [] for tool in tool_data_list: - upserted_tool = await self.create_or_update_tool_async(tool, actor) - tools.append(upserted_tool) + if override_existing_tools: + # update existing tools if they exist + upserted_tool = await self.create_or_update_tool_async(tool, actor) + tools.append(upserted_tool) + else: + # skip existing tools, only create new ones + existing_tool_id = await self.get_tool_id_by_name_async(tool_name=tool.name, actor=actor) + if existing_tool_id: + # tool exists, fetch and return it without updating + existing_tool = await self.get_tool_by_id_async(existing_tool_id, actor=actor) + tools.append(existing_tool) + else: + # tool doesn't exist, create it + created_tool = await self.create_tool_async(tool, actor=actor) + tools.append(created_tool) return tools diff --git a/tests/test_managers.py b/tests/test_managers.py index f8acc485..7774c727 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3969,7 +3969,7 @@ async def test_bulk_upsert_tools_name_conflict(server: SyncServer, default_user) name="unique_name_tool", description="Original description", tags=["original"], - source_code="def unique_name_tool():\n '''Original function'''\n return 'original'", + source_code="def unique_name_tool():\n '''Original function'''\n return 'original'`", source_type="python", ) @@ -4087,6 +4087,146 @@ async def test_bulk_upsert_tools_mixed_create_update(server: SyncServer, default assert tool_2.tags == ["existing"] +@pytest.mark.asyncio +async def test_bulk_upsert_tools_override_existing_true(server: SyncServer, default_user): + """Test bulk_upsert_tools_async with override_existing_tools=True (default behavior)""" + + # create some existing tools + existing_tool = PydanticTool( + name="test_override_tool", + description="Original description", + tags=["original"], + source_code="def test_override_tool():\n '''Original'''\n return 'original'", + source_type="python", + ) + created = await server.tool_manager.create_tool_async(existing_tool, default_user) + original_id = created.id + + # prepare updated version of the tool + updated_tool = PydanticTool( + name="test_override_tool", + description="Updated description", + tags=["updated"], + source_code="def test_override_tool():\n '''Updated'''\n return 'updated'", + source_type="python", + ) + + # bulk upsert with override_existing_tools=True (default) + result = await server.tool_manager.bulk_upsert_tools_async([updated_tool], default_user, override_existing_tools=True) + + assert len(result) == 1 + assert result[0].id == original_id # id should remain the same + assert result[0].description == "Updated description" # description should be updated + assert result[0].tags == ["updated"] # tags should be updated + + # verify the tool was actually updated in the database + fetched = await server.tool_manager.get_tool_by_id_async(original_id, default_user) + assert fetched.description == "Updated description" + assert fetched.tags == ["updated"] + + +@pytest.mark.asyncio +async def test_bulk_upsert_tools_override_existing_false(server: SyncServer, default_user): + """Test bulk_upsert_tools_async with override_existing_tools=False (skip existing)""" + + # create some existing tools + existing_tool = PydanticTool( + name="test_no_override_tool", + description="Original description", + tags=["original"], + source_code="def test_no_override_tool():\n '''Original'''\n return 'original'", + source_type="python", + ) + created = await server.tool_manager.create_tool_async(existing_tool, default_user) + original_id = created.id + + # prepare updated version of the tool + updated_tool = PydanticTool( + name="test_no_override_tool", + description="Should not be updated", + tags=["should_not_update"], + source_code="def test_no_override_tool():\n '''Should not update'''\n return 'should_not_update'", + source_type="python", + ) + + # bulk upsert with override_existing_tools=False + result = await server.tool_manager.bulk_upsert_tools_async([updated_tool], default_user, override_existing_tools=False) + + assert len(result) == 1 + assert result[0].id == original_id # id should remain the same + assert result[0].description == "Original description" # description should NOT be updated + assert result[0].tags == ["original"] # tags should NOT be updated + + # verify the tool was NOT updated in the database + fetched = await server.tool_manager.get_tool_by_id_async(original_id, default_user) + assert fetched.description == "Original description" + assert fetched.tags == ["original"] + + +@pytest.mark.asyncio +async def test_bulk_upsert_tools_override_mixed_scenario(server: SyncServer, default_user): + """Test bulk_upsert_tools_async with override_existing_tools=False in mixed create/update scenario""" + + # create some existing tools + existing_tools = [] + for i in range(2): + tool = PydanticTool( + name=f"mixed_existing_{i}", + description=f"Original {i}", + tags=["original"], + source_code=f"def mixed_existing_{i}():\n '''Original {i}'''\n return 'original_{i}'", + source_type="python", + ) + created = await server.tool_manager.create_tool_async(tool, default_user) + existing_tools.append(created) + + # prepare bulk tools: 2 updates (that should be skipped) + 3 new creations + bulk_tools = [] + + # these should be skipped when override_existing_tools=False + for i in range(2): + bulk_tools.append( + PydanticTool( + name=f"mixed_existing_{i}", + description=f"Should not update {i}", + tags=["should_not_update"], + source_code=f"def mixed_existing_{i}():\n '''Should not update {i}'''\n return 'should_not_update_{i}'", + source_type="python", + ) + ) + + # these should be created + for i in range(3): + bulk_tools.append( + PydanticTool( + name=f"mixed_new_{i}", + description=f"New tool {i}", + tags=["new"], + source_code=f"def mixed_new_{i}():\n '''New {i}'''\n return 'new_{i}'", + source_type="python", + ) + ) + + # bulk upsert with override_existing_tools=False + result = await server.tool_manager.bulk_upsert_tools_async(bulk_tools, default_user, override_existing_tools=False) + + assert len(result) == 5 # 2 existing (not updated) + 3 new + + # verify existing tools were NOT updated + for i in range(2): + tool = await server.tool_manager.get_tool_by_name_async(f"mixed_existing_{i}", default_user) + assert tool.description == f"Original {i}" # should remain original + assert tool.tags == ["original"] # should remain original + assert tool.id == existing_tools[i].id # id should remain same + + # verify new tools were created + for i in range(3): + new_tool = await server.tool_manager.get_tool_by_name_async(f"mixed_new_{i}", default_user) + assert new_tool is not None + assert new_tool.description == f"New tool {i}" + assert new_tool.tags == ["new"] + + @pytest.mark.asyncio async def test_create_tool_with_pip_requirements(server: SyncServer, default_user, default_organization): def test_tool_with_deps():