From c6ff04b4dabc7b061fe75e2e9ac322d7a1880ebe Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 15 Jul 2025 17:15:33 -0700 Subject: [PATCH] feat: Integrate performant `validate_agent_exists_async` (#3350) --- letta/services/agent_manager.py | 33 +++---------------- .../services/helpers/agent_manager_helper.py | 22 +++++++++++++ letta/services/message_manager.py | 5 +-- tests/test_managers.py | 18 +++++++++- 4 files changed, 46 insertions(+), 32 deletions(-) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 005a78fc..333af4a7 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -84,6 +84,7 @@ from letta.services.helpers.agent_manager_helper import ( derive_system_message, initialize_message_sequence, package_initial_message_sequence, + validate_agent_exists_async, ) from letta.services.identity_manager import IdentityManager from letta.services.message_manager import MessageManager @@ -107,32 +108,6 @@ class AgentManager: self.identity_manager = IdentityManager() self.file_agent_manager = FileAgentManager() - @trace_method - async def _validate_agent_exists_async(self, session, agent_id: str, actor: PydanticUser) -> None: - """ - Validate that an agent exists and user has access to it using raw SQL for efficiency. - - Args: - session: Database session - agent_id: ID of the agent to validate - actor: User performing the action - - Raises: - NoResultFound: If agent doesn't exist or user doesn't have access - """ - agent_check_query = sa.text( - """ - SELECT 1 FROM agents - WHERE id = :agent_id - AND organization_id = :org_id - AND is_deleted = false - """ - ) - agent_exists = await session.execute(agent_check_query, {"agent_id": agent_id, "org_id": actor.organization_id}) - - if not agent_exists.fetchone(): - raise NoResultFound(f"Agent with ID {agent_id} not found") - @staticmethod def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]: """ @@ -1907,7 +1882,7 @@ class AgentManager: """ async with db_registry.async_session() as session: # Validate agent exists and user has access - await self._validate_agent_exists_async(session, agent_id, actor) + await validate_agent_exists_async(session, agent_id, actor) # Use raw SQL to efficiently fetch sources - much faster than lazy loading # Fast query without relationship loading @@ -1943,7 +1918,7 @@ class AgentManager: """ async with db_registry.async_session() as session: # Validate agent exists and user has access - await self._validate_agent_exists_async(session, agent_id, actor) + await validate_agent_exists_async(session, agent_id, actor) # Check if the source is actually attached to this agent using junction table attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id) @@ -2710,7 +2685,7 @@ class AgentManager: """ async with db_registry.async_session() as session: # lightweight check for agent access - await self._validate_agent_exists_async(session, agent_id, actor) + await validate_agent_exists_async(session, agent_id, actor) # direct query for tools via join - much more performant query = ( diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 4e262de0..09178419 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -1070,3 +1070,25 @@ def calculate_multi_agent_tools() -> Set[str]: return set(MULTI_AGENT_TOOLS) - set(LOCAL_ONLY_MULTI_AGENT_TOOLS) else: return set(MULTI_AGENT_TOOLS) + + +@trace_method +async def validate_agent_exists_async(session, agent_id: str, actor: User) -> None: + """ + Validate that an agent exists and user has access to it using raw SQL for efficiency. + + Args: + session: Database session + agent_id: ID of the agent to validate + actor: User performing the action + + Raises: + NoResultFound: If agent doesn't exist or user doesn't have access + """ + agent_exists_query = select( + exists().where(and_(AgentModel.id == agent_id, AgentModel.organization_id == actor.organization_id, AgentModel.is_deleted == False)) + ) + result = await session.execute(agent_exists_query) + + if not result.scalar(): + raise NoResultFound(f"Agent with ID {agent_id} not found") diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 5e1620cf..8ce15610 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -17,6 +17,7 @@ from letta.schemas.message import MessageUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.file_manager import FileManager +from letta.services.helpers.agent_manager_helper import validate_agent_exists_async from letta.utils import enforce_types logger = get_logger(__name__) @@ -541,7 +542,7 @@ class MessageManager: async with db_registry.async_session() as session: # Permission check: raise if the agent doesn't exist or actor is not allowed. - await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + await validate_agent_exists_async(session, agent_id, actor) # Build a query that directly filters the Message table by agent_id. query = select(MessageModel).where(MessageModel.agent_id == agent_id) @@ -611,7 +612,7 @@ class MessageManager: """ async with db_registry.async_session() as session: # 1) verify the agent exists and the actor has access - await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + await validate_agent_exists_async(session, agent_id, actor) # 2) issue a CORE DELETE against the mapped class stmt = ( diff --git a/tests/test_managers.py b/tests/test_managers.py index f8833f8f..8585c409 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -85,7 +85,7 @@ from letta.schemas.user import UserUpdate from letta.server.db import db_registry from letta.server.server import SyncServer from letta.services.block_manager import BlockManager -from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools +from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools, validate_agent_exists_async from letta.services.step_manager import FeedbackType from letta.settings import tool_settings from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview @@ -693,6 +693,22 @@ async def another_file(server, default_source, default_user, default_organizatio # ====================================================================================================================== # AgentManager Tests - Basic # ====================================================================================================================== +@pytest.mark.asyncio +async def test_validate_agent_exists_async(server: SyncServer, comprehensive_test_agent_fixture, default_user): + """Test the validate_agent_exists_async helper function""" + created_agent, _ = comprehensive_test_agent_fixture + + # test with valid agent + async with db_registry.async_session() as session: + # should not raise exception + await validate_agent_exists_async(session, created_agent.id, default_user) + + # test with non-existent agent + async with db_registry.async_session() as session: + with pytest.raises(NoResultFound): + await validate_agent_exists_async(session, "non-existent-id", default_user) + + @pytest.mark.asyncio async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop): # Test agent creation