feat: only load specified relationships for agents (#3438)

This commit is contained in:
cthomas
2025-07-20 22:11:46 -07:00
committed by GitHub
parent 873791659c
commit b530527a91
2 changed files with 56 additions and 10 deletions

View File

@@ -71,6 +71,7 @@ from letta.services.helpers.agent_manager_helper import (
_apply_identity_filters,
_apply_pagination,
_apply_pagination_async,
_apply_relationship_filters,
_apply_tag_filter,
_process_relationship,
_process_relationship_async,
@@ -1005,7 +1006,8 @@ class AgentManager:
identity_id (Optional[str]): Filter by identifier ID.
identifier_keys (Optional[List[str]]): Search agents by identifier keys.
include_relationships (Optional[List[str]]): List of fields to load for performance optimization.
ascending
ascending (bool): Sort agents in ascending order.
sort_by (Optional[str]): Sort agents by this field.
Returns:
List[PydanticAgentState]: The filtered list of matching agents.
@@ -1018,6 +1020,7 @@ class AgentManager:
query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id)
query = _apply_identity_filters(query, identity_id, identifier_keys)
query = _apply_tag_filter(query, tags, match_all_tags)
query = _apply_relationship_filters(query, include_relationships)
query = await _apply_pagination_async(query, before, after, session, ascending=ascending, sort_by=sort_by)
if limit:
@@ -1153,10 +1156,23 @@ class AgentManager:
include_relationships: Optional[List[str]] = None,
) -> PydanticAgentState:
"""Fetch an agent by its ID."""
async with db_registry.async_session() as session:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
return await agent.to_pydantic_async(include_relationships=include_relationships)
try:
query = select(AgentModel)
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
query = query.where(AgentModel.id == agent_id)
query = _apply_relationship_filters(query, include_relationships)
result = await session.execute(query)
agent = result.scalar_one_or_none()
if agent is None:
raise NoResultFound(f"Agent with ID {agent_id} not found")
return await agent.to_pydantic_async(include_relationships=include_relationships)
except Exception as e:
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
raise
@enforce_types
@trace_method
@@ -1168,12 +1184,23 @@ class AgentManager:
) -> list[PydanticAgentState]:
"""Fetch a list of agents by their IDs."""
async with db_registry.async_session() as session:
agents = await AgentModel.read_multiple_async(
db_session=session,
identifiers=agent_ids,
actor=actor,
)
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
try:
query = select(AgentModel)
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
query = query.where(AgentModel.id.in_(agent_ids))
query = _apply_relationship_filters(query, include_relationships)
result = await session.execute(query)
agents = result.scalars().all()
if not agents:
logger.warning(f"No agents found with IDs: {agent_ids}")
return []
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
except Exception as e:
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
raise
@enforce_types
@trace_method

View File

@@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Set
import numpy as np
from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all
from sqlalchemy.orm import noload
from sqlalchemy.sql.expression import exists
from letta import system
@@ -669,6 +670,24 @@ def _apply_filters(
return query
def _apply_relationship_filters(query, include_relationships: Optional[List[str]] = None):
if include_relationships is None:
return query
if "memory" not in include_relationships:
query = query.options(noload(AgentModel.core_memory), noload(AgentModel.file_agents))
if "identity_ids" not in include_relationships:
query = query.options(noload(AgentModel.identities))
relationships = ["tool_exec_environment_variables", "tools", "sources", "tags", "multi_agent_group"]
for rel in relationships:
if rel not in include_relationships:
query = query.options(noload(getattr(AgentModel, rel)))
return query
def build_passage_query(
actor: User,
agent_id: Optional[str] = None,