feat: Rename file object and create performant files <-> agents association (#2588)
This commit is contained in:
54
alembic/versions/0b496eae90de_add_file_agent_table.py
Normal file
54
alembic/versions/0b496eae90de_add_file_agent_table.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Add file agent table
|
||||
|
||||
Revision ID: 0b496eae90de
|
||||
Revises: 341068089f14
|
||||
Create Date: 2025-06-02 15:14:33.730687
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0b496eae90de"
|
||||
down_revision: Union[str, None] = "341068089f14"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"files_agents",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("file_id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("is_open", sa.Boolean(), nullable=False),
|
||||
sa.Column("visible_content", sa.Text(), nullable=True),
|
||||
sa.Column("last_accessed_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["file_id"], ["files.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "file_id", "agent_id"),
|
||||
)
|
||||
op.create_index("ix_files_agents_file_id_agent_id", "files_agents", ["file_id", "agent_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_files_agents_file_id_agent_id", table_name="files_agents")
|
||||
op.drop_table("files_agents")
|
||||
# ### end Alembic commands ###
|
||||
@@ -5,6 +5,7 @@ from letta.orm.block import Block
|
||||
from letta.orm.block_history import BlockHistory
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.files_agents import FileAgent
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.groups_agents import GroupsAgents
|
||||
from letta.orm.groups_blocks import GroupsBlocks
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
||||
"""Represents metadata for an uploaded file."""
|
||||
"""Represents an uploaded file."""
|
||||
|
||||
__tablename__ = "files"
|
||||
__pydantic_model__ = PydanticFileMetadata
|
||||
|
||||
42
letta/orm/files_agents.py
Normal file
42
letta/orm/files_agents.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.file import FileAgent as PydanticFileAgent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class FileAgent(SqlalchemyBase, OrganizationMixin):
|
||||
"""
|
||||
Join table between File and Agent.
|
||||
|
||||
Tracks whether a file is currently “open” for the agent and
|
||||
the specific excerpt (grepped section) the agent is looking at.
|
||||
"""
|
||||
|
||||
__tablename__ = "files_agents"
|
||||
__table_args__ = (Index("ix_files_agents_file_id_agent_id", "file_id", "agent_id"),)
|
||||
__pydantic_model__ = PydanticFileAgent
|
||||
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
# TODO: Some still rely on the Pydantic object to do this
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"file_agent-{uuid.uuid4()}")
|
||||
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id", ondelete="CASCADE"), primary_key=True, doc="ID of the file.")
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True, doc="ID of the agent.")
|
||||
|
||||
is_open: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, doc="True if the agent currently has the file open.")
|
||||
visible_content: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Portion of the file the agent is focused on.")
|
||||
last_accessed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
doc="UTC timestamp when this agent last accessed the file.",
|
||||
)
|
||||
@@ -29,3 +29,50 @@ class FileMetadata(FileMetadataBase):
|
||||
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.")
|
||||
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the file.")
|
||||
is_deleted: bool = Field(False, description="Whether this file is deleted or not.")
|
||||
|
||||
|
||||
class FileAgentBase(LettaBase):
|
||||
"""Base class for the FileMetadata-⇄-Agent association schemas"""
|
||||
|
||||
__id_prefix__ = "file_agent"
|
||||
|
||||
|
||||
class FileAgent(FileAgentBase):
|
||||
"""
|
||||
A single FileMetadata ⇄ Agent association row.
|
||||
|
||||
Captures:
|
||||
• whether the agent currently has the file “open”
|
||||
• the excerpt (grepped section) in the context window
|
||||
• the last time the agent accessed the file
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
...,
|
||||
description="The internal ID",
|
||||
)
|
||||
organization_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Org ID this association belongs to (inherited from both agent and file).",
|
||||
)
|
||||
agent_id: str = Field(..., description="Unique identifier of the agent.")
|
||||
file_id: str = Field(..., description="Unique identifier of the file.")
|
||||
is_open: bool = Field(True, description="True if the agent currently has the file open.")
|
||||
visible_content: Optional[str] = Field(
|
||||
None,
|
||||
description="Portion of the file the agent is focused on (may be large).",
|
||||
)
|
||||
last_accessed_at: Optional[datetime] = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="UTC timestamp of the agent’s most recent access to this file.",
|
||||
)
|
||||
|
||||
created_at: Optional[datetime] = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="Row creation timestamp (UTC).",
|
||||
)
|
||||
updated_at: Optional[datetime] = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="Row last-update timestamp (UTC).",
|
||||
)
|
||||
is_deleted: bool = Field(False, description="Soft-delete flag.")
|
||||
|
||||
@@ -79,6 +79,7 @@ from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.utils import sse_async_generator
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.files_agents_manager import FileAgentManager
|
||||
from letta.services.group_manager import GroupManager
|
||||
from letta.services.helpers.tool_execution_helper import prepare_local_sandbox
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
@@ -217,6 +218,7 @@ class SyncServer(Server):
|
||||
self.group_manager = GroupManager()
|
||||
self.batch_manager = LLMBatchManager()
|
||||
self.telemetry_manager = TelemetryManager()
|
||||
self.file_agent_manager = FileAgentManager()
|
||||
|
||||
# A resusable httpx client
|
||||
timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0)
|
||||
|
||||
184
letta/services/files_agents_manager.py
Normal file
184
letta/services/files_agents_manager.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_, func, select, update
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.files_agents import FileAgent as FileAgentModel
|
||||
from letta.schemas.file import FileAgent as PydanticFileAgent
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class FileAgentManager:
|
||||
"""High-level helpers for CRUD / listing on the `files_agents` join table."""
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def attach_file(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
file_id: str,
|
||||
actor: PydanticUser,
|
||||
is_open: bool = True,
|
||||
visible_content: Optional[str] = None,
|
||||
) -> PydanticFileAgent:
|
||||
"""
|
||||
Idempotently attach *file_id* to *agent_id*.
|
||||
|
||||
• If the row already exists → update `is_open`, `visible_content`
|
||||
and always refresh `last_accessed_at`.
|
||||
• Otherwise create a brand-new association.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(FileAgentModel).where(
|
||||
and_(
|
||||
FileAgentModel.agent_id == agent_id,
|
||||
FileAgentModel.file_id == file_id,
|
||||
FileAgentModel.organization_id == actor.organization_id,
|
||||
)
|
||||
)
|
||||
existing = await session.scalar(query)
|
||||
|
||||
now_ts = datetime.now(timezone.utc)
|
||||
|
||||
if existing:
|
||||
# update only the fields that actually changed
|
||||
if existing.is_open != is_open:
|
||||
existing.is_open = is_open
|
||||
|
||||
if visible_content is not None and existing.visible_content != visible_content:
|
||||
existing.visible_content = visible_content
|
||||
|
||||
existing.last_accessed_at = now_ts
|
||||
|
||||
await existing.update_async(session, actor=actor)
|
||||
return existing.to_pydantic()
|
||||
|
||||
assoc = FileAgentModel(
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
organization_id=actor.organization_id,
|
||||
is_open=is_open,
|
||||
visible_content=visible_content,
|
||||
last_accessed_at=now_ts,
|
||||
)
|
||||
await assoc.create_async(session, actor=actor)
|
||||
return assoc.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_file_agent(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
file_id: str,
|
||||
actor: PydanticUser,
|
||||
is_open: Optional[bool] = None,
|
||||
visible_content: Optional[str] = None,
|
||||
) -> PydanticFileAgent:
|
||||
"""Patch an existing association row."""
|
||||
async with db_registry.async_session() as session:
|
||||
assoc = await self._get_association(session, agent_id, file_id, actor)
|
||||
|
||||
if is_open is not None:
|
||||
assoc.is_open = is_open
|
||||
if visible_content is not None:
|
||||
assoc.visible_content = visible_content
|
||||
|
||||
# touch timestamp
|
||||
assoc.last_accessed_at = datetime.now(timezone.utc)
|
||||
|
||||
await assoc.update_async(session, actor=actor)
|
||||
return assoc.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def detach_file(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard-delete the association."""
|
||||
async with db_registry.async_session() as session:
|
||||
assoc = await self._get_association(session, agent_id, file_id, actor)
|
||||
await assoc.hard_delete_async(session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_file_agent(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> Optional[PydanticFileAgent]:
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
assoc = await self._get_association(session, agent_id, file_id, actor)
|
||||
return assoc.to_pydantic()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_files_for_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser,
|
||||
is_open_only: bool = False,
|
||||
) -> List[PydanticFileAgent]:
|
||||
"""Return associations for *agent_id* (filtering by `is_open` if asked)."""
|
||||
async with db_registry.async_session() as session:
|
||||
conditions = [
|
||||
FileAgentModel.agent_id == agent_id,
|
||||
FileAgentModel.organization_id == actor.organization_id,
|
||||
]
|
||||
if is_open_only:
|
||||
conditions.append(FileAgentModel.is_open.is_(True))
|
||||
|
||||
rows = (await session.execute(select(FileAgentModel).where(and_(*conditions)))).scalars().all()
|
||||
return [r.to_pydantic() for r in rows]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_agents_for_file(
|
||||
self,
|
||||
file_id: str,
|
||||
actor: PydanticUser,
|
||||
is_open_only: bool = False,
|
||||
) -> List[PydanticFileAgent]:
|
||||
"""Return associations for *file_id* (filtering by `is_open` if asked)."""
|
||||
async with db_registry.async_session() as session:
|
||||
conditions = [
|
||||
FileAgentModel.file_id == file_id,
|
||||
FileAgentModel.organization_id == actor.organization_id,
|
||||
]
|
||||
if is_open_only:
|
||||
conditions.append(FileAgentModel.is_open.is_(True))
|
||||
|
||||
rows = (await session.execute(select(FileAgentModel).where(and_(*conditions)))).scalars().all()
|
||||
return [r.to_pydantic() for r in rows]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def mark_access(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> None:
|
||||
"""Update only `last_accessed_at = now()` without loading the row."""
|
||||
async with db_registry.async_session() as session:
|
||||
stmt = (
|
||||
update(FileAgentModel)
|
||||
.where(
|
||||
FileAgentModel.agent_id == agent_id,
|
||||
FileAgentModel.file_id == file_id,
|
||||
FileAgentModel.organization_id == actor.organization_id,
|
||||
)
|
||||
.values(last_accessed_at=func.now())
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def _get_association(self, session, agent_id: str, file_id: str, actor: PydanticUser) -> FileAgentModel:
|
||||
q = select(FileAgentModel).where(
|
||||
and_(
|
||||
FileAgentModel.agent_id == agent_id,
|
||||
FileAgentModel.file_id == file_id,
|
||||
FileAgentModel.organization_id == actor.organization_id,
|
||||
)
|
||||
)
|
||||
assoc = await session.scalar(q)
|
||||
if not assoc:
|
||||
raise NoResultFound(f"FileAgent(agent_id={agent_id}, file_id={file_id}) not found in org {actor.organization_id}")
|
||||
return assoc
|
||||
@@ -640,6 +640,27 @@ def event_loop(request):
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def file_attachment(server, default_user, sarah_agent, default_file):
|
||||
assoc = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
actor=default_user,
|
||||
visible_content="initial",
|
||||
)
|
||||
yield assoc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def another_file(server, default_source, default_user, default_organization):
|
||||
pf = PydanticFileMetadata(
|
||||
file_name="another_file",
|
||||
organization_id=default_organization.id,
|
||||
source_id=default_source.id,
|
||||
)
|
||||
return await server.source_manager.create_file(pf, actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -5754,3 +5775,141 @@ async def test_create_mcp_server(server, default_user, event_loop):
|
||||
assert tool.name == tool_name
|
||||
assert f"mcp:{created_server.server_name}" in tool.tags, f"Expected tag {f'mcp:{created_server.server_name}'}, got {tool.tags}"
|
||||
print("TAGS", tool.tags)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# FileAgent Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_creates_association(server, default_user, sarah_agent, default_file):
|
||||
assoc = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
actor=default_user,
|
||||
visible_content="hello",
|
||||
)
|
||||
|
||||
assert assoc.agent_id == sarah_agent.id
|
||||
assert assoc.file_id == default_file.id
|
||||
assert assoc.is_open is True
|
||||
assert assoc.visible_content == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_is_idempotent(server, default_user, sarah_agent, default_file):
|
||||
a1 = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
actor=default_user,
|
||||
visible_content="first",
|
||||
)
|
||||
|
||||
# second attach with different params
|
||||
a2 = await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
actor=default_user,
|
||||
is_open=False,
|
||||
visible_content="second",
|
||||
)
|
||||
|
||||
assert a1.id == a2.id
|
||||
assert a2.is_open is False
|
||||
assert a2.visible_content == "second"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_file_agent(server, file_attachment, default_user):
|
||||
updated = await server.file_agent_manager.update_file_agent(
|
||||
agent_id=file_attachment.agent_id,
|
||||
file_id=file_attachment.file_id,
|
||||
actor=default_user,
|
||||
is_open=False,
|
||||
visible_content="updated",
|
||||
)
|
||||
assert updated.is_open is False
|
||||
assert updated.visible_content == "updated"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_access(server, file_attachment, default_user):
|
||||
old_ts = file_attachment.last_accessed_at
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
else:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await server.file_agent_manager.mark_access(
|
||||
agent_id=file_attachment.agent_id,
|
||||
file_id=file_attachment.file_id,
|
||||
actor=default_user,
|
||||
)
|
||||
refreshed = await server.file_agent_manager.get_file_agent(
|
||||
agent_id=file_attachment.agent_id,
|
||||
file_id=file_attachment.file_id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert refreshed.last_accessed_at > old_ts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_and_agents(
|
||||
server,
|
||||
default_user,
|
||||
sarah_agent,
|
||||
charles_agent,
|
||||
default_file,
|
||||
another_file,
|
||||
):
|
||||
# default_file ↔ charles (open)
|
||||
await server.file_agent_manager.attach_file(agent_id=charles_agent.id, file_id=default_file.id, actor=default_user)
|
||||
# default_file ↔ sarah (open)
|
||||
await server.file_agent_manager.attach_file(agent_id=sarah_agent.id, file_id=default_file.id, actor=default_user)
|
||||
# another_file ↔ sarah (closed)
|
||||
await server.file_agent_manager.attach_file(agent_id=sarah_agent.id, file_id=another_file.id, actor=default_user, is_open=False)
|
||||
|
||||
files_for_sarah = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user)
|
||||
assert {f.file_id for f in files_for_sarah} == {default_file.id, another_file.id}
|
||||
|
||||
open_only = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user, is_open_only=True)
|
||||
assert {f.file_id for f in open_only} == {default_file.id}
|
||||
|
||||
agents_for_default = await server.file_agent_manager.list_agents_for_file(default_file.id, actor=default_user)
|
||||
assert {a.agent_id for a in agents_for_default} == {sarah_agent.id, charles_agent.id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_file(server, file_attachment, default_user):
|
||||
await server.file_agent_manager.detach_file(
|
||||
agent_id=file_attachment.agent_id,
|
||||
file_id=file_attachment.file_id,
|
||||
actor=default_user,
|
||||
)
|
||||
res = await server.file_agent_manager.get_file_agent(
|
||||
agent_id=file_attachment.agent_id,
|
||||
file_id=file_attachment.file_id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert res is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_org_scoping(
|
||||
server,
|
||||
default_user,
|
||||
other_user_different_org,
|
||||
sarah_agent,
|
||||
default_file,
|
||||
):
|
||||
# attach as default_user
|
||||
await server.file_agent_manager.attach_file(
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# other org should see nothing
|
||||
files = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=other_user_different_org)
|
||||
assert files == []
|
||||
|
||||
Reference in New Issue
Block a user