feat: Add get agents for source_id endpoint (#3381)

This commit is contained in:
Matthew Zhou
2025-07-17 11:47:41 -07:00
committed by GitHub
parent 181006b3ea
commit f4854e95cd
2 changed files with 67 additions and 5 deletions

View File

@@ -324,6 +324,19 @@ async def upload_file_to_source(
return file_metadata
@router.get("/{source_id}/agents", response_model=List[str], operation_id="get_agents_for_source")
async def get_agents_for_source(
source_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Get all agent IDs that have the specified source attached.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.source_manager.get_agents_for_source_id(source_id=source_id, actor=actor)
@router.get("/{source_id}/passages", response_model=List[Passage], operation_id="list_source_passages")
async def list_source_passages(
source_id: str,

View File

@@ -1,7 +1,7 @@
import asyncio
from typing import List, Optional
from sqlalchemy import select
from sqlalchemy import and_, exists, select
from letta.orm import Agent as AgentModel
from letta.orm.errors import NoResultFound
@@ -19,6 +19,30 @@ from letta.utils import enforce_types, printd
class SourceManager:
"""Manager class to handle business logic related to Sources."""
@trace_method
async def _validate_source_exists_async(self, session, source_id: str, actor: PydanticUser) -> None:
"""
Validate that a source exists and user has access to it using raw SQL for efficiency.
Args:
session: Database session
source_id: ID of the source to validate
actor: User performing the action
Raises:
NoResultFound: If source doesn't exist or user doesn't have access
"""
source_exists_query = select(
exists().where(
and_(SourceModel.id == source_id, SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False)
)
)
result = await session.execute(source_exists_query)
if not result.scalar():
raise NoResultFound(f"Source with ID {source_id} not found")
@enforce_types
@trace_method
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
@@ -93,20 +117,20 @@ class SourceManager:
@enforce_types
@trace_method
async def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
async def list_attached_agents(self, source_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
"""
Lists all agents that have the specified source attached.
Args:
source_id: ID of the source to find attached agents for
actor: User performing the action (optional for now, following existing pattern)
actor: User performing the action
Returns:
List[PydanticAgentState]: List of agents that have this source attached
"""
async with db_registry.async_session() as session:
# Verify source exists and user has permission to access it
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
await self._validate_source_exists_async(session, source_id, actor)
# Use junction table query instead of relationship to avoid performance issues
query = (
@@ -114,7 +138,7 @@ class SourceManager:
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id if actor else True,
AgentModel.organization_id == actor.organization_id,
AgentModel.is_deleted == False,
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
@@ -125,6 +149,31 @@ class SourceManager:
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
@enforce_types
@trace_method
async def get_agents_for_source_id(self, source_id: str, actor: PydanticUser) -> List[str]:
"""
Get all agent IDs associated with a given source ID.
Args:
source_id: ID of the source to find agents for
actor: User performing the action
Returns:
List[str]: List of agent IDs that have this source attached
"""
async with db_registry.async_session() as session:
# Verify source exists and user has permission to access it
await self._validate_source_exists_async(session, source_id, actor)
# Query the junction table directly for performance
query = select(SourcesAgents.agent_id).where(SourcesAgents.source_id == source_id)
result = await session.execute(query)
agent_ids = result.scalars().all()
return list(agent_ids)
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@trace_method