fix: Fix constraints and also implement bulk attach (#3107)

This commit is contained in:
Matthew Zhou
2025-06-30 14:27:57 -07:00
committed by GitHub
parent 3aee153040
commit 5dccccec21
6 changed files with 495 additions and 45 deletions

View File

@@ -4,15 +4,19 @@ from typing import List, Optional
from sqlalchemy import and_, func, select, update
from letta.constants import MAX_FILES_OPEN
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.files_agents import FileAgent as FileAgentModel
from letta.otel.tracing import trace_method
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.file import FileAgent as PydanticFileAgent
from letta.schemas.file import FileMetadata
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
logger = get_logger(__name__)
class FileAgentManager:
"""High-level helpers for CRUD / listing on the `files_agents` join table."""
@@ -423,6 +427,117 @@ class FileAgentManager:
return closed_file_names, file_was_already_open
@enforce_types
@trace_method
async def attach_files_bulk(
self,
*,
agent_id: str,
files_metadata: list[FileMetadata],
visible_content_map: Optional[dict[str, str]] = None,
actor: PydanticUser,
) -> list[str]:
"""Atomically attach many files, applying an LRU cap with one commit."""
if not files_metadata:
return []
# TODO: This is not strictly necessary, as the file_metadata should never be duped
# TODO: But we have this as a protection, check logs for details
# dedupe while preserving caller order
seen: set[str] = set()
ordered_unique: list[FileMetadata] = []
for m in files_metadata:
if m.file_name not in seen:
ordered_unique.append(m)
seen.add(m.file_name)
if (dup_cnt := len(files_metadata) - len(ordered_unique)) > 0:
logger.warning(
"attach_files_bulk: removed %d duplicate file(s) for agent %s",
dup_cnt,
agent_id,
)
now = datetime.now(timezone.utc)
vc_for = visible_content_map or {}
async with db_registry.async_session() as session:
# fetch existing assoc rows for requested names
existing_q = select(FileAgentModel).where(
FileAgentModel.agent_id == agent_id,
FileAgentModel.organization_id == actor.organization_id,
FileAgentModel.file_name.in_(seen),
)
existing_rows = (await session.execute(existing_q)).scalars().all()
existing_by_name = {r.file_name: r for r in existing_rows}
# snapshot current OPEN rows (oldest first)
open_q = (
select(FileAgentModel)
.where(
FileAgentModel.agent_id == agent_id,
FileAgentModel.organization_id == actor.organization_id,
FileAgentModel.is_open.is_(True),
)
.order_by(FileAgentModel.last_accessed_at.asc())
)
currently_open = (await session.execute(open_q)).scalars().all()
new_names = [m.file_name for m in ordered_unique]
new_names_set = set(new_names)
still_open_names = [r.file_name for r in currently_open if r.file_name not in new_names_set]
# decide final open set
if len(new_names) >= MAX_FILES_OPEN:
final_open = new_names[:MAX_FILES_OPEN]
else:
room_for_old = MAX_FILES_OPEN - len(new_names)
final_open = new_names + still_open_names[-room_for_old:]
final_open_set = set(final_open)
closed_file_names = [r.file_name for r in currently_open if r.file_name not in final_open_set]
# Add new files that won't be opened due to MAX_FILES_OPEN limit
if len(new_names) >= MAX_FILES_OPEN:
closed_file_names.extend(new_names[MAX_FILES_OPEN:])
evicted_ids = [r.file_id for r in currently_open if r.file_name in closed_file_names]
# upsert requested files
for meta in ordered_unique:
is_now_open = meta.file_name in final_open_set
vc = vc_for.get(meta.file_name, "") if is_now_open else None
if row := existing_by_name.get(meta.file_name):
row.is_open = is_now_open
row.visible_content = vc
row.last_accessed_at = now
session.add(row) # already present, but safe
else:
session.add(
FileAgentModel(
agent_id=agent_id,
file_id=meta.id,
file_name=meta.file_name,
organization_id=actor.organization_id,
is_open=is_now_open,
visible_content=vc,
last_accessed_at=now,
)
)
# bulk-close evicted rows
if evicted_ids:
await session.execute(
update(FileAgentModel)
.where(
FileAgentModel.agent_id == agent_id,
FileAgentModel.organization_id == actor.organization_id,
FileAgentModel.file_id.in_(evicted_ids),
)
.values(is_open=False, visible_content=None)
)
await session.commit()
return closed_file_names
async def _get_association_by_file_id(self, session, agent_id: str, file_id: str, actor: PydanticUser) -> FileAgentModel:
q = select(FileAgentModel).where(
and_(