feat: Add get agents for source_id endpoint (#3381)
This commit is contained in:
@@ -324,6 +324,19 @@ async def upload_file_to_source(
|
||||
return file_metadata
|
||||
|
||||
|
||||
@router.get("/{source_id}/agents", response_model=List[str], operation_id="get_agents_for_source")
|
||||
async def get_agents_for_source(
|
||||
source_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get all agent IDs that have the specified source attached.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.source_manager.get_agents_for_source_id(source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{source_id}/passages", response_model=List[Passage], operation_id="list_source_passages")
|
||||
async def list_source_passages(
|
||||
source_id: str,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import and_, exists, select
|
||||
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -19,6 +19,30 @@ from letta.utils import enforce_types, printd
|
||||
class SourceManager:
|
||||
"""Manager class to handle business logic related to Sources."""
|
||||
|
||||
@trace_method
|
||||
async def _validate_source_exists_async(self, session, source_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Validate that a source exists and user has access to it using raw SQL for efficiency.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
source_id: ID of the source to validate
|
||||
actor: User performing the action
|
||||
|
||||
Raises:
|
||||
NoResultFound: If source doesn't exist or user doesn't have access
|
||||
"""
|
||||
source_exists_query = select(
|
||||
exists().where(
|
||||
and_(SourceModel.id == source_id, SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False)
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.execute(source_exists_query)
|
||||
|
||||
if not result.scalar():
|
||||
raise NoResultFound(f"Source with ID {source_id} not found")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
|
||||
@@ -93,20 +117,20 @@ class SourceManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
|
||||
async def list_attached_agents(self, source_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Lists all agents that have the specified source attached.
|
||||
|
||||
Args:
|
||||
source_id: ID of the source to find attached agents for
|
||||
actor: User performing the action (optional for now, following existing pattern)
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
List[PydanticAgentState]: List of agents that have this source attached
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify source exists and user has permission to access it
|
||||
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
|
||||
await self._validate_source_exists_async(session, source_id, actor)
|
||||
|
||||
# Use junction table query instead of relationship to avoid performance issues
|
||||
query = (
|
||||
@@ -114,7 +138,7 @@ class SourceManager:
|
||||
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
|
||||
.where(
|
||||
SourcesAgents.source_id == source_id,
|
||||
AgentModel.organization_id == actor.organization_id if actor else True,
|
||||
AgentModel.organization_id == actor.organization_id,
|
||||
AgentModel.is_deleted == False,
|
||||
)
|
||||
.order_by(AgentModel.created_at.desc(), AgentModel.id)
|
||||
@@ -125,6 +149,31 @@ class SourceManager:
|
||||
|
||||
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_agents_for_source_id(self, source_id: str, actor: PydanticUser) -> List[str]:
|
||||
"""
|
||||
Get all agent IDs associated with a given source ID.
|
||||
|
||||
Args:
|
||||
source_id: ID of the source to find agents for
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
List[str]: List of agent IDs that have this source attached
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify source exists and user has permission to access it
|
||||
await self._validate_source_exists_async(session, source_id, actor)
|
||||
|
||||
# Query the junction table directly for performance
|
||||
query = select(SourcesAgents.agent_id).where(SourcesAgents.source_id == source_id)
|
||||
|
||||
result = await session.execute(query)
|
||||
agent_ids = result.scalars().all()
|
||||
|
||||
return list(agent_ids)
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
Reference in New Issue
Block a user