diff --git a/letta/server/server.py b/letta/server/server.py index cda9f8e0..6f9af1de 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1470,17 +1470,23 @@ class SyncServer(Server): Remove the document from the context window of all agents attached to the given source. """ - # TODO: We probably do NOT need to get the entire agent state, we can just get the IDs - agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor) + # Use the optimized ids_only parameter + agent_ids = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor, ids_only=True) - # Return early - if not agent_states: + # Return early if no agents + if not agent_ids: return logger.info(f"Removing file from context window for source: {source_id}") - logger.info(f"Attached agents: {[a.id for a in agent_states]}") + logger.info(f"Attached agents: {agent_ids}") - await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for agent_state in agent_states)) + # Create agent-file pairs for bulk deletion + agent_file_pairs = [(agent_id, file_id) for agent_id in agent_ids] + + # Bulk delete in a single query + deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor) + + logger.info(f"Removed file {file_id} from {deleted_count} agent context windows") async def remove_files_from_context_window(self, agent_state: AgentState, file_ids: List[str], actor: User) -> None: """ @@ -1490,7 +1496,13 @@ class SyncServer(Server): logger.info(f"Removing files from context window for agent_state: {agent_state.id}") logger.info(f"Files to remove: {file_ids}") - await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for file_id in file_ids)) + # Create agent-file pairs for bulk deletion + agent_file_pairs = [(agent_state.id, file_id) for file_id in file_ids] + + # Bulk delete in a single query + deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor) + + logger.info(f"Removed {deleted_count} files from agent {agent_state.id}") async def create_document_sleeptime_agent_async( self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False diff --git a/letta/services/files_agents_manager.py b/letta/services/files_agents_manager.py index 242ad8f9..e04ccb59 100644 --- a/letta/services/files_agents_manager.py +++ b/letta/services/files_agents_manager.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from typing import List, Optional, Union -from sqlalchemy import and_, func, select, update +from sqlalchemy import and_, delete, func, or_, select, update from letta.log import get_logger from letta.orm.errors import NoResultFound @@ -162,6 +162,36 @@ class FileAgentManager: assoc = await self._get_association_by_file_id(session, agent_id, file_id, actor) await assoc.hard_delete_async(session, actor=actor) + @enforce_types + @trace_method + async def detach_file_bulk(self, *, agent_file_pairs: List, actor: PydanticUser) -> int: # List of (agent_id, file_id) tuples + """ + Bulk delete multiple agent-file associations in a single query. + + Args: + agent_file_pairs: List of (agent_id, file_id) tuples to delete + actor: User performing the action + + Returns: + Number of rows deleted + """ + if not agent_file_pairs: + return 0 + + async with db_registry.async_session() as session: + # Build compound OR conditions for each agent-file pair + conditions = [] + for agent_id, file_id in agent_file_pairs: + conditions.append(and_(FileAgentModel.agent_id == agent_id, FileAgentModel.file_id == file_id)) + + # Create delete statement with all conditions + stmt = delete(FileAgentModel).where(and_(or_(*conditions), FileAgentModel.organization_id == actor.organization_id)) + + result = await session.execute(stmt) + await session.commit() + + return result.rowcount + @enforce_types @trace_method async def get_file_agent_by_id(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> Optional[PydanticFileAgent]: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 96786919..dbab4f29 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Optional +from typing import List, Optional, Union from sqlalchemy import and_, exists, select @@ -234,37 +234,56 @@ class SourceManager: @enforce_types @trace_method - async def list_attached_agents(self, source_id: str, actor: PydanticUser) -> List[PydanticAgentState]: + async def list_attached_agents( + self, source_id: str, actor: PydanticUser, ids_only: bool = False + ) -> Union[List[PydanticAgentState], List[str]]: """ Lists all agents that have the specified source attached. Args: source_id: ID of the source to find attached agents for actor: User performing the action + ids_only: If True, return only agent IDs instead of full agent states Returns: - List[PydanticAgentState]: List of agents that have this source attached + List[PydanticAgentState] | List[str]: List of agents or agent IDs that have this source attached """ async with db_registry.async_session() as session: # Verify source exists and user has permission to access it await self._validate_source_exists_async(session, source_id, actor) - # Use junction table query instead of relationship to avoid performance issues - query = ( - select(AgentModel) - .join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id) - .where( - SourcesAgents.source_id == source_id, - AgentModel.organization_id == actor.organization_id, - AgentModel.is_deleted == False, + if ids_only: + # Query only agent IDs for performance + query = ( + select(AgentModel.id) + .join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id) + .where( + SourcesAgents.source_id == source_id, + AgentModel.organization_id == actor.organization_id, + AgentModel.is_deleted == False, + ) + .order_by(AgentModel.created_at.desc(), AgentModel.id) ) - .order_by(AgentModel.created_at.desc(), AgentModel.id) - ) - result = await session.execute(query) - agents_orm = result.scalars().all() + result = await session.execute(query) + return list(result.scalars().all()) + else: + # Use junction table query instead of relationship to avoid performance issues + query = ( + select(AgentModel) + .join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id) + .where( + SourcesAgents.source_id == source_id, + AgentModel.organization_id == actor.organization_id, + AgentModel.is_deleted == False, + ) + .order_by(AgentModel.created_at.desc(), AgentModel.id) + ) - return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm]) + result = await session.execute(query) + agents_orm = result.scalars().all() + + return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm]) @enforce_types @trace_method diff --git a/tests/test_managers.py b/tests/test_managers.py index 390e2ce9..acc9d6d2 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -8268,6 +8268,95 @@ async def test_detach_file(server, file_attachment, default_user): assert res is None +@pytest.mark.asyncio +async def test_detach_file_bulk( + server, + default_user, + sarah_agent, + charles_agent, + default_source, +): + """Test bulk deletion of multiple agent-file associations.""" + # Create multiple files + files = [] + for i in range(3): + file_metadata = PydanticFileMetadata( + file_name=f"test_file_{i}.txt", + source_id=default_source.id, + organization_id=default_user.organization_id, + ) + file = await server.file_manager.create_file(file_metadata, actor=default_user) + files.append(file) + + # Attach all files to both agents + for file in files: + await server.file_agent_manager.attach_file( + agent_id=sarah_agent.id, + file_id=file.id, + file_name=file.file_name, + source_id=file.source_id, + actor=default_user, + max_files_open=sarah_agent.max_files_open, + ) + await server.file_agent_manager.attach_file( + agent_id=charles_agent.id, + file_id=file.id, + file_name=file.file_name, + source_id=file.source_id, + actor=default_user, + max_files_open=charles_agent.max_files_open, + ) + + # Verify all files are attached to both agents + sarah_files = await server.file_agent_manager.list_files_for_agent( + sarah_agent.id, per_file_view_window_char_limit=sarah_agent.per_file_view_window_char_limit, actor=default_user + ) + charles_files = await server.file_agent_manager.list_files_for_agent( + charles_agent.id, per_file_view_window_char_limit=charles_agent.per_file_view_window_char_limit, actor=default_user + ) + assert len(sarah_files) == 3 + assert len(charles_files) == 3 + + # Test 1: Bulk delete specific files from specific agents + agent_file_pairs = [ + (sarah_agent.id, files[0].id), # Remove file 0 from sarah + (sarah_agent.id, files[1].id), # Remove file 1 from sarah + (charles_agent.id, files[1].id), # Remove file 1 from charles + ] + + deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=default_user) + assert deleted_count == 3 + + # Verify the correct files were deleted + sarah_files = await server.file_agent_manager.list_files_for_agent( + sarah_agent.id, per_file_view_window_char_limit=sarah_agent.per_file_view_window_char_limit, actor=default_user + ) + charles_files = await server.file_agent_manager.list_files_for_agent( + charles_agent.id, per_file_view_window_char_limit=charles_agent.per_file_view_window_char_limit, actor=default_user + ) + + # Sarah should only have file 2 left + assert len(sarah_files) == 1 + assert sarah_files[0].file_id == files[2].id + + # Charles should have files 0 and 2 left + assert len(charles_files) == 2 + charles_file_ids = {f.file_id for f in charles_files} + assert charles_file_ids == {files[0].id, files[2].id} + + # Test 2: Empty list should return 0 and not fail + deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=[], actor=default_user) + assert deleted_count == 0 + + # Test 3: Attempting to delete already deleted associations should return 0 + agent_file_pairs = [ + (sarah_agent.id, files[0].id), # Already deleted + (sarah_agent.id, files[1].id), # Already deleted + ] + deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=default_user) + assert deleted_count == 0 + + @pytest.mark.asyncio async def test_org_scoping( server,