feat: Integrate performant validate_agent_exists_async (#3350)

This commit is contained in:
Matthew Zhou
2025-07-15 17:15:33 -07:00
committed by GitHub
parent 5bdb18d55f
commit c6ff04b4da
4 changed files with 46 additions and 32 deletions

View File

@@ -84,6 +84,7 @@ from letta.services.helpers.agent_manager_helper import (
derive_system_message,
initialize_message_sequence,
package_initial_message_sequence,
validate_agent_exists_async,
)
from letta.services.identity_manager import IdentityManager
from letta.services.message_manager import MessageManager
@@ -107,32 +108,6 @@ class AgentManager:
self.identity_manager = IdentityManager()
self.file_agent_manager = FileAgentManager()
@trace_method
async def _validate_agent_exists_async(self, session, agent_id: str, actor: PydanticUser) -> None:
"""
Validate that an agent exists and user has access to it using raw SQL for efficiency.
Args:
session: Database session
agent_id: ID of the agent to validate
actor: User performing the action
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
agent_check_query = sa.text(
"""
SELECT 1 FROM agents
WHERE id = :agent_id
AND organization_id = :org_id
AND is_deleted = false
"""
)
agent_exists = await session.execute(agent_check_query, {"agent_id": agent_id, "org_id": actor.organization_id})
if not agent_exists.fetchone():
raise NoResultFound(f"Agent with ID {agent_id} not found")
@staticmethod
def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
"""
@@ -1907,7 +1882,7 @@ class AgentManager:
"""
async with db_registry.async_session() as session:
# Validate agent exists and user has access
await self._validate_agent_exists_async(session, agent_id, actor)
await validate_agent_exists_async(session, agent_id, actor)
# Use raw SQL to efficiently fetch sources - much faster than lazy loading
# Fast query without relationship loading
@@ -1943,7 +1918,7 @@ class AgentManager:
"""
async with db_registry.async_session() as session:
# Validate agent exists and user has access
await self._validate_agent_exists_async(session, agent_id, actor)
await validate_agent_exists_async(session, agent_id, actor)
# Check if the source is actually attached to this agent using junction table
attachment_check_query = select(SourcesAgents).where(SourcesAgents.agent_id == agent_id, SourcesAgents.source_id == source_id)
@@ -2710,7 +2685,7 @@ class AgentManager:
"""
async with db_registry.async_session() as session:
# lightweight check for agent access
await self._validate_agent_exists_async(session, agent_id, actor)
await validate_agent_exists_async(session, agent_id, actor)
# direct query for tools via join - much more performant
query = (

View File

@@ -1070,3 +1070,25 @@ def calculate_multi_agent_tools() -> Set[str]:
return set(MULTI_AGENT_TOOLS) - set(LOCAL_ONLY_MULTI_AGENT_TOOLS)
else:
return set(MULTI_AGENT_TOOLS)
@trace_method
async def validate_agent_exists_async(session, agent_id: str, actor: User) -> None:
"""
Validate that an agent exists and user has access to it using raw SQL for efficiency.
Args:
session: Database session
agent_id: ID of the agent to validate
actor: User performing the action
Raises:
NoResultFound: If agent doesn't exist or user doesn't have access
"""
agent_exists_query = select(
exists().where(and_(AgentModel.id == agent_id, AgentModel.organization_id == actor.organization_id, AgentModel.is_deleted == False))
)
result = await session.execute(agent_exists_query)
if not result.scalar():
raise NoResultFound(f"Agent with ID {agent_id} not found")

View File

@@ -17,6 +17,7 @@ from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.file_manager import FileManager
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
from letta.utils import enforce_types
logger = get_logger(__name__)
@@ -541,7 +542,7 @@ class MessageManager:
async with db_registry.async_session() as session:
# Permission check: raise if the agent doesn't exist or actor is not allowed.
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
await validate_agent_exists_async(session, agent_id, actor)
# Build a query that directly filters the Message table by agent_id.
query = select(MessageModel).where(MessageModel.agent_id == agent_id)
@@ -611,7 +612,7 @@ class MessageManager:
"""
async with db_registry.async_session() as session:
# 1) verify the agent exists and the actor has access
await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
await validate_agent_exists_async(session, agent_id, actor)
# 2) issue a CORE DELETE against the mapped class
stmt = (

View File

@@ -85,7 +85,7 @@ from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools
from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools, validate_agent_exists_async
from letta.services.step_manager import FeedbackType
from letta.settings import tool_settings
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
@@ -693,6 +693,22 @@ async def another_file(server, default_source, default_user, default_organizatio
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@pytest.mark.asyncio
async def test_validate_agent_exists_async(server: SyncServer, comprehensive_test_agent_fixture, default_user):
"""Test the validate_agent_exists_async helper function"""
created_agent, _ = comprehensive_test_agent_fixture
# test with valid agent
async with db_registry.async_session() as session:
# should not raise exception
await validate_agent_exists_async(session, created_agent.id, default_user)
# test with non-existent agent
async with db_registry.async_session() as session:
with pytest.raises(NoResultFound):
await validate_agent_exists_async(session, "non-existent-id", default_user)
@pytest.mark.asyncio
async def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixture, default_user, event_loop):
# Test agent creation