diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 7ae0a40b..b6ea04a7 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -293,7 +293,9 @@ async def attach_tool( Attach a tool to an agent. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor) + await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor) + # TODO: Unfortunately we need this to preserve our current API behavior + return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) @router.patch("/{agent_id}/tools/detach/{tool_id}", response_model=AgentState, operation_id="detach_tool") @@ -307,7 +309,9 @@ async def detach_tool( Detach a tool from an agent. """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor) + await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor) + # TODO: Unfortunately we need this to preserve our current API behavior + return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) @router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent") @@ -327,7 +331,8 @@ async def attach_source( agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=actor) files = await server.file_manager.list_files(source_id, actor, include_content=True) - await server.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor) + if files: + await server.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor) if agent_state.enable_sleeptime: source = await server.source_manager.get_source_by_id(source_id=source_id) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index dec9cbfb..8bef6e4a 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -2584,7 +2584,7 @@ class AgentManager: @enforce_types @trace_method - async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: + async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None: """ Attaches a tool to an agent. @@ -2601,22 +2601,112 @@ class AgentManager: """ async with db_registry.async_session() as session: # Verify the agent exists and user has permission to access it - agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + await validate_agent_exists_async(session, agent_id, actor) - # Use the _process_relationship helper to attach the tool - await _process_relationship_async( - session=session, - agent=agent, - relationship_name="tools", - model_class=ToolModel, - item_ids=[tool_id], - allow_partial=False, # Ensure the tool exists - replace=False, # Extend the existing tools + # verify tool exists and belongs to organization in a single query with the insert + # first, check if tool exists with correct organization + tool_check_query = select(func.count(ToolModel.id)).where( + ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id ) + tool_result = await session.execute(tool_check_query) + if tool_result.scalar() == 0: + raise NoResultFound(f"Tool with id={tool_id} not found in organization={actor.organization_id}") - # Commit and refresh the agent - await agent.update_async(session, actor=actor) - return await agent.to_pydantic_async() + # use postgresql on conflict or mysql on duplicate key update for atomic operation + if settings.letta_pg_uri_no_default: + from sqlalchemy.dialects.postgresql import insert as pg_insert + + insert_stmt = pg_insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id) + # on conflict do nothing - silently ignore if already exists + insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"]) + result = await session.execute(insert_stmt) + if result.rowcount == 0: + logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}") + else: + # for sqlite/mysql, check then insert + existing_query = ( + select(func.count()).select_from(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id) + ) + existing_result = await session.execute(existing_query) + if existing_result.scalar() == 0: + insert_stmt = insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id) + await session.execute(insert_stmt) + else: + logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}") + + await session.commit() + + @enforce_types + @trace_method + async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None: + """ + Efficiently attaches multiple tools to an agent in a single operation. + + Args: + agent_id: ID of the agent to attach the tools to. + tool_ids: List of tool IDs to attach. + actor: User performing the action. + + Raises: + NoResultFound: If the agent or any tool is not found. + """ + if not tool_ids: + # no tools to attach, nothing to do + return + + async with db_registry.async_session() as session: + # Verify the agent exists and user has permission to access it + await validate_agent_exists_async(session, agent_id, actor) + + # verify all tools exist and belong to organization in a single query + tool_check_query = select(func.count(ToolModel.id)).where( + ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id + ) + tool_result = await session.execute(tool_check_query) + found_count = tool_result.scalar() + + if found_count != len(tool_ids): + # find which tools are missing for better error message + existing_query = select(ToolModel.id).where(ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id) + existing_result = await session.execute(existing_query) + existing_ids = {row[0] for row in existing_result} + missing_ids = set(tool_ids) - existing_ids + raise NoResultFound(f"Tools with ids={missing_ids} not found in organization={actor.organization_id}") + + if settings.letta_pg_uri_no_default: + from sqlalchemy.dialects.postgresql import insert as pg_insert + + # prepare bulk values + values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in tool_ids] + + # bulk insert with on conflict do nothing + insert_stmt = pg_insert(ToolsAgents).values(values) + insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"]) + result = await session.execute(insert_stmt) + logger.info( + f"Attached {result.rowcount} new tools to agent {agent_id} (skipped {len(tool_ids) - result.rowcount} already attached)" + ) + else: + # for sqlite/mysql, first check which tools are already attached + existing_query = select(ToolsAgents.tool_id).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids)) + existing_result = await session.execute(existing_query) + already_attached = {row[0] for row in existing_result} + + # only insert tools that aren't already attached + new_tool_ids = [tid for tid in tool_ids if tid not in already_attached] + + if new_tool_ids: + # bulk insert new attachments + values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in new_tool_ids] + insert_stmt = insert(ToolsAgents).values(values) + await session.execute(insert_stmt) + logger.info( + f"Attached {len(new_tool_ids)} new tools to agent {agent_id} (skipped {len(already_attached)} already attached)" + ) + else: + logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}") + + await session.commit() @enforce_types @trace_method @@ -2625,7 +2715,7 @@ class AgentManager: Attaches missing core file tools to an agent. Args: - agent_id: ID of the agent to attach the tools to. + agent_state: The current agent state with tools already loaded. actor: User performing the action. Raises: @@ -2634,21 +2724,50 @@ class AgentManager: Returns: PydanticAgentState: The updated agent state. """ - # Check if the agent is missing any files tools - core_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} - missing_tool_names = set(FILES_TOOLS).difference(core_tool_names) + # get current file tools attached to the agent + attached_file_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} - for tool_name in missing_tool_names: - tool_id = await self.tool_manager.get_tool_id_by_name_async(tool_name=tool_name, actor=actor) + # determine which file tools are missing + missing_tool_names = set(FILES_TOOLS) - attached_file_tool_names - # TODO: This is hacky and deserves a rethink - how do we keep all the base tools available in every org always? - if not tool_id: - await self.tool_manager.upsert_base_tools_async(actor=actor, allowed_types={ToolType.LETTA_FILES_CORE}) + if not missing_tool_names: + # agent already has all file tools + return agent_state - # TODO: Inefficient - I think this re-retrieves the agent_state? - agent_state = await self.attach_tool_async(agent_id=agent_state.id, tool_id=tool_id, actor=actor) + # get full tool objects for all missing file tools in one query + async with db_registry.async_session() as session: + query = select(ToolModel).where( + ToolModel.name.in_(missing_tool_names), + ToolModel.organization_id == actor.organization_id, + ToolModel.tool_type == ToolType.LETTA_FILES_CORE, + ) + result = await session.execute(query) + found_tool_models = result.scalars().all() - return agent_state + if not found_tool_models: + logger.warning(f"No file tools found for organization {actor.organization_id}. Expected tools: {missing_tool_names}") + return agent_state + + # convert to pydantic tools + found_tools = [tool.to_pydantic() for tool in found_tool_models] + found_tool_names = {tool.name for tool in found_tools} + + # log if any expected tools weren't found + still_missing = missing_tool_names - found_tool_names + if still_missing: + logger.warning(f"File tools {still_missing} not found in organization {actor.organization_id}") + + # extract tool IDs for bulk attach + tool_ids_to_attach = [tool.id for tool in found_tools] + + # bulk attach all found file tools + await self.bulk_attach_tools_async(agent_id=agent_state.id, tool_ids=tool_ids_to_attach, actor=actor) + + # create a shallow copy with updated tools list to avoid modifying input + agent_state_dict = agent_state.model_dump() + agent_state_dict["tools"] = agent_state.tools + found_tools + + return PydanticAgentState(**agent_state_dict) @enforce_types @trace_method @@ -2657,25 +2776,30 @@ class AgentManager: Detach all core file tools from an agent. Args: - agent_id: ID of the agent to detach the tools from. + agent_state: The current agent state with tools already loaded. actor: User performing the action. Raises: - NoResultFound: If the agent or tool is not found. + NoResultFound: If the agent is not found. Returns: PydanticAgentState: The updated agent state. """ - # Check if the agent is missing any files tools - core_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} + # extract file tool IDs directly from agent_state.tools + file_tool_ids = [tool.id for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE] - for tool_name in core_tool_names: - tool_id = await self.tool_manager.get_tool_id_by_name_async(tool_name=tool_name, actor=actor) + if not file_tool_ids: + # no file tools to detach + return agent_state - # TODO: Inefficient - I think this re-retrieves the agent_state? - agent_state = await self.detach_tool_async(agent_id=agent_state.id, tool_id=tool_id, actor=actor) + # bulk detach all file tools in one operation + await self.bulk_detach_tools_async(agent_id=agent_state.id, tool_ids=file_tool_ids, actor=actor) - return agent_state + # create a shallow copy with updated tools list to avoid modifying input + agent_state_dict = agent_state.model_dump() + agent_state_dict["tools"] = [tool for tool in agent_state.tools if tool.tool_type != ToolType.LETTA_FILES_CORE] + + return PydanticAgentState(**agent_state_dict) @enforce_types @trace_method @@ -2713,7 +2837,7 @@ class AgentManager: @enforce_types @trace_method - async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: + async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None: """ Detaches a tool from an agent. @@ -2723,27 +2847,58 @@ class AgentManager: actor: User performing the action. Raises: - NoResultFound: If the agent or tool is not found. - - Returns: - PydanticAgentState: The updated agent state. + NoResultFound: If the agent is not found. """ async with db_registry.async_session() as session: # Verify the agent exists and user has permission to access it - agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + await validate_agent_exists_async(session, agent_id, actor) - # Filter out the tool to be detached - remaining_tools = [tool for tool in agent.tools if tool.id != tool_id] + # Delete the association directly - if it doesn't exist, rowcount will be 0 + delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id) + result = await session.execute(delete_query) - if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship + if result.rowcount == 0: logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}") + else: + logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}") - # Update the tools relationship - agent.tools = remaining_tools + await session.commit() - # Commit and refresh the agent - await agent.update_async(session, actor=actor) - return await agent.to_pydantic_async() + @enforce_types + @trace_method + async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None: + """ + Efficiently detaches multiple tools from an agent in a single operation. + + Args: + agent_id: ID of the agent to detach the tools from. + tool_ids: List of tool IDs to detach. + actor: User performing the action. + + Raises: + NoResultFound: If the agent is not found. + """ + if not tool_ids: + # no tools to detach, nothing to do + return + + async with db_registry.async_session() as session: + # Verify the agent exists and user has permission to access it + await validate_agent_exists_async(session, agent_id, actor) + + # Delete all associations in a single query + delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids)) + result = await session.execute(delete_query) + + detached_count = result.rowcount + if detached_count == 0: + logger.warning(f"No tools from list {tool_ids} were attached to agent id={agent_id}") + elif detached_count < len(tool_ids): + logger.info(f"Detached {detached_count} tools from agent {agent_id} ({len(tool_ids) - detached_count} were not attached)") + else: + logger.info(f"Detached all {detached_count} tools from agent {agent_id}") + + await session.commit() @enforce_types @trace_method diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index b4af8d4b..1885fa24 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -3,6 +3,8 @@ import os import warnings from typing import List, Optional, Set, Union +from sqlalchemy import func, select + from letta.constants import ( BASE_FUNCTION_RETURN_CHAR_LIMIT, BASE_MEMORY_TOOLS, @@ -290,6 +292,16 @@ class ToolManager: except NoResultFound: return None + @enforce_types + @trace_method + async def tool_exists_async(self, tool_id: str, actor: PydanticUser) -> bool: + """Check if a tool exists and belongs to the user's organization (lightweight check).""" + async with db_registry.async_session() as session: + query = select(func.count(ToolModel.id)).where(ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id) + result = await session.execute(query) + count = result.scalar() + return count > 0 + @enforce_types @trace_method async def list_tools_async( diff --git a/tests/test_managers.py b/tests/test_managers.py index 7eb5307d..390e2ce9 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1406,14 +1406,14 @@ async def test_list_agents_ordering_and_pagination(server: SyncServer, default_u async def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop): """Test attaching a tool to an agent.""" # Attach the tool - server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) # Verify attachment through get_agent_by_id agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) assert print_tool.id in [t.id for t in agent.tools] # Verify that attaching the same tool again doesn't cause duplication - server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) assert len([t for t in agent.tools if t.id == print_tool.id]) == 1 @@ -1422,39 +1422,125 @@ async def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_ async def test_detach_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop): """Test detaching a tool from an agent.""" # Attach the tool first - server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) # Verify it's attached agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) assert print_tool.id in [t.id for t in agent.tools] # Detach the tool - server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + await server.agent_manager.detach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) # Verify it's detached agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) assert print_tool.id not in [t.id for t in agent.tools] # Verify that detaching an already detached tool doesn't cause issues - server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + await server.agent_manager.detach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) -def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user): +@pytest.mark.asyncio +async def test_bulk_detach_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop): + """Test bulk detaching multiple tools from an agent.""" + # First attach both tools + tool_ids = [print_tool.id, other_tool.id] + await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Verify both tools are attached + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert print_tool.id in [t.id for t in agent.tools] + assert other_tool.id in [t.id for t in agent.tools] + + # Bulk detach both tools + await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Verify both tools are detached + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert print_tool.id not in [t.id for t in agent.tools] + assert other_tool.id not in [t.id for t in agent.tools] + + +@pytest.mark.asyncio +async def test_bulk_detach_tools_partial(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop): + """Test bulk detaching tools when some are not attached.""" + # Only attach one tool + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + + # Try to bulk detach both tools (one attached, one not) + tool_ids = [print_tool.id, other_tool.id] + await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Verify the attached tool was detached + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert print_tool.id not in [t.id for t in agent.tools] + assert other_tool.id not in [t.id for t in agent.tools] + + +@pytest.mark.asyncio +async def test_bulk_detach_tools_empty_list(server: SyncServer, sarah_agent, print_tool, default_user, event_loop): + """Test bulk detaching empty list of tools.""" + # Attach a tool first + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + + # Bulk detach empty list + await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=[], actor=default_user) + + # Verify the tool is still attached + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert print_tool.id in [t.id for t in agent.tools] + + +@pytest.mark.asyncio +async def test_bulk_detach_tools_idempotent(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop): + """Test that bulk detach is idempotent.""" + # Attach both tools + tool_ids = [print_tool.id, other_tool.id] + await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Bulk detach once + await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Verify tools are detached + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert len(agent.tools) == 0 + + # Bulk detach again (should be no-op, no errors) + await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Verify still no tools + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert len(agent.tools) == 0 + + +@pytest.mark.asyncio +async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user, event_loop): + """Test bulk detaching tools from a nonexistent agent.""" + nonexistent_agent_id = "nonexistent-agent-id" + tool_ids = [print_tool.id, other_tool.id] + + with pytest.raises(NoResultFound): + await server.agent_manager.bulk_detach_tools_async(agent_id=nonexistent_agent_id, tool_ids=tool_ids, actor=default_user) + + +@pytest.mark.asyncio +async def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user): """Test attaching a tool to a nonexistent agent.""" with pytest.raises(NoResultFound): - server.agent_manager.attach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user) + await server.agent_manager.attach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user) -def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user): +@pytest.mark.asyncio +async def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user): """Test attaching a nonexistent tool to an agent.""" with pytest.raises(NoResultFound): - server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user) + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user) -def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user): +@pytest.mark.asyncio +async def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user): """Test detaching a tool from a nonexistent agent.""" with pytest.raises(NoResultFound): - server.agent_manager.detach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user) + await server.agent_manager.detach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user) @pytest.mark.asyncio @@ -1465,8 +1551,8 @@ async def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool, assert len(agent.tools) == 0 # Attach tools - server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) - server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user) + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user) # List tools and verify agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user) @@ -1476,6 +1562,251 @@ async def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool, assert other_tool.id in attached_tool_ids +@pytest.mark.asyncio +async def test_bulk_attach_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop): + """Test bulk attaching multiple tools to an agent.""" + # Bulk attach both tools + tool_ids = [print_tool.id, other_tool.id] + await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Verify both tools are attached + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + attached_tool_ids = [t.id for t in agent.tools] + assert print_tool.id in attached_tool_ids + assert other_tool.id in attached_tool_ids + + +@pytest.mark.asyncio +async def test_bulk_attach_tools_with_duplicates(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop): + """Test bulk attaching tools handles duplicates correctly.""" + # First attach one tool + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + + # Bulk attach both tools (one already attached) + tool_ids = [print_tool.id, other_tool.id] + await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Verify both tools are attached and no duplicates + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + attached_tool_ids = [t.id for t in agent.tools] + assert len(attached_tool_ids) == 2 + assert print_tool.id in attached_tool_ids + assert other_tool.id in attached_tool_ids + # Ensure no duplicates + assert len(set(attached_tool_ids)) == len(attached_tool_ids) + + +@pytest.mark.asyncio +async def test_bulk_attach_tools_empty_list(server: SyncServer, sarah_agent, default_user, event_loop): + """Test bulk attaching empty list of tools.""" + # Bulk attach empty list + await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=[], actor=default_user) + + # Verify no tools are attached + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert len(agent.tools) == 0 + + +@pytest.mark.asyncio +async def test_bulk_attach_tools_nonexistent_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop): + """Test bulk attaching tools with a nonexistent tool ID.""" + # Try to bulk attach with one valid and one invalid tool ID + nonexistent_id = "nonexistent-tool-id" + tool_ids = [print_tool.id, nonexistent_id] + + with pytest.raises(NoResultFound) as exc_info: + await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user) + + # Verify error message contains the missing tool ID + assert nonexistent_id in str(exc_info.value) + + # Verify no tools were attached (transaction should have rolled back) + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert len(agent.tools) == 0 + + +@pytest.mark.asyncio +async def test_bulk_attach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user, event_loop): + """Test bulk attaching tools to a nonexistent agent.""" + nonexistent_agent_id = "nonexistent-agent-id" + tool_ids = [print_tool.id, other_tool.id] + + with pytest.raises(NoResultFound): + await server.agent_manager.bulk_attach_tools_async(agent_id=nonexistent_agent_id, tool_ids=tool_ids, actor=default_user) + + +@pytest.mark.asyncio +async def test_attach_missing_files_tools_async(server: SyncServer, sarah_agent, default_user, event_loop): + """Test attaching missing file tools to an agent.""" + # First ensure file tools exist in the system + await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE}) + + # Get initial agent state (should have no file tools) + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + initial_tool_count = len(agent_state.tools) + + # Attach missing file tools + updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify all file tools are now attached + file_tool_names = {tool.name for tool in updated_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} + assert file_tool_names == set(FILES_TOOLS) + + # Verify the total tool count increased by the number of file tools + assert len(updated_agent_state.tools) == initial_tool_count + len(FILES_TOOLS) + + +@pytest.mark.asyncio +async def test_attach_missing_files_tools_async_partial(server: SyncServer, sarah_agent, default_user, event_loop): + """Test attaching missing file tools when some are already attached.""" + # First ensure file tools exist in the system + await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE}) + + # Get file tool IDs + all_tools = await server.tool_manager.list_tools_async(actor=default_user) + file_tools = [tool for tool in all_tools if tool.tool_type == ToolType.LETTA_FILES_CORE and tool.name in FILES_TOOLS] + + # Manually attach just the first file tool + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=file_tools[0].id, actor=default_user) + + # Get agent state with one file tool already attached + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 1 + + # Attach missing file tools + updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify all file tools are now attached + file_tool_names = {tool.name for tool in updated_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} + assert file_tool_names == set(FILES_TOOLS) + + # Verify no duplicates + all_tool_ids = [tool.id for tool in updated_agent_state.tools] + assert len(all_tool_ids) == len(set(all_tool_ids)) + + +@pytest.mark.asyncio +async def test_attach_missing_files_tools_async_idempotent(server: SyncServer, sarah_agent, default_user, event_loop): + """Test that attach_missing_files_tools is idempotent.""" + # First ensure file tools exist in the system + await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE}) + + # Get initial agent state + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + + # Attach missing file tools the first time + updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user) + first_tool_count = len(updated_agent_state.tools) + + # Call attach_missing_files_tools again (should be no-op) + final_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=updated_agent_state, actor=default_user) + + # Verify tool count didn't change + assert len(final_agent_state.tools) == first_tool_count + + # Verify still have all file tools + file_tool_names = {tool.name for tool in final_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE} + assert file_tool_names == set(FILES_TOOLS) + + +@pytest.mark.asyncio +async def test_detach_all_files_tools_async(server: SyncServer, sarah_agent, default_user, event_loop): + """Test detaching all file tools from an agent.""" + # First ensure file tools exist and attach them + await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE}) + + # Get initial agent state and attach file tools + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify file tools are attached + file_tool_count_before = len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) + assert file_tool_count_before == len(FILES_TOOLS) + + # Detach all file tools + updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify all file tools are detached + file_tool_count_after = len([t for t in updated_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) + assert file_tool_count_after == 0 + + # Verify the returned state was modified in-place (no DB reload) + assert updated_agent_state.id == agent_state.id + assert len(updated_agent_state.tools) == len(agent_state.tools) - file_tool_count_before + + +@pytest.mark.asyncio +async def test_detach_all_files_tools_async_empty(server: SyncServer, sarah_agent, default_user, event_loop): + """Test detaching all file tools when no file tools are attached.""" + # Get agent state (should have no file tools initially) + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + initial_tool_count = len(agent_state.tools) + + # Verify no file tools attached + file_tool_count = len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) + assert file_tool_count == 0 + + # Call detach_all_files_tools (should be no-op) + updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify nothing changed + assert len(updated_agent_state.tools) == initial_tool_count + assert updated_agent_state == agent_state # Should be the same object since no changes + + +@pytest.mark.asyncio +async def test_detach_all_files_tools_async_with_other_tools(server: SyncServer, sarah_agent, print_tool, default_user, event_loop): + """Test detaching all file tools preserves non-file tools.""" + # First ensure file tools exist + await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE}) + + # Attach a non-file tool + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user) + + # Get agent state and attach file tools + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify both file tools and print tool are attached + file_tools = [t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE] + assert len(file_tools) == len(FILES_TOOLS) + assert print_tool.id in [t.id for t in agent_state.tools] + + # Detach all file tools + updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify only file tools were removed, print tool remains + remaining_file_tools = [t for t in updated_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE] + assert len(remaining_file_tools) == 0 + assert print_tool.id in [t.id for t in updated_agent_state.tools] + assert len(updated_agent_state.tools) == 1 + + +@pytest.mark.asyncio +async def test_detach_all_files_tools_async_idempotent(server: SyncServer, sarah_agent, default_user, event_loop): + """Test that detach_all_files_tools is idempotent.""" + # First ensure file tools exist and attach them + await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE}) + + # Get initial agent state and attach file tools + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user) + + # Detach all file tools once + agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify no file tools + assert len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 0 + tool_count_after_first = len(agent_state.tools) + + # Detach all file tools again (should be no-op) + final_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user) + + # Verify still no file tools and same tool count + assert len([t for t in final_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 0 + assert len(final_agent_state.tools) == tool_count_after_first + + # ====================================================================================================================== # AgentManager Tests - Sources Relationship # ====================================================================================================================== diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index e9e4fdfb..d4bb7376 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -686,9 +686,13 @@ def test_include_return_message_types(client: LettaSDKClient, agent: AgentState, include_return_message_types=message_types, ) # wait to finish - while response.status != "completed": + while response.status not in {"failed", "completed", "cancelled", "expired"}: time.sleep(1) response = client.runs.retrieve(run_id=response.id) + + if response.status != "completed": + pytest.fail(f"Response status was NOT completed: {response}") + messages = client.runs.messages.list(run_id=response.id) verify_message_types(messages, message_types)