feat: Bulk remove from context window on source detach (#3495)

This commit is contained in:
Matthew Zhou
2025-07-22 17:22:51 -07:00
committed by GitHub
parent 58081e3cea
commit fcb03be382
4 changed files with 174 additions and 24 deletions

View File

@@ -1470,17 +1470,23 @@ class SyncServer(Server):
Remove the document from the context window of all agents
attached to the given source.
"""
# TODO: We probably do NOT need to get the entire agent state, we can just get the IDs
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
# Use the optimized ids_only parameter
agent_ids = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor, ids_only=True)
# Return early
if not agent_states:
# Return early if no agents
if not agent_ids:
return
logger.info(f"Removing file from context window for source: {source_id}")
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
logger.info(f"Attached agents: {agent_ids}")
await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for agent_state in agent_states))
# Create agent-file pairs for bulk deletion
agent_file_pairs = [(agent_id, file_id) for agent_id in agent_ids]
# Bulk delete in a single query
deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor)
logger.info(f"Removed file {file_id} from {deleted_count} agent context windows")
async def remove_files_from_context_window(self, agent_state: AgentState, file_ids: List[str], actor: User) -> None:
"""
@@ -1490,7 +1496,13 @@ class SyncServer(Server):
logger.info(f"Removing files from context window for agent_state: {agent_state.id}")
logger.info(f"Files to remove: {file_ids}")
await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for file_id in file_ids))
# Create agent-file pairs for bulk deletion
agent_file_pairs = [(agent_state.id, file_id) for file_id in file_ids]
# Bulk delete in a single query
deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor)
logger.info(f"Removed {deleted_count} files from agent {agent_state.id}")
async def create_document_sleeptime_agent_async(
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False

View File

@@ -1,7 +1,7 @@
from datetime import datetime, timezone
from typing import List, Optional, Union
from sqlalchemy import and_, func, select, update
from sqlalchemy import and_, delete, func, or_, select, update
from letta.log import get_logger
from letta.orm.errors import NoResultFound
@@ -162,6 +162,36 @@ class FileAgentManager:
assoc = await self._get_association_by_file_id(session, agent_id, file_id, actor)
await assoc.hard_delete_async(session, actor=actor)
@enforce_types
@trace_method
async def detach_file_bulk(self, *, agent_file_pairs: List, actor: PydanticUser) -> int: # List of (agent_id, file_id) tuples
"""
Bulk delete multiple agent-file associations in a single query.
Args:
agent_file_pairs: List of (agent_id, file_id) tuples to delete
actor: User performing the action
Returns:
Number of rows deleted
"""
if not agent_file_pairs:
return 0
async with db_registry.async_session() as session:
# Build compound OR conditions for each agent-file pair
conditions = []
for agent_id, file_id in agent_file_pairs:
conditions.append(and_(FileAgentModel.agent_id == agent_id, FileAgentModel.file_id == file_id))
# Create delete statement with all conditions
stmt = delete(FileAgentModel).where(and_(or_(*conditions), FileAgentModel.organization_id == actor.organization_id))
result = await session.execute(stmt)
await session.commit()
return result.rowcount
@enforce_types
@trace_method
async def get_file_agent_by_id(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> Optional[PydanticFileAgent]:

View File

@@ -1,5 +1,5 @@
import asyncio
from typing import List, Optional
from typing import List, Optional, Union
from sqlalchemy import and_, exists, select
@@ -234,37 +234,56 @@ class SourceManager:
@enforce_types
@trace_method
async def list_attached_agents(self, source_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
async def list_attached_agents(
self, source_id: str, actor: PydanticUser, ids_only: bool = False
) -> Union[List[PydanticAgentState], List[str]]:
"""
Lists all agents that have the specified source attached.
Args:
source_id: ID of the source to find attached agents for
actor: User performing the action
ids_only: If True, return only agent IDs instead of full agent states
Returns:
List[PydanticAgentState]: List of agents that have this source attached
List[PydanticAgentState] | List[str]: List of agents or agent IDs that have this source attached
"""
async with db_registry.async_session() as session:
# Verify source exists and user has permission to access it
await self._validate_source_exists_async(session, source_id, actor)
# Use junction table query instead of relationship to avoid performance issues
query = (
select(AgentModel)
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id,
AgentModel.is_deleted == False,
if ids_only:
# Query only agent IDs for performance
query = (
select(AgentModel.id)
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id,
AgentModel.is_deleted == False,
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
)
result = await session.execute(query)
agents_orm = result.scalars().all()
result = await session.execute(query)
return list(result.scalars().all())
else:
# Use junction table query instead of relationship to avoid performance issues
query = (
select(AgentModel)
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id,
AgentModel.is_deleted == False,
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
)
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
result = await session.execute(query)
agents_orm = result.scalars().all()
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
@enforce_types
@trace_method

View File

@@ -8268,6 +8268,95 @@ async def test_detach_file(server, file_attachment, default_user):
assert res is None
@pytest.mark.asyncio
async def test_detach_file_bulk(
server,
default_user,
sarah_agent,
charles_agent,
default_source,
):
"""Test bulk deletion of multiple agent-file associations."""
# Create multiple files
files = []
for i in range(3):
file_metadata = PydanticFileMetadata(
file_name=f"test_file_{i}.txt",
source_id=default_source.id,
organization_id=default_user.organization_id,
)
file = await server.file_manager.create_file(file_metadata, actor=default_user)
files.append(file)
# Attach all files to both agents
for file in files:
await server.file_agent_manager.attach_file(
agent_id=sarah_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
max_files_open=sarah_agent.max_files_open,
)
await server.file_agent_manager.attach_file(
agent_id=charles_agent.id,
file_id=file.id,
file_name=file.file_name,
source_id=file.source_id,
actor=default_user,
max_files_open=charles_agent.max_files_open,
)
# Verify all files are attached to both agents
sarah_files = await server.file_agent_manager.list_files_for_agent(
sarah_agent.id, per_file_view_window_char_limit=sarah_agent.per_file_view_window_char_limit, actor=default_user
)
charles_files = await server.file_agent_manager.list_files_for_agent(
charles_agent.id, per_file_view_window_char_limit=charles_agent.per_file_view_window_char_limit, actor=default_user
)
assert len(sarah_files) == 3
assert len(charles_files) == 3
# Test 1: Bulk delete specific files from specific agents
agent_file_pairs = [
(sarah_agent.id, files[0].id), # Remove file 0 from sarah
(sarah_agent.id, files[1].id), # Remove file 1 from sarah
(charles_agent.id, files[1].id), # Remove file 1 from charles
]
deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=default_user)
assert deleted_count == 3
# Verify the correct files were deleted
sarah_files = await server.file_agent_manager.list_files_for_agent(
sarah_agent.id, per_file_view_window_char_limit=sarah_agent.per_file_view_window_char_limit, actor=default_user
)
charles_files = await server.file_agent_manager.list_files_for_agent(
charles_agent.id, per_file_view_window_char_limit=charles_agent.per_file_view_window_char_limit, actor=default_user
)
# Sarah should only have file 2 left
assert len(sarah_files) == 1
assert sarah_files[0].file_id == files[2].id
# Charles should have files 0 and 2 left
assert len(charles_files) == 2
charles_file_ids = {f.file_id for f in charles_files}
assert charles_file_ids == {files[0].id, files[2].id}
# Test 2: Empty list should return 0 and not fail
deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=[], actor=default_user)
assert deleted_count == 0
# Test 3: Attempting to delete already deleted associations should return 0
agent_file_pairs = [
(sarah_agent.id, files[0].id), # Already deleted
(sarah_agent.id, files[1].id), # Already deleted
]
deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=default_user)
assert deleted_count == 0
@pytest.mark.asyncio
async def test_org_scoping(
server,