diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 082fe6e1..fb41bf80 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -324,6 +324,19 @@ async def upload_file_to_source( return file_metadata +@router.get("/{source_id}/agents", response_model=List[str], operation_id="get_agents_for_source") +async def get_agents_for_source( + source_id: str, + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Get all agent IDs that have the specified source attached. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.source_manager.get_agents_for_source_id(source_id=source_id, actor=actor) + + @router.get("/{source_id}/passages", response_model=List[Passage], operation_id="list_source_passages") async def list_source_passages( source_id: str, diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index b3cd2c04..02357c06 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -1,7 +1,7 @@ import asyncio from typing import List, Optional -from sqlalchemy import select +from sqlalchemy import and_, exists, select from letta.orm import Agent as AgentModel from letta.orm.errors import NoResultFound @@ -19,6 +19,30 @@ from letta.utils import enforce_types, printd class SourceManager: """Manager class to handle business logic related to Sources.""" + @trace_method + async def _validate_source_exists_async(self, session, source_id: str, actor: PydanticUser) -> None: + """ + Validate that a source exists and user has access to it using raw SQL for efficiency. + + Args: + session: Database session + source_id: ID of the source to validate + actor: User performing the action + + Raises: + NoResultFound: If source doesn't exist or user doesn't have access + """ + source_exists_query = select( + exists().where( + and_(SourceModel.id == source_id, SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False) + ) + ) + + result = await session.execute(source_exists_query) + + if not result.scalar(): + raise NoResultFound(f"Source with ID {source_id} not found") + @enforce_types @trace_method async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource: @@ -93,20 +117,20 @@ class SourceManager: @enforce_types @trace_method - async def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]: + async def list_attached_agents(self, source_id: str, actor: PydanticUser) -> List[PydanticAgentState]: """ Lists all agents that have the specified source attached. Args: source_id: ID of the source to find attached agents for - actor: User performing the action (optional for now, following existing pattern) + actor: User performing the action Returns: List[PydanticAgentState]: List of agents that have this source attached """ async with db_registry.async_session() as session: # Verify source exists and user has permission to access it - source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor) + await self._validate_source_exists_async(session, source_id, actor) # Use junction table query instead of relationship to avoid performance issues query = ( @@ -114,7 +138,7 @@ class SourceManager: .join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id) .where( SourcesAgents.source_id == source_id, - AgentModel.organization_id == actor.organization_id if actor else True, + AgentModel.organization_id == actor.organization_id, AgentModel.is_deleted == False, ) .order_by(AgentModel.created_at.desc(), AgentModel.id) @@ -125,6 +149,31 @@ class SourceManager: return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm]) + @enforce_types + @trace_method + async def get_agents_for_source_id(self, source_id: str, actor: PydanticUser) -> List[str]: + """ + Get all agent IDs associated with a given source ID. + + Args: + source_id: ID of the source to find agents for + actor: User performing the action + + Returns: + List[str]: List of agent IDs that have this source attached + """ + async with db_registry.async_session() as session: + # Verify source exists and user has permission to access it + await self._validate_source_exists_async(session, source_id, actor) + + # Query the junction table directly for performance + query = select(SourcesAgents.agent_id).where(SourcesAgents.source_id == source_id) + + result = await session.execute(query) + agent_ids = result.scalars().all() + + return list(agent_ids) + # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types @trace_method