feat: Bulk remove from context window on source detach (#3495)
This commit is contained in:
@@ -1470,17 +1470,23 @@ class SyncServer(Server):
|
||||
Remove the document from the context window of all agents
|
||||
attached to the given source.
|
||||
"""
|
||||
# TODO: We probably do NOT need to get the entire agent state, we can just get the IDs
|
||||
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
|
||||
# Use the optimized ids_only parameter
|
||||
agent_ids = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor, ids_only=True)
|
||||
|
||||
# Return early
|
||||
if not agent_states:
|
||||
# Return early if no agents
|
||||
if not agent_ids:
|
||||
return
|
||||
|
||||
logger.info(f"Removing file from context window for source: {source_id}")
|
||||
logger.info(f"Attached agents: {[a.id for a in agent_states]}")
|
||||
logger.info(f"Attached agents: {agent_ids}")
|
||||
|
||||
await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for agent_state in agent_states))
|
||||
# Create agent-file pairs for bulk deletion
|
||||
agent_file_pairs = [(agent_id, file_id) for agent_id in agent_ids]
|
||||
|
||||
# Bulk delete in a single query
|
||||
deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor)
|
||||
|
||||
logger.info(f"Removed file {file_id} from {deleted_count} agent context windows")
|
||||
|
||||
async def remove_files_from_context_window(self, agent_state: AgentState, file_ids: List[str], actor: User) -> None:
|
||||
"""
|
||||
@@ -1490,7 +1496,13 @@ class SyncServer(Server):
|
||||
logger.info(f"Removing files from context window for agent_state: {agent_state.id}")
|
||||
logger.info(f"Files to remove: {file_ids}")
|
||||
|
||||
await asyncio.gather(*(self._remove_file_from_agent(agent_state.id, file_id, actor) for file_id in file_ids))
|
||||
# Create agent-file pairs for bulk deletion
|
||||
agent_file_pairs = [(agent_state.id, file_id) for file_id in file_ids]
|
||||
|
||||
# Bulk delete in a single query
|
||||
deleted_count = await self.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=actor)
|
||||
|
||||
logger.info(f"Removed {deleted_count} files from agent {agent_state.id}")
|
||||
|
||||
async def create_document_sleeptime_agent_async(
|
||||
self, main_agent: AgentState, source: Source, actor: User, clear_history: bool = False
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import and_, func, select, update
|
||||
from sqlalchemy import and_, delete, func, or_, select, update
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -162,6 +162,36 @@ class FileAgentManager:
|
||||
assoc = await self._get_association_by_file_id(session, agent_id, file_id, actor)
|
||||
await assoc.hard_delete_async(session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def detach_file_bulk(self, *, agent_file_pairs: List, actor: PydanticUser) -> int: # List of (agent_id, file_id) tuples
|
||||
"""
|
||||
Bulk delete multiple agent-file associations in a single query.
|
||||
|
||||
Args:
|
||||
agent_file_pairs: List of (agent_id, file_id) tuples to delete
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
Number of rows deleted
|
||||
"""
|
||||
if not agent_file_pairs:
|
||||
return 0
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Build compound OR conditions for each agent-file pair
|
||||
conditions = []
|
||||
for agent_id, file_id in agent_file_pairs:
|
||||
conditions.append(and_(FileAgentModel.agent_id == agent_id, FileAgentModel.file_id == file_id))
|
||||
|
||||
# Create delete statement with all conditions
|
||||
stmt = delete(FileAgentModel).where(and_(or_(*conditions), FileAgentModel.organization_id == actor.organization_id))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
return result.rowcount
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_file_agent_by_id(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> Optional[PydanticFileAgent]:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sqlalchemy import and_, exists, select
|
||||
|
||||
@@ -234,37 +234,56 @@ class SourceManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_attached_agents(self, source_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
|
||||
async def list_attached_agents(
|
||||
self, source_id: str, actor: PydanticUser, ids_only: bool = False
|
||||
) -> Union[List[PydanticAgentState], List[str]]:
|
||||
"""
|
||||
Lists all agents that have the specified source attached.
|
||||
|
||||
Args:
|
||||
source_id: ID of the source to find attached agents for
|
||||
actor: User performing the action
|
||||
ids_only: If True, return only agent IDs instead of full agent states
|
||||
|
||||
Returns:
|
||||
List[PydanticAgentState]: List of agents that have this source attached
|
||||
List[PydanticAgentState] | List[str]: List of agents or agent IDs that have this source attached
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Verify source exists and user has permission to access it
|
||||
await self._validate_source_exists_async(session, source_id, actor)
|
||||
|
||||
# Use junction table query instead of relationship to avoid performance issues
|
||||
query = (
|
||||
select(AgentModel)
|
||||
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
|
||||
.where(
|
||||
SourcesAgents.source_id == source_id,
|
||||
AgentModel.organization_id == actor.organization_id,
|
||||
AgentModel.is_deleted == False,
|
||||
if ids_only:
|
||||
# Query only agent IDs for performance
|
||||
query = (
|
||||
select(AgentModel.id)
|
||||
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
|
||||
.where(
|
||||
SourcesAgents.source_id == source_id,
|
||||
AgentModel.organization_id == actor.organization_id,
|
||||
AgentModel.is_deleted == False,
|
||||
)
|
||||
.order_by(AgentModel.created_at.desc(), AgentModel.id)
|
||||
)
|
||||
.order_by(AgentModel.created_at.desc(), AgentModel.id)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
else:
|
||||
# Use junction table query instead of relationship to avoid performance issues
|
||||
query = (
|
||||
select(AgentModel)
|
||||
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
|
||||
.where(
|
||||
SourcesAgents.source_id == source_id,
|
||||
AgentModel.organization_id == actor.organization_id,
|
||||
AgentModel.is_deleted == False,
|
||||
)
|
||||
.order_by(AgentModel.created_at.desc(), AgentModel.id)
|
||||
)
|
||||
|
||||
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
|
||||
result = await session.execute(query)
|
||||
agents_orm = result.scalars().all()
|
||||
|
||||
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -8268,6 +8268,95 @@ async def test_detach_file(server, file_attachment, default_user):
|
||||
assert res is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_file_bulk(
|
||||
server,
|
||||
default_user,
|
||||
sarah_agent,
|
||||
charles_agent,
|
||||
default_source,
|
||||
):
|
||||
"""Test bulk deletion of multiple agent-file associations."""
|
||||
# Create multiple files
|
||||
files = []
|
||||
for i in range(3):
|
||||
file_metadata = PydanticFileMetadata(
|
||||
file_name=f"test_file_{i}.txt",
|
||||
source_id=default_source.id,
|
||||
organization_id=default_user.organization_id,
|
||||
)
|
||||
file = await server.file_manager.create_file(file_metadata, actor=default_user)
|
||||
files.append(file)
|
||||
|
||||
# Attach all files to both agents
|
||||
for file in files:
|
||||
await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
source_id=file.source_id,
|
||||
actor=default_user,
|
||||
max_files_open=sarah_agent.max_files_open,
|
||||
)
|
||||
await server.file_agent_manager.attach_file(
|
||||
agent_id=charles_agent.id,
|
||||
file_id=file.id,
|
||||
file_name=file.file_name,
|
||||
source_id=file.source_id,
|
||||
actor=default_user,
|
||||
max_files_open=charles_agent.max_files_open,
|
||||
)
|
||||
|
||||
# Verify all files are attached to both agents
|
||||
sarah_files = await server.file_agent_manager.list_files_for_agent(
|
||||
sarah_agent.id, per_file_view_window_char_limit=sarah_agent.per_file_view_window_char_limit, actor=default_user
|
||||
)
|
||||
charles_files = await server.file_agent_manager.list_files_for_agent(
|
||||
charles_agent.id, per_file_view_window_char_limit=charles_agent.per_file_view_window_char_limit, actor=default_user
|
||||
)
|
||||
assert len(sarah_files) == 3
|
||||
assert len(charles_files) == 3
|
||||
|
||||
# Test 1: Bulk delete specific files from specific agents
|
||||
agent_file_pairs = [
|
||||
(sarah_agent.id, files[0].id), # Remove file 0 from sarah
|
||||
(sarah_agent.id, files[1].id), # Remove file 1 from sarah
|
||||
(charles_agent.id, files[1].id), # Remove file 1 from charles
|
||||
]
|
||||
|
||||
deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=default_user)
|
||||
assert deleted_count == 3
|
||||
|
||||
# Verify the correct files were deleted
|
||||
sarah_files = await server.file_agent_manager.list_files_for_agent(
|
||||
sarah_agent.id, per_file_view_window_char_limit=sarah_agent.per_file_view_window_char_limit, actor=default_user
|
||||
)
|
||||
charles_files = await server.file_agent_manager.list_files_for_agent(
|
||||
charles_agent.id, per_file_view_window_char_limit=charles_agent.per_file_view_window_char_limit, actor=default_user
|
||||
)
|
||||
|
||||
# Sarah should only have file 2 left
|
||||
assert len(sarah_files) == 1
|
||||
assert sarah_files[0].file_id == files[2].id
|
||||
|
||||
# Charles should have files 0 and 2 left
|
||||
assert len(charles_files) == 2
|
||||
charles_file_ids = {f.file_id for f in charles_files}
|
||||
assert charles_file_ids == {files[0].id, files[2].id}
|
||||
|
||||
# Test 2: Empty list should return 0 and not fail
|
||||
deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=[], actor=default_user)
|
||||
assert deleted_count == 0
|
||||
|
||||
# Test 3: Attempting to delete already deleted associations should return 0
|
||||
agent_file_pairs = [
|
||||
(sarah_agent.id, files[0].id), # Already deleted
|
||||
(sarah_agent.id, files[1].id), # Already deleted
|
||||
]
|
||||
deleted_count = await server.file_agent_manager.detach_file_bulk(agent_file_pairs=agent_file_pairs, actor=default_user)
|
||||
assert deleted_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_scoping(
|
||||
server,
|
||||
|
||||
Reference in New Issue
Block a user