feat: Improve attach/detach missing file tools performance (#3486)
This commit is contained in:
@@ -293,7 +293,9 @@ async def attach_tool(
|
||||
Attach a tool to an agent.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
await server.agent_manager.attach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
# TODO: Unfortunately we need this to preserve our current API behavior
|
||||
return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/tools/detach/{tool_id}", response_model=AgentState, operation_id="detach_tool")
|
||||
@@ -307,7 +309,9 @@ async def detach_tool(
|
||||
Detach a tool from an agent.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
await server.agent_manager.detach_tool_async(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
# TODO: Unfortunately we need this to preserve our current API behavior
|
||||
return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")
|
||||
@@ -327,7 +331,8 @@ async def attach_source(
|
||||
agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=actor)
|
||||
|
||||
files = await server.file_manager.list_files(source_id, actor, include_content=True)
|
||||
await server.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
|
||||
if files:
|
||||
await server.insert_files_into_context_window(agent_state=agent_state, file_metadata_with_content=files, actor=actor)
|
||||
|
||||
if agent_state.enable_sleeptime:
|
||||
source = await server.source_manager.get_source_by_id(source_id=source_id)
|
||||
|
||||
@@ -2584,7 +2584,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Attaches a tool to an agent.
|
||||
|
||||
@@ -2601,22 +2601,112 @@ class AgentManager:
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
await validate_agent_exists_async(session, agent_id, actor)
|
||||
|
||||
# Use the _process_relationship helper to attach the tool
|
||||
await _process_relationship_async(
|
||||
session=session,
|
||||
agent=agent,
|
||||
relationship_name="tools",
|
||||
model_class=ToolModel,
|
||||
item_ids=[tool_id],
|
||||
allow_partial=False, # Ensure the tool exists
|
||||
replace=False, # Extend the existing tools
|
||||
# verify tool exists and belongs to organization in a single query with the insert
|
||||
# first, check if tool exists with correct organization
|
||||
tool_check_query = select(func.count(ToolModel.id)).where(
|
||||
ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id
|
||||
)
|
||||
tool_result = await session.execute(tool_check_query)
|
||||
if tool_result.scalar() == 0:
|
||||
raise NoResultFound(f"Tool with id={tool_id} not found in organization={actor.organization_id}")
|
||||
|
||||
# Commit and refresh the agent
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
# use postgresql on conflict or mysql on duplicate key update for atomic operation
|
||||
if settings.letta_pg_uri_no_default:
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
insert_stmt = pg_insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id)
|
||||
# on conflict do nothing - silently ignore if already exists
|
||||
insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"])
|
||||
result = await session.execute(insert_stmt)
|
||||
if result.rowcount == 0:
|
||||
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
|
||||
else:
|
||||
# for sqlite/mysql, check then insert
|
||||
existing_query = (
|
||||
select(func.count()).select_from(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id)
|
||||
)
|
||||
existing_result = await session.execute(existing_query)
|
||||
if existing_result.scalar() == 0:
|
||||
insert_stmt = insert(ToolsAgents).values(agent_id=agent_id, tool_id=tool_id)
|
||||
await session.execute(insert_stmt)
|
||||
else:
|
||||
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||||
"""
|
||||
Efficiently attaches multiple tools to an agent in a single operation.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to attach the tools to.
|
||||
tool_ids: List of tool IDs to attach.
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the agent or any tool is not found.
|
||||
"""
|
||||
if not tool_ids:
|
||||
# no tools to attach, nothing to do
|
||||
return
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
await validate_agent_exists_async(session, agent_id, actor)
|
||||
|
||||
# verify all tools exist and belong to organization in a single query
|
||||
tool_check_query = select(func.count(ToolModel.id)).where(
|
||||
ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id
|
||||
)
|
||||
tool_result = await session.execute(tool_check_query)
|
||||
found_count = tool_result.scalar()
|
||||
|
||||
if found_count != len(tool_ids):
|
||||
# find which tools are missing for better error message
|
||||
existing_query = select(ToolModel.id).where(ToolModel.id.in_(tool_ids), ToolModel.organization_id == actor.organization_id)
|
||||
existing_result = await session.execute(existing_query)
|
||||
existing_ids = {row[0] for row in existing_result}
|
||||
missing_ids = set(tool_ids) - existing_ids
|
||||
raise NoResultFound(f"Tools with ids={missing_ids} not found in organization={actor.organization_id}")
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
# prepare bulk values
|
||||
values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in tool_ids]
|
||||
|
||||
# bulk insert with on conflict do nothing
|
||||
insert_stmt = pg_insert(ToolsAgents).values(values)
|
||||
insert_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["agent_id", "tool_id"])
|
||||
result = await session.execute(insert_stmt)
|
||||
logger.info(
|
||||
f"Attached {result.rowcount} new tools to agent {agent_id} (skipped {len(tool_ids) - result.rowcount} already attached)"
|
||||
)
|
||||
else:
|
||||
# for sqlite/mysql, first check which tools are already attached
|
||||
existing_query = select(ToolsAgents.tool_id).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids))
|
||||
existing_result = await session.execute(existing_query)
|
||||
already_attached = {row[0] for row in existing_result}
|
||||
|
||||
# only insert tools that aren't already attached
|
||||
new_tool_ids = [tid for tid in tool_ids if tid not in already_attached]
|
||||
|
||||
if new_tool_ids:
|
||||
# bulk insert new attachments
|
||||
values = [{"agent_id": agent_id, "tool_id": tool_id} for tool_id in new_tool_ids]
|
||||
insert_stmt = insert(ToolsAgents).values(values)
|
||||
await session.execute(insert_stmt)
|
||||
logger.info(
|
||||
f"Attached {len(new_tool_ids)} new tools to agent {agent_id} (skipped {len(already_attached)} already attached)"
|
||||
)
|
||||
else:
|
||||
logger.info(f"All {len(tool_ids)} tools already attached to agent {agent_id}")
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -2625,7 +2715,7 @@ class AgentManager:
|
||||
Attaches missing core file tools to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to attach the tools to.
|
||||
agent_state: The current agent state with tools already loaded.
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
@@ -2634,21 +2724,50 @@ class AgentManager:
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
"""
|
||||
# Check if the agent is missing any files tools
|
||||
core_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||||
missing_tool_names = set(FILES_TOOLS).difference(core_tool_names)
|
||||
# get current file tools attached to the agent
|
||||
attached_file_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||||
|
||||
for tool_name in missing_tool_names:
|
||||
tool_id = await self.tool_manager.get_tool_id_by_name_async(tool_name=tool_name, actor=actor)
|
||||
# determine which file tools are missing
|
||||
missing_tool_names = set(FILES_TOOLS) - attached_file_tool_names
|
||||
|
||||
# TODO: This is hacky and deserves a rethink - how do we keep all the base tools available in every org always?
|
||||
if not tool_id:
|
||||
await self.tool_manager.upsert_base_tools_async(actor=actor, allowed_types={ToolType.LETTA_FILES_CORE})
|
||||
if not missing_tool_names:
|
||||
# agent already has all file tools
|
||||
return agent_state
|
||||
|
||||
# TODO: Inefficient - I think this re-retrieves the agent_state?
|
||||
agent_state = await self.attach_tool_async(agent_id=agent_state.id, tool_id=tool_id, actor=actor)
|
||||
# get full tool objects for all missing file tools in one query
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(ToolModel).where(
|
||||
ToolModel.name.in_(missing_tool_names),
|
||||
ToolModel.organization_id == actor.organization_id,
|
||||
ToolModel.tool_type == ToolType.LETTA_FILES_CORE,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
found_tool_models = result.scalars().all()
|
||||
|
||||
return agent_state
|
||||
if not found_tool_models:
|
||||
logger.warning(f"No file tools found for organization {actor.organization_id}. Expected tools: {missing_tool_names}")
|
||||
return agent_state
|
||||
|
||||
# convert to pydantic tools
|
||||
found_tools = [tool.to_pydantic() for tool in found_tool_models]
|
||||
found_tool_names = {tool.name for tool in found_tools}
|
||||
|
||||
# log if any expected tools weren't found
|
||||
still_missing = missing_tool_names - found_tool_names
|
||||
if still_missing:
|
||||
logger.warning(f"File tools {still_missing} not found in organization {actor.organization_id}")
|
||||
|
||||
# extract tool IDs for bulk attach
|
||||
tool_ids_to_attach = [tool.id for tool in found_tools]
|
||||
|
||||
# bulk attach all found file tools
|
||||
await self.bulk_attach_tools_async(agent_id=agent_state.id, tool_ids=tool_ids_to_attach, actor=actor)
|
||||
|
||||
# create a shallow copy with updated tools list to avoid modifying input
|
||||
agent_state_dict = agent_state.model_dump()
|
||||
agent_state_dict["tools"] = agent_state.tools + found_tools
|
||||
|
||||
return PydanticAgentState(**agent_state_dict)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -2657,25 +2776,30 @@ class AgentManager:
|
||||
Detach all core file tools from an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to detach the tools from.
|
||||
agent_state: The current agent state with tools already loaded.
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the agent or tool is not found.
|
||||
NoResultFound: If the agent is not found.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
"""
|
||||
# Check if the agent is missing any files tools
|
||||
core_tool_names = {tool.name for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||||
# extract file tool IDs directly from agent_state.tools
|
||||
file_tool_ids = [tool.id for tool in agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE]
|
||||
|
||||
for tool_name in core_tool_names:
|
||||
tool_id = await self.tool_manager.get_tool_id_by_name_async(tool_name=tool_name, actor=actor)
|
||||
if not file_tool_ids:
|
||||
# no file tools to detach
|
||||
return agent_state
|
||||
|
||||
# TODO: Inefficient - I think this re-retrieves the agent_state?
|
||||
agent_state = await self.detach_tool_async(agent_id=agent_state.id, tool_id=tool_id, actor=actor)
|
||||
# bulk detach all file tools in one operation
|
||||
await self.bulk_detach_tools_async(agent_id=agent_state.id, tool_ids=file_tool_ids, actor=actor)
|
||||
|
||||
return agent_state
|
||||
# create a shallow copy with updated tools list to avoid modifying input
|
||||
agent_state_dict = agent_state.model_dump()
|
||||
agent_state_dict["tools"] = [tool for tool in agent_state.tools if tool.tool_type != ToolType.LETTA_FILES_CORE]
|
||||
|
||||
return PydanticAgentState(**agent_state_dict)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -2713,7 +2837,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Detaches a tool from an agent.
|
||||
|
||||
@@ -2723,27 +2847,58 @@ class AgentManager:
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the agent or tool is not found.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
NoResultFound: If the agent is not found.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
await validate_agent_exists_async(session, agent_id, actor)
|
||||
|
||||
# Filter out the tool to be detached
|
||||
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
|
||||
# Delete the association directly - if it doesn't exist, rowcount will be 0
|
||||
delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id == tool_id)
|
||||
result = await session.execute(delete_query)
|
||||
|
||||
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
|
||||
if result.rowcount == 0:
|
||||
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
||||
else:
|
||||
logger.debug(f"Detached tool id={tool_id} from agent id={agent_id}")
|
||||
|
||||
# Update the tools relationship
|
||||
agent.tools = remaining_tools
|
||||
await session.commit()
|
||||
|
||||
# Commit and refresh the agent
|
||||
await agent.update_async(session, actor=actor)
|
||||
return await agent.to_pydantic_async()
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||||
"""
|
||||
Efficiently detaches multiple tools from an agent in a single operation.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to detach the tools from.
|
||||
tool_ids: List of tool IDs to detach.
|
||||
actor: User performing the action.
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the agent is not found.
|
||||
"""
|
||||
if not tool_ids:
|
||||
# no tools to detach, nothing to do
|
||||
return
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
await validate_agent_exists_async(session, agent_id, actor)
|
||||
|
||||
# Delete all associations in a single query
|
||||
delete_query = delete(ToolsAgents).where(ToolsAgents.agent_id == agent_id, ToolsAgents.tool_id.in_(tool_ids))
|
||||
result = await session.execute(delete_query)
|
||||
|
||||
detached_count = result.rowcount
|
||||
if detached_count == 0:
|
||||
logger.warning(f"No tools from list {tool_ids} were attached to agent id={agent_id}")
|
||||
elif detached_count < len(tool_ids):
|
||||
logger.info(f"Detached {detached_count} tools from agent {agent_id} ({len(tool_ids) - detached_count} were not attached)")
|
||||
else:
|
||||
logger.info(f"Detached all {detached_count} tools from agent {agent_id}")
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
import warnings
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from letta.constants import (
|
||||
BASE_FUNCTION_RETURN_CHAR_LIMIT,
|
||||
BASE_MEMORY_TOOLS,
|
||||
@@ -290,6 +292,16 @@ class ToolManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def tool_exists_async(self, tool_id: str, actor: PydanticUser) -> bool:
|
||||
"""Check if a tool exists and belongs to the user's organization (lightweight check)."""
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(func.count(ToolModel.id)).where(ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(query)
|
||||
count = result.scalar()
|
||||
return count > 0
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_tools_async(
|
||||
|
||||
@@ -1406,14 +1406,14 @@ async def test_list_agents_ordering_and_pagination(server: SyncServer, default_u
|
||||
async def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
|
||||
"""Test attaching a tool to an agent."""
|
||||
# Attach the tool
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Verify attachment through get_agent_by_id
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id in [t.id for t in agent.tools]
|
||||
|
||||
# Verify that attaching the same tool again doesn't cause duplication
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len([t for t in agent.tools if t.id == print_tool.id]) == 1
|
||||
|
||||
@@ -1422,39 +1422,125 @@ async def test_attach_tool(server: SyncServer, sarah_agent, print_tool, default_
|
||||
async def test_detach_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
|
||||
"""Test detaching a tool from an agent."""
|
||||
# Attach the tool first
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Verify it's attached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id in [t.id for t in agent.tools]
|
||||
|
||||
# Detach the tool
|
||||
server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.detach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Verify it's detached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id not in [t.id for t in agent.tools]
|
||||
|
||||
# Verify that detaching an already detached tool doesn't cause issues
|
||||
server.agent_manager.detach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.detach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_detach_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
|
||||
"""Test bulk detaching multiple tools from an agent."""
|
||||
# First attach both tools
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Verify both tools are attached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id in [t.id for t in agent.tools]
|
||||
assert other_tool.id in [t.id for t in agent.tools]
|
||||
|
||||
# Bulk detach both tools
|
||||
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Verify both tools are detached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id not in [t.id for t in agent.tools]
|
||||
assert other_tool.id not in [t.id for t in agent.tools]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_detach_tools_partial(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
|
||||
"""Test bulk detaching tools when some are not attached."""
|
||||
# Only attach one tool
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Try to bulk detach both tools (one attached, one not)
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Verify the attached tool was detached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id not in [t.id for t in agent.tools]
|
||||
assert other_tool.id not in [t.id for t in agent.tools]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_detach_tools_empty_list(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
|
||||
"""Test bulk detaching empty list of tools."""
|
||||
# Attach a tool first
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Bulk detach empty list
|
||||
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=[], actor=default_user)
|
||||
|
||||
# Verify the tool is still attached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert print_tool.id in [t.id for t in agent.tools]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_detach_tools_idempotent(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
|
||||
"""Test that bulk detach is idempotent."""
|
||||
# Attach both tools
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Bulk detach once
|
||||
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Verify tools are detached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(agent.tools) == 0
|
||||
|
||||
# Bulk detach again (should be no-op, no errors)
|
||||
await server.agent_manager.bulk_detach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Verify still no tools
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(agent.tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user, event_loop):
|
||||
"""Test bulk detaching tools from a nonexistent agent."""
|
||||
nonexistent_agent_id = "nonexistent-agent-id"
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
await server.agent_manager.bulk_detach_tools_async(agent_id=nonexistent_agent_id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
|
||||
"""Test attaching a tool to a nonexistent agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.attach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
|
||||
def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test attaching a nonexistent tool to an agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user)
|
||||
|
||||
|
||||
def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
|
||||
"""Test detaching a tool from a nonexistent agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.detach_tool(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.detach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1465,8 +1551,8 @@ async def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool,
|
||||
assert len(agent.tools) == 0
|
||||
|
||||
# Attach tools
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
server.agent_manager.attach_tool(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=other_tool.id, actor=default_user)
|
||||
|
||||
# List tools and verify
|
||||
agent = await server.agent_manager.get_agent_by_id_async(sarah_agent.id, actor=default_user)
|
||||
@@ -1476,6 +1562,251 @@ async def test_list_attached_tools(server: SyncServer, sarah_agent, print_tool,
|
||||
assert other_tool.id in attached_tool_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_attach_tools(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
|
||||
"""Test bulk attaching multiple tools to an agent."""
|
||||
# Bulk attach both tools
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Verify both tools are attached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
attached_tool_ids = [t.id for t in agent.tools]
|
||||
assert print_tool.id in attached_tool_ids
|
||||
assert other_tool.id in attached_tool_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_attach_tools_with_duplicates(server: SyncServer, sarah_agent, print_tool, other_tool, default_user, event_loop):
|
||||
"""Test bulk attaching tools handles duplicates correctly."""
|
||||
# First attach one tool
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Bulk attach both tools (one already attached)
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Verify both tools are attached and no duplicates
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
attached_tool_ids = [t.id for t in agent.tools]
|
||||
assert len(attached_tool_ids) == 2
|
||||
assert print_tool.id in attached_tool_ids
|
||||
assert other_tool.id in attached_tool_ids
|
||||
# Ensure no duplicates
|
||||
assert len(set(attached_tool_ids)) == len(attached_tool_ids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_attach_tools_empty_list(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""Test bulk attaching empty list of tools."""
|
||||
# Bulk attach empty list
|
||||
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=[], actor=default_user)
|
||||
|
||||
# Verify no tools are attached
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(agent.tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_attach_tools_nonexistent_tool(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
|
||||
"""Test bulk attaching tools with a nonexistent tool ID."""
|
||||
# Try to bulk attach with one valid and one invalid tool ID
|
||||
nonexistent_id = "nonexistent-tool-id"
|
||||
tool_ids = [print_tool.id, nonexistent_id]
|
||||
|
||||
with pytest.raises(NoResultFound) as exc_info:
|
||||
await server.agent_manager.bulk_attach_tools_async(agent_id=sarah_agent.id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
# Verify error message contains the missing tool ID
|
||||
assert nonexistent_id in str(exc_info.value)
|
||||
|
||||
# Verify no tools were attached (transaction should have rolled back)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(agent.tools) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_attach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user, event_loop):
|
||||
"""Test bulk attaching tools to a nonexistent agent."""
|
||||
nonexistent_agent_id = "nonexistent-agent-id"
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
await server.agent_manager.bulk_attach_tools_async(agent_id=nonexistent_agent_id, tool_ids=tool_ids, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_missing_files_tools_async(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""Test attaching missing file tools to an agent."""
|
||||
# First ensure file tools exist in the system
|
||||
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
|
||||
|
||||
# Get initial agent state (should have no file tools)
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
initial_tool_count = len(agent_state.tools)
|
||||
|
||||
# Attach missing file tools
|
||||
updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify all file tools are now attached
|
||||
file_tool_names = {tool.name for tool in updated_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||||
assert file_tool_names == set(FILES_TOOLS)
|
||||
|
||||
# Verify the total tool count increased by the number of file tools
|
||||
assert len(updated_agent_state.tools) == initial_tool_count + len(FILES_TOOLS)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_missing_files_tools_async_partial(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""Test attaching missing file tools when some are already attached."""
|
||||
# First ensure file tools exist in the system
|
||||
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
|
||||
|
||||
# Get file tool IDs
|
||||
all_tools = await server.tool_manager.list_tools_async(actor=default_user)
|
||||
file_tools = [tool for tool in all_tools if tool.tool_type == ToolType.LETTA_FILES_CORE and tool.name in FILES_TOOLS]
|
||||
|
||||
# Manually attach just the first file tool
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=file_tools[0].id, actor=default_user)
|
||||
|
||||
# Get agent state with one file tool already attached
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 1
|
||||
|
||||
# Attach missing file tools
|
||||
updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify all file tools are now attached
|
||||
file_tool_names = {tool.name for tool in updated_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||||
assert file_tool_names == set(FILES_TOOLS)
|
||||
|
||||
# Verify no duplicates
|
||||
all_tool_ids = [tool.id for tool in updated_agent_state.tools]
|
||||
assert len(all_tool_ids) == len(set(all_tool_ids))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_missing_files_tools_async_idempotent(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""Test that attach_missing_files_tools is idempotent."""
|
||||
# First ensure file tools exist in the system
|
||||
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
|
||||
|
||||
# Get initial agent state
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
|
||||
# Attach missing file tools the first time
|
||||
updated_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
first_tool_count = len(updated_agent_state.tools)
|
||||
|
||||
# Call attach_missing_files_tools again (should be no-op)
|
||||
final_agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=updated_agent_state, actor=default_user)
|
||||
|
||||
# Verify tool count didn't change
|
||||
assert len(final_agent_state.tools) == first_tool_count
|
||||
|
||||
# Verify still have all file tools
|
||||
file_tool_names = {tool.name for tool in final_agent_state.tools if tool.tool_type == ToolType.LETTA_FILES_CORE}
|
||||
assert file_tool_names == set(FILES_TOOLS)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_all_files_tools_async(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""Test detaching all file tools from an agent."""
|
||||
# First ensure file tools exist and attach them
|
||||
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
|
||||
|
||||
# Get initial agent state and attach file tools
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify file tools are attached
|
||||
file_tool_count_before = len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE])
|
||||
assert file_tool_count_before == len(FILES_TOOLS)
|
||||
|
||||
# Detach all file tools
|
||||
updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify all file tools are detached
|
||||
file_tool_count_after = len([t for t in updated_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE])
|
||||
assert file_tool_count_after == 0
|
||||
|
||||
# Verify the returned state was modified in-place (no DB reload)
|
||||
assert updated_agent_state.id == agent_state.id
|
||||
assert len(updated_agent_state.tools) == len(agent_state.tools) - file_tool_count_before
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_all_files_tools_async_empty(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""Test detaching all file tools when no file tools are attached."""
|
||||
# Get agent state (should have no file tools initially)
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
initial_tool_count = len(agent_state.tools)
|
||||
|
||||
# Verify no file tools attached
|
||||
file_tool_count = len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE])
|
||||
assert file_tool_count == 0
|
||||
|
||||
# Call detach_all_files_tools (should be no-op)
|
||||
updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify nothing changed
|
||||
assert len(updated_agent_state.tools) == initial_tool_count
|
||||
assert updated_agent_state == agent_state # Should be the same object since no changes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_all_files_tools_async_with_other_tools(server: SyncServer, sarah_agent, print_tool, default_user, event_loop):
|
||||
"""Test detaching all file tools preserves non-file tools."""
|
||||
# First ensure file tools exist
|
||||
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
|
||||
|
||||
# Attach a non-file tool
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
# Get agent state and attach file tools
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify both file tools and print tool are attached
|
||||
file_tools = [t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]
|
||||
assert len(file_tools) == len(FILES_TOOLS)
|
||||
assert print_tool.id in [t.id for t in agent_state.tools]
|
||||
|
||||
# Detach all file tools
|
||||
updated_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify only file tools were removed, print tool remains
|
||||
remaining_file_tools = [t for t in updated_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]
|
||||
assert len(remaining_file_tools) == 0
|
||||
assert print_tool.id in [t.id for t in updated_agent_state.tools]
|
||||
assert len(updated_agent_state.tools) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_all_files_tools_async_idempotent(server: SyncServer, sarah_agent, default_user, event_loop):
|
||||
"""Test that detach_all_files_tools is idempotent."""
|
||||
# First ensure file tools exist and attach them
|
||||
await server.tool_manager.upsert_base_tools_async(actor=default_user, allowed_types={ToolType.LETTA_FILES_CORE})
|
||||
|
||||
# Get initial agent state and attach file tools
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
agent_state = await server.agent_manager.attach_missing_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Detach all file tools once
|
||||
agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify no file tools
|
||||
assert len([t for t in agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 0
|
||||
tool_count_after_first = len(agent_state.tools)
|
||||
|
||||
# Detach all file tools again (should be no-op)
|
||||
final_agent_state = await server.agent_manager.detach_all_files_tools_async(agent_state=agent_state, actor=default_user)
|
||||
|
||||
# Verify still no file tools and same tool count
|
||||
assert len([t for t in final_agent_state.tools if t.tool_type == ToolType.LETTA_FILES_CORE]) == 0
|
||||
assert len(final_agent_state.tools) == tool_count_after_first
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Sources Relationship
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -686,9 +686,13 @@ def test_include_return_message_types(client: LettaSDKClient, agent: AgentState,
|
||||
include_return_message_types=message_types,
|
||||
)
|
||||
# wait to finish
|
||||
while response.status != "completed":
|
||||
while response.status not in {"failed", "completed", "cancelled", "expired"}:
|
||||
time.sleep(1)
|
||||
response = client.runs.retrieve(run_id=response.id)
|
||||
|
||||
if response.status != "completed":
|
||||
pytest.fail(f"Response status was NOT completed: {response}")
|
||||
|
||||
messages = client.runs.messages.list(run_id=response.id)
|
||||
verify_message_types(messages, message_types)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user