feat: Improve attach/detach missing file tools performance (#3486)

This commit is contained in:
Matthew Zhou
2025-07-22 15:35:42 -07:00
committed by GitHub
parent 1590e4e3e6
commit fa58214a99
5 changed files with 573 additions and 66 deletions

View File

@@ -293,7 +293,9 @@ async def attach_tool(
Attach a tool to an agent.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
# TODO: Unfortunately we need this to preserve our current API behavior
return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
@router.patch("/{agent_id}/tools/detach/{tool_id}", response_model=AgentState, operation_id="detach_tool")
@@ -307,7 +309,9 @@ async def detach_tool(
Detach a tool from an agent.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
# TODO: Unfortunately we need this to preserve our current API behavior
return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")
@@ -327,7 +331,8 @@ async def attach_source(
agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=actor)
files = await server.file_manager.list_files(source_id, actor, include_content=True)
await server.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
if files:
await server.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
if agent_state.enable_sleeptime:
source = await server.source_manager.get_source_by_id(source_id=source_id)

View File

@@ -2584,7 +2584,7 @@ class AgentManager:
@enforce_types
@trace_method
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
"""
Attaches a tool to an agent.
@@ -2601,22 +2601,112 @@ class AgentManager:
"""
async with db_registry.async_session() as session:
# Verify the agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
await validate_agent_exists_async(session, agent_id, actor)
# Use the _process_relationship helper to attach the tool
await _process_relationship_async(
session=session,
agent=agent,
relationship_name="tools",
model_class=ToolModel,
item_ids=[tool_id],
allow_partial=False, # Ensure the tool exists
replace=False, # Extend the existing tools
# verify tool exists and belongs to organization in a single query with the insert
# first, check if tool exists with correct organization
tool_check_query = select(func.count(ToolModel.id)).where(
ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id
)
tool_result = await session.execute(tool_check_query)
if tool_result.scalar() == 0:
raise NoResultFound(f"Tool with id={tool_id} not found in organization={actor.organization_id}")
# Commit and refresh the agent
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
# use postgresql on conflict or mysql on duplicate key update for atomic operation
if settings.letta_pg_uri_no_default:
from sqlalchemy.dialects.postgresql import insert as pg_insert
insert_stmt = pg_insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id)
# on conflict do nothing - silently ignore if already exists
insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"])
result = await session.execute(insert_stmt)
if result.rowcount == 0:
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
else:
# for sqlite/mysql, check then insert
existing_query = (
select(func.count()).select_from(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id)
)
existing_result = await session.execute(existing_query)
if existing_result.scalar() == 0:
insert_stmt = insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id)
await session.execute(insert_stmt)
else:
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
await session.commit()
@enforce_types
@trace_method
async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
"""
Efficiently attaches multiple tools to an agent in a single operation.
Args:
agent_id: ID of the agent to attach the tools to.
tool_ids: List of tool IDs to attach.
actor: User performing the action.
Raises:
NoResultFound: If the agent or any tool is not found.
"""
if not tool_ids:
# no tools to attach, nothing to do
return
async with db_registry.async_session() as session:
# Verify the agent exists and user has permission to access it
await validate_agent_exists_async(session, agent_id, actor)
# verify all tools exist and belong to organization in a single query
tool_check_query = select(func.count(ToolModel.id)).where(
ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id
)
tool_result = await session.execute(tool_check_query)
found_count = tool_result.scalar()
if found_count != len(tool_ids):
# find which tools are missing for better error message
existing_query = select(ToolModel.id).where(ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id)
existing_result = await session.execute(existing_query)
existing_ids = {row[0] for row in existing_result}
missing_ids = set(tool_ids) - existing_ids
raise NoResultFound(f"Tools with ids={missing_ids} not found in organization={actor.organization_id}")
if settings.letta_pg_uri_no_default:
from sqlalchemy.dialects.postgresql import insert as pg_insert
# prepare bulk values
values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in tool_ids]
# bulk insert with on conflict do nothing
insert_stmt = pg_insert(ToolsAgents).values(values)
insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"])
result = await session.execute(insert_stmt)
logger.info(
f"Attached {result.rowcount} new tools to agent {agent_id} (skipped {len(tool_ids) - result.rowcount} already attached)"
)
else:
# for sqlite/mysql, first check which tools are already attached
existing_query = select(ToolsAgents.tool_id).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids))
existing_result = await session.execute(existing_query)
already_attached = {row[0] for row in existing_result}
# only insert tools that aren't already attached
new_tool_ids = [tid for tid in tool_ids if tid not in already_attached]
if new_tool_ids:
# bulk insert new attachments
values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in new_tool_ids]
insert_stmt = insert(ToolsAgents).values(values)
await session.execute(insert_stmt)
logger.info(
f"Attached {len(new_tool_ids)} new tools to agent {agent_id} (skipped {len(already_attached)} already attached)"
)
else:
logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}")
await session.commit()
@enforce_types
@trace_method
@@ -2625,7 +2715,7 @@ class AgentManager:
Attaches missing core file tools to an agent.
Args:
agent_id: ID of the agent to attach the tools to.
agent_state: The current agent state with tools already loaded.
actor: User performing the action.
Raises:
@@ -2634,21 +2724,50 @@ class AgentManager:
Returns:
PydanticAgentState: The updated agent state.
"""
# Check if the agent is missing any files tools
core_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
missing_tool_names = set(FILES_TOOLS).difference(core_tool_names)
# get current file tools attached to the agent
attached_file_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
for tool_name in missing_tool_names:
tool_id = await self.tool_manager.get_tool_id_by_name_async(tool_name=tool_name, actor=actor)
# determine which file tools are missing
missing_tool_names = set(FILES_TOOLS) - attached_file_tool_names
# TODO: This is hacky and deserves a rethink - how do we keep all the base tools available in every org always?
if not tool_id:
await self.tool_manager.upsert_base_tools_async(actor=actor, allowed_types={ToolType.LETTA_FILES_CORE})
if not missing_tool_names:
# agent already has all file tools
return agent_state
# TODO: Inefficient - I think this re-retrieves the agent_state?
agent_state = await self.attach_tool_async(agent_id=agent_state.id, tool_id=tool_id, actor=actor)
# get full tool objects for all missing file tools in one query
async with db_registry.async_session() as session:
query = select(ToolModel).where(
ToolModel.name.in_(missing_tool_names),
ToolModel.organization_id == actor.organization_id,
ToolModel.tool_type == ToolType.LETTA_FILES_CORE,
)
result = await session.execute(query)
found_tool_models = result.scalars().all()
return agent_state
if not found_tool_models:
logger.warning(f"No file tools found for organization {actor.organization_id}. Expected tools: {missing_tool_names}")
return agent_state
# convert to pydantic tools
found_tools = [tool.to_pydantic() for tool in found_tool_models]
found_tool_names = {tool.name for tool in found_tools}
# log if any expected tools weren't found
still_missing = missing_tool_names - found_tool_names
if still_missing:
logger.warning(f"File tools {still_missing} not found in organization {actor.organization_id}")
# extract tool IDs for bulk attach
tool_ids_to_attach = [tool.id for tool in found_tools]
# bulk attach all found file tools
await self.bulk_attach_tools_async(agent_id=agent_state.id, tool_ids=tool_ids_to_attach, actor=actor)
# create a shallow copy with updated tools list to avoid modifying input
agent_state_dict = agent_state.model_dump()
agent_state_dict["tools"] = agent_state.tools + found_tools
return PydanticAgentState(**agent_state_dict)
@enforce_types
@trace_method
@@ -2657,25 +2776,30 @@ class AgentManager:
Detach all core file tools from an agent.
Args:
agent_id: ID of the agent to detach the tools from.
agent_state: The current agent state with tools already loaded.
actor: User performing the action.
Raises:
NoResultFound: If the agent or tool is not found.
NoResultFound: If the agent is not found.
Returns:
PydanticAgentState: The updated agent state.
"""
# Check if the agent is missing any files tools
core_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
# extract file tool IDs directly from agent_state.tools
file_tool_ids = [tool.id for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE]
for tool_name in core_tool_names:
tool_id = await self.tool_manager.get_tool_id_by_name_async(tool_name=tool_name, actor=actor)
if not file_tool_ids:
# no file tools to detach
return agent_state
# TODO: Inefficient - I think this re-retrieves the agent_state?
agent_state = await self.detach_tool_async(agent_id=agent_state.id, tool_id=tool_id, actor=actor)
# bulk detach all file tools in one operation
await self.bulk_detach_tools_async(agent_id=agent_state.id, tool_ids=file_tool_ids, actor=actor)
return agent_state
# create a shallow copy with updated tools list to avoid modifying input
agent_state_dict = agent_state.model_dump()
agent_state_dict["tools"] = [tool for tool in agent_state.tools if tool.tool_type != ToolType.LETTA_FILES_CORE]
return PydanticAgentState(**agent_state_dict)
@enforce_types
@trace_method
@@ -2713,7 +2837,7 @@ class AgentManager:
@enforce_types
@trace_method
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
"""
Detaches a tool from an agent.
@@ -2723,27 +2847,58 @@ class AgentManager:
actor: User performing the action.
Raises:
NoResultFound: If the agent or tool is not found.
Returns:
PydanticAgentState: The updated agent state.
NoResultFound: If the agent is not found.
"""
async with db_registry.async_session() as session:
# Verify the agent exists and user has permission to access it
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
await validate_agent_exists_async(session, agent_id, actor)
# Filter out the tool to be detached
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
# Delete the association directly - if it doesn't exist, rowcount will be 0
delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id)
result = await session.execute(delete_query)
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
if result.rowcount == 0:
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
else:
logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}")
# Update the tools relationship
agent.tools = remaining_tools
await session.commit()
# Commit and refresh the agent
await agent.update_async(session, actor=actor)
return await agent.to_pydantic_async()
@enforce_types
@trace_method
async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
"""
Efficiently detaches multiple tools from an agent in a single operation.
Args:
agent_id: ID of the agent to detach the tools from.
tool_ids: List of tool IDs to detach.
actor: User performing the action.
Raises:
NoResultFound: If the agent is not found.
"""
if not tool_ids:
# no tools to detach, nothing to do
return
async with db_registry.async_session() as session:
# Verify the agent exists and user has permission to access it
await validate_agent_exists_async(session, agent_id, actor)
# Delete all associations in a single query
delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids))
result = await session.execute(delete_query)
detached_count = result.rowcount
if detached_count == 0:
logger.warning(f"No tools from list {tool_ids} were attached to agent id={agent_id}")
elif detached_count < len(tool_ids):
logger.info(f"Detached {detached_count} tools from agent {agent_id} ({len(tool_ids) - detached_count} were not attached)")
else:
logger.info(f"Detached all {detached_count} tools from agent {agent_id}")
await session.commit()
@enforce_types
@trace_method

View File

@@ -3,6 +3,8 @@ import os
import warnings
from typing import List, Optional, Set, Union
from sqlalchemy import func, select
from letta.constants import (
BASE_FUNCTION_RETURN_CHAR_LIMIT,
BASE_MEMORY_TOOLS,
@@ -290,6 +292,16 @@ class ToolManager:
except NoResultFound:
return None
@enforce_types
@trace_method
async def tool_exists_async(self, tool_id: str, actor: PydanticUser) -> bool:
"""Check if a tool exists and belongs to the user's organization (lightweight check)."""
async with db_registry.async_session() as session:
query = select(func.count(ToolModel.id)).where(ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id)
result = await session.execute(query)
count = result.scalar()
return count > 0
@enforce_types
@trace_method
async def list_tools_async(

View File

@@ -1406,14 +1406,14 @@ async def test_list_agents_ordering_and_pagination(server: SyncServer, default_u
async def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
"""Test attaching a tool to an agent."""
# Attach the tool
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Verify attachment through get_agent_by_id
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert print_tool.id in [t.id for t in agent.tools]
# Verify that attaching the same tool again doesn't cause duplication
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert len([t for t in agent.tools if t.id == print_tool.id]) == 1
@@ -1422,39 +1422,125 @@ async def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_
async def test_detach_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
"""Test detaching a tool from an agent."""
# Attach the tool first
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Verify it's attached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert print_tool.id in [t.id for t in agent.tools]
# Detach the tool
server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
await server.agent_manager.detach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Verify it's detached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert print_tool.id not in [t.id for t in agent.tools]
# Verify that detaching an already detached tool doesn't cause issues
server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
await server.agent_manager.detach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
@pytest.mark.asyncio
async def test_bulk_detach_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
"""Test bulk detaching multiple tools from an agent."""
# First attach both tools
tool_ids = [print_tool.id, other_tool.id]
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Verify both tools are attached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert print_tool.id in [t.id for t in agent.tools]
assert other_tool.id in [t.id for t in agent.tools]
# Bulk detach both tools
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Verify both tools are detached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert print_tool.id not in [t.id for t in agent.tools]
assert other_tool.id not in [t.id for t in agent.tools]
@pytest.mark.asyncio
async def test_bulk_detach_tools_partial(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
"""Test bulk detaching tools when some are not attached."""
# Only attach one tool
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Try to bulk detach both tools (one attached, one not)
tool_ids = [print_tool.id, other_tool.id]
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Verify the attached tool was detached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert print_tool.id not in [t.id for t in agent.tools]
assert other_tool.id not in [t.id for t in agent.tools]
@pytest.mark.asyncio
async def test_bulk_detach_tools_empty_list(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
"""Test bulk detaching empty list of tools."""
# Attach a tool first
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Bulk detach empty list
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=[], actor=default_user)
# Verify the tool is still attached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert print_tool.id in [t.id for t in agent.tools]
@pytest.mark.asyncio
async def test_bulk_detach_tools_idempotent(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
"""Test that bulk detach is idempotent."""
# Attach both tools
tool_ids = [print_tool.id, other_tool.id]
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Bulk detach once
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Verify tools are detached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert len(agent.tools) == 0
# Bulk detach again (should be no-op, no errors)
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Verify still no tools
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert len(agent.tools) == 0
@pytest.mark.asyncio
async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user, event_loop):
"""Test bulk detaching tools from a nonexistent agent."""
nonexistent_agent_id = "nonexistent-agent-id"
tool_ids = [print_tool.id, other_tool.id]
with pytest.raises(NoResultFound):
await server.agent_manager.bulk_detach_tools_async(agent_id=nonexistent_agent_id, tool_ids=tool_ids, actor=default_user)
@pytest.mark.asyncio
async def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
"""Test attaching a tool to a nonexistent agent."""
with pytest.raises(NoResultFound):
server.agent_manager.attach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
await server.agent_manager.attach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user):
@pytest.mark.asyncio
async def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user):
"""Test attaching a nonexistent tool to an agent."""
with pytest.raises(NoResultFound):
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user)
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user)
def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
@pytest.mark.asyncio
async def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
"""Test detaching a tool from a nonexistent agent."""
with pytest.raises(NoResultFound):
server.agent_manager.detach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
await server.agent_manager.detach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
@pytest.mark.asyncio
@@ -1465,8 +1551,8 @@ async def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool,
assert len(agent.tools) == 0
# Attach tools
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user)
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user)
# List tools and verify
agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user)
@@ -1476,6 +1562,251 @@ async def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool,
assert other_tool.id in attached_tool_ids
@pytest.mark.asyncio
async def test_bulk_attach_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
"""Test bulk attaching multiple tools to an agent."""
# Bulk attach both tools
tool_ids = [print_tool.id, other_tool.id]
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Verify both tools are attached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
attached_tool_ids = [t.id for t in agent.tools]
assert print_tool.id in attached_tool_ids
assert other_tool.id in attached_tool_ids
@pytest.mark.asyncio
async def test_bulk_attach_tools_with_duplicates(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
"""Test bulk attaching tools handles duplicates correctly."""
# First attach one tool
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Bulk attach both tools (one already attached)
tool_ids = [print_tool.id, other_tool.id]
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Verify both tools are attached and no duplicates
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
attached_tool_ids = [t.id for t in agent.tools]
assert len(attached_tool_ids) == 2
assert print_tool.id in attached_tool_ids
assert other_tool.id in attached_tool_ids
# Ensure no duplicates
assert len(set(attached_tool_ids)) == len(attached_tool_ids)
@pytest.mark.asyncio
async def test_bulk_attach_tools_empty_list(server: SyncServer, sarah_agent, default_user, event_loop):
"""Test bulk attaching empty list of tools."""
# Bulk attach empty list
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=[], actor=default_user)
# Verify no tools are attached
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert len(agent.tools) == 0
@pytest.mark.asyncio
async def test_bulk_attach_tools_nonexistent_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
"""Test bulk attaching tools with a nonexistent tool ID."""
# Try to bulk attach with one valid and one invalid tool ID
nonexistent_id = "nonexistent-tool-id"
tool_ids = [print_tool.id, nonexistent_id]
with pytest.raises(NoResultFound) as exc_info:
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
# Verify error message contains the missing tool ID
assert nonexistent_id in str(exc_info.value)
# Verify no tools were attached (transaction should have rolled back)
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert len(agent.tools) == 0
@pytest.mark.asyncio
async def test_bulk_attach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user, event_loop):
"""Test bulk attaching tools to a nonexistent agent."""
nonexistent_agent_id = "nonexistent-agent-id"
tool_ids = [print_tool.id, other_tool.id]
with pytest.raises(NoResultFound):
await server.agent_manager.bulk_attach_tools_async(agent_id=nonexistent_agent_id, tool_ids=tool_ids, actor=default_user)
@pytest.mark.asyncio
async def test_attach_missing_files_tools_async(server: SyncServer, sarah_agent, default_user, event_loop):
"""Test attaching missing file tools to an agent."""
# First ensure file tools exist in the system
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
# Get initial agent state (should have no file tools)
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
initial_tool_count = len(agent_state.tools)
# Attach missing file tools
updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify all file tools are now attached
file_tool_names = {tool.name for tool in updated_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
assert file_tool_names == set(FILES_TOOLS)
# Verify the total tool count increased by the number of file tools
assert len(updated_agent_state.tools) == initial_tool_count + len(FILES_TOOLS)
@pytest.mark.asyncio
async def test_attach_missing_files_tools_async_partial(server: SyncServer, sarah_agent, default_user, event_loop):
"""Test attaching missing file tools when some are already attached."""
# First ensure file tools exist in the system
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
# Get file tool IDs
all_tools = await server.tool_manager.list_tools_async(actor=default_user)
file_tools = [tool for tool in all_tools if tool.tool_type == ToolType.LETTA_FILES_CORE and tool.name in FILES_TOOLS]
# Manually attach just the first file tool
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=file_tools[0].id, actor=default_user)
# Get agent state with one file tool already attached
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 1
# Attach missing file tools
updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify all file tools are now attached
file_tool_names = {tool.name for tool in updated_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
assert file_tool_names == set(FILES_TOOLS)
# Verify no duplicates
all_tool_ids = [tool.id for tool in updated_agent_state.tools]
assert len(all_tool_ids) == len(set(all_tool_ids))
@pytest.mark.asyncio
async def test_attach_missing_files_tools_async_idempotent(server: SyncServer, sarah_agent, default_user, event_loop):
"""Test that attach_missing_files_tools is idempotent."""
# First ensure file tools exist in the system
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
# Get initial agent state
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
# Attach missing file tools the first time
updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
first_tool_count = len(updated_agent_state.tools)
# Call attach_missing_files_tools again (should be no-op)
final_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=updated_agent_state, actor=default_user)
# Verify tool count didn't change
assert len(final_agent_state.tools) == first_tool_count
# Verify still have all file tools
file_tool_names = {tool.name for tool in final_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
assert file_tool_names == set(FILES_TOOLS)
@pytest.mark.asyncio
async def test_detach_all_files_tools_async(server: SyncServer, sarah_agent, default_user, event_loop):
"""Test detaching all file tools from an agent."""
# First ensure file tools exist and attach them
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
# Get initial agent state and attach file tools
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify file tools are attached
file_tool_count_before = len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE])
assert file_tool_count_before == len(FILES_TOOLS)
# Detach all file tools
updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify all file tools are detached
file_tool_count_after = len([t for t in updated_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE])
assert file_tool_count_after == 0
# Verify the returned state was modified in-place (no DB reload)
assert updated_agent_state.id == agent_state.id
assert len(updated_agent_state.tools) == len(agent_state.tools) - file_tool_count_before
@pytest.mark.asyncio
async def test_detach_all_files_tools_async_empty(server: SyncServer, sarah_agent, default_user, event_loop):
"""Test detaching all file tools when no file tools are attached."""
# Get agent state (should have no file tools initially)
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
initial_tool_count = len(agent_state.tools)
# Verify no file tools attached
file_tool_count = len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE])
assert file_tool_count == 0
# Call detach_all_files_tools (should be no-op)
updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify nothing changed
assert len(updated_agent_state.tools) == initial_tool_count
assert updated_agent_state == agent_state # Should be the same object since no changes
@pytest.mark.asyncio
async def test_detach_all_files_tools_async_with_other_tools(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
"""Test detaching all file tools preserves non-file tools."""
# First ensure file tools exist
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
# Attach a non-file tool
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
# Get agent state and attach file tools
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify both file tools and print tool are attached
file_tools = [t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]
assert len(file_tools) == len(FILES_TOOLS)
assert print_tool.id in [t.id for t in agent_state.tools]
# Detach all file tools
updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify only file tools were removed, print tool remains
remaining_file_tools = [t for t in updated_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]
assert len(remaining_file_tools) == 0
assert print_tool.id in [t.id for t in updated_agent_state.tools]
assert len(updated_agent_state.tools) == 1
@pytest.mark.asyncio
async def test_detach_all_files_tools_async_idempotent(server: SyncServer, sarah_agent, default_user, event_loop):
"""Test that detach_all_files_tools is idempotent."""
# First ensure file tools exist and attach them
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
# Get initial agent state and attach file tools
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
# Detach all file tools once
agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify no file tools
assert len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 0
tool_count_after_first = len(agent_state.tools)
# Detach all file tools again (should be no-op)
final_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
# Verify still no file tools and same tool count
assert len([t for t in final_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 0
assert len(final_agent_state.tools) == tool_count_after_first
# ======================================================================================================================
# AgentManager Tests - Sources Relationship
# ======================================================================================================================

View File

@@ -686,9 +686,13 @@ def test_include_return_message_types(client: LettaSDKClient, agent: AgentState,
include_return_message_types=message_types,
)
# wait to finish
while response.status != "completed":
while response.status not in {"failed", "completed", "cancelled", "expired"}:
time.sleep(1)
response = client.runs.retrieve(run_id=response.id)
if response.status != "completed":
pytest.fail(f"Response status was NOT completed: {response}")
messages = client.runs.messages.list(run_id=response.id)
verify_message_types(messages, message_types)