From b530527a919f74e6479463adf8f74a623f404711 Mon Sep 17 00:00:00 2001 From: cthomas Date: Sun, 20 Jul 2025 22:11:46 -0700 Subject: [PATCH] feat: only load specified relationships for agents (#3438) --- letta/services/agent_manager.py | 47 +++++++++++++++---- .../services/helpers/agent_manager_helper.py | 19 ++++++++ 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index ff7abfba..e25d2203 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -71,6 +71,7 @@ from letta.services.helpers.agent_manager_helper import ( _apply_identity_filters, _apply_pagination, _apply_pagination_async, + _apply_relationship_filters, _apply_tag_filter, _process_relationship, _process_relationship_async, @@ -1005,7 +1006,8 @@ class AgentManager: identity_id (Optional[str]): Filter by identifier ID. identifier_keys (Optional[List[str]]): Search agents by identifier keys. include_relationships (Optional[List[str]]): List of fields to load for performance optimization. - ascending + ascending (bool): Sort agents in ascending order. + sort_by (Optional[str]): Sort agents by this field. Returns: List[PydanticAgentState]: The filtered list of matching agents. @@ -1018,6 +1020,7 @@ class AgentManager: query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id) query = _apply_identity_filters(query, identity_id, identifier_keys) query = _apply_tag_filter(query, tags, match_all_tags) + query = _apply_relationship_filters(query, include_relationships) query = await _apply_pagination_async(query, before, after, session, ascending=ascending, sort_by=sort_by) if limit: @@ -1153,10 +1156,23 @@ class AgentManager: include_relationships: Optional[List[str]] = None, ) -> PydanticAgentState: """Fetch an agent by its ID.""" - async with db_registry.async_session() as session: - agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) - return await agent.to_pydantic_async(include_relationships=include_relationships) + try: + query = select(AgentModel) + query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) + query = query.where(AgentModel.id == agent_id) + query = _apply_relationship_filters(query, include_relationships) + + result = await session.execute(query) + agent = result.scalar_one_or_none() + + if agent is None: + raise NoResultFound(f"Agent with ID {agent_id} not found") + + return await agent.to_pydantic_async(include_relationships=include_relationships) + except Exception as e: + logger.error(f"Error fetching agent {agent_id}: {str(e)}") + raise @enforce_types @trace_method @@ -1168,12 +1184,23 @@ class AgentManager: ) -> list[PydanticAgentState]: """Fetch a list of agents by their IDs.""" async with db_registry.async_session() as session: - agents = await AgentModel.read_multiple_async( - db_session=session, - identifiers=agent_ids, - actor=actor, - ) - return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) + try: + query = select(AgentModel) + query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) + query = query.where(AgentModel.id.in_(agent_ids)) + query = _apply_relationship_filters(query, include_relationships) + + result = await session.execute(query) + agents = result.scalars().all() + + if not agents: + logger.warning(f"No agents found with IDs: {agent_ids}") + return [] + + return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) + except Exception as e: + logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}") + raise @enforce_types @trace_method diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 8c70f725..7c46219d 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Set import numpy as np from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all +from sqlalchemy.orm import noload from sqlalchemy.sql.expression import exists from letta import system @@ -669,6 +670,24 @@ def _apply_filters( return query +def _apply_relationship_filters(query, include_relationships: Optional[List[str]] = None): + if include_relationships is None: + return query + + if "memory" not in include_relationships: + query = query.options(noload(AgentModel.core_memory), noload(AgentModel.file_agents)) + if "identity_ids" not in include_relationships: + query = query.options(noload(AgentModel.identities)) + + relationships = ["tool_exec_environment_variables", "tools", "sources", "tags", "multi_agent_group"] + + for rel in relationships: + if rel not in include_relationships: + query = query.options(noload(getattr(AgentModel, rel))) + + return query + + def build_passage_query( actor: User, agent_id: Optional[str] = None,