feat: only load specified relationships for agents (#3438)
This commit is contained in:
@@ -71,6 +71,7 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
_apply_identity_filters,
|
||||
_apply_pagination,
|
||||
_apply_pagination_async,
|
||||
_apply_relationship_filters,
|
||||
_apply_tag_filter,
|
||||
_process_relationship,
|
||||
_process_relationship_async,
|
||||
@@ -1005,7 +1006,8 @@ class AgentManager:
|
||||
identity_id (Optional[str]): Filter by identifier ID.
|
||||
identifier_keys (Optional[List[str]]): Search agents by identifier keys.
|
||||
include_relationships (Optional[List[str]]): List of fields to load for performance optimization.
|
||||
ascending
|
||||
ascending (bool): Sort agents in ascending order.
|
||||
sort_by (Optional[str]): Sort agents by this field.
|
||||
|
||||
Returns:
|
||||
List[PydanticAgentState]: The filtered list of matching agents.
|
||||
@@ -1018,6 +1020,7 @@ class AgentManager:
|
||||
query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id)
|
||||
query = _apply_identity_filters(query, identity_id, identifier_keys)
|
||||
query = _apply_tag_filter(query, tags, match_all_tags)
|
||||
query = _apply_relationship_filters(query, include_relationships)
|
||||
query = await _apply_pagination_async(query, before, after, session, ascending=ascending, sort_by=sort_by)
|
||||
|
||||
if limit:
|
||||
@@ -1153,10 +1156,23 @@ class AgentManager:
|
||||
include_relationships: Optional[List[str]] = None,
|
||||
) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
return await agent.to_pydantic_async(include_relationships=include_relationships)
|
||||
try:
|
||||
query = select(AgentModel)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
query = query.where(AgentModel.id == agent_id)
|
||||
query = _apply_relationship_filters(query, include_relationships)
|
||||
|
||||
result = await session.execute(query)
|
||||
agent = result.scalar_one_or_none()
|
||||
|
||||
if agent is None:
|
||||
raise NoResultFound(f"Agent with ID {agent_id} not found")
|
||||
|
||||
return await agent.to_pydantic_async(include_relationships=include_relationships)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent {agent_id}: {str(e)}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -1168,12 +1184,23 @@ class AgentManager:
|
||||
) -> list[PydanticAgentState]:
|
||||
"""Fetch a list of agents by their IDs."""
|
||||
async with db_registry.async_session() as session:
|
||||
agents = await AgentModel.read_multiple_async(
|
||||
db_session=session,
|
||||
identifiers=agent_ids,
|
||||
actor=actor,
|
||||
)
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||||
try:
|
||||
query = select(AgentModel)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
query = query.where(AgentModel.id.in_(agent_ids))
|
||||
query = _apply_relationship_filters(query, include_relationships)
|
||||
|
||||
result = await session.execute(query)
|
||||
agents = result.scalars().all()
|
||||
|
||||
if not agents:
|
||||
logger.warning(f"No agents found with IDs: {agent_ids}")
|
||||
return []
|
||||
|
||||
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agents with IDs {agent_ids}: {str(e)}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import List, Literal, Optional, Set
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import Select, and_, asc, desc, func, literal, nulls_last, or_, select, union_all
|
||||
from sqlalchemy.orm import noload
|
||||
from sqlalchemy.sql.expression import exists
|
||||
|
||||
from letta import system
|
||||
@@ -669,6 +670,24 @@ def _apply_filters(
|
||||
return query
|
||||
|
||||
|
||||
def _apply_relationship_filters(query, include_relationships: Optional[List[str]] = None):
|
||||
if include_relationships is None:
|
||||
return query
|
||||
|
||||
if "memory" not in include_relationships:
|
||||
query = query.options(noload(AgentModel.core_memory), noload(AgentModel.file_agents))
|
||||
if "identity_ids" not in include_relationships:
|
||||
query = query.options(noload(AgentModel.identities))
|
||||
|
||||
relationships = ["tool_exec_environment_variables", "tools", "sources", "tags", "multi_agent_group"]
|
||||
|
||||
for rel in relationships:
|
||||
if rel not in include_relationships:
|
||||
query = query.options(noload(getattr(AgentModel, rel)))
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def build_passage_query(
|
||||
actor: User,
|
||||
agent_id: Optional[str] = None,
|
||||
|
||||
Reference in New Issue
Block a user