diff --git a/alembic/versions/0b496eae90de_add_file_agent_table.py b/alembic/versions/0b496eae90de_add_file_agent_table.py new file mode 100644 index 00000000..91b10781 --- /dev/null +++ b/alembic/versions/0b496eae90de_add_file_agent_table.py @@ -0,0 +1,54 @@ +"""Add file agent table + +Revision ID: 0b496eae90de +Revises: 341068089f14 +Create Date: 2025-06-02 15:14:33.730687 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0b496eae90de" +down_revision: Union[str, None] = "341068089f14" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "files_agents", + sa.Column("id", sa.String(), nullable=False), + sa.Column("file_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("is_open", sa.Boolean(), nullable=False), + sa.Column("visible_content", sa.Text(), nullable=True), + sa.Column("last_accessed_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["file_id"], ["files.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id", "file_id", "agent_id"), + ) + op.create_index("ix_files_agents_file_id_agent_id", "files_agents", ["file_id", "agent_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_files_agents_file_id_agent_id", table_name="files_agents") + op.drop_table("files_agents") + # ### end Alembic commands ### diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 8f7961bd..7b6076ef 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -5,6 +5,7 @@ from letta.orm.block import Block from letta.orm.block_history import BlockHistory from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata +from letta.orm.files_agents import FileAgent from letta.orm.group import Group from letta.orm.groups_agents import GroupsAgents from letta.orm.groups_blocks import GroupsBlocks diff --git a/letta/orm/file.py b/letta/orm/file.py index 88342700..b27ec7e1 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): - """Represents metadata for an uploaded file.""" + """Represents an uploaded file.""" __tablename__ = "files" __pydantic_model__ = PydanticFileMetadata diff --git a/letta/orm/files_agents.py b/letta/orm/files_agents.py new file mode 100644 index 00000000..847b6c7e --- /dev/null +++ b/letta/orm/files_agents.py @@ -0,0 +1,42 @@ +import uuid +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.file import FileAgent as PydanticFileAgent + +if TYPE_CHECKING: + pass + + +class FileAgent(SqlalchemyBase, OrganizationMixin): + """ + Join table between File and Agent. + + Tracks whether a file is currently “open” for the agent and + the specific excerpt (grepped section) the agent is looking at. + """ + + __tablename__ = "files_agents" + __table_args__ = (Index("ix_files_agents_file_id_agent_id", "file_id", "agent_id"),) + __pydantic_model__ = PydanticFileAgent + + # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase + # TODO: Some still rely on the Pydantic object to do this + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"file_agent-{uuid.uuid4()}") + file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id", ondelete="CASCADE"), primary_key=True, doc="ID of the file.") + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True, doc="ID of the agent.") + + is_open: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True, doc="True if the agent currently has the file open.") + visible_content: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Portion of the file the agent is focused on.") + last_accessed_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + doc="UTC timestamp when this agent last accessed the file.", + ) diff --git a/letta/schemas/file.py b/letta/schemas/file.py index b43eb64c..f537485d 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -29,3 +29,50 @@ class FileMetadata(FileMetadataBase): created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.") updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the file.") is_deleted: bool = Field(False, description="Whether this file is deleted or not.") + + +class FileAgentBase(LettaBase): + """Base class for the FileMetadata-⇄-Agent association schemas""" + + __id_prefix__ = "file_agent" + + +class FileAgent(FileAgentBase): + """ + A single FileMetadata ⇄ Agent association row. + + Captures: + • whether the agent currently has the file “open” + • the excerpt (grepped section) in the context window + • the last time the agent accessed the file + """ + + id: str = Field( + ..., + description="The internal ID", + ) + organization_id: Optional[str] = Field( + None, + description="Org ID this association belongs to (inherited from both agent and file).", + ) + agent_id: str = Field(..., description="Unique identifier of the agent.") + file_id: str = Field(..., description="Unique identifier of the file.") + is_open: bool = Field(True, description="True if the agent currently has the file open.") + visible_content: Optional[str] = Field( + None, + description="Portion of the file the agent is focused on (may be large).", + ) + last_accessed_at: Optional[datetime] = Field( + default_factory=datetime.utcnow, + description="UTC timestamp of the agent’s most recent access to this file.", + ) + + created_at: Optional[datetime] = Field( + default_factory=datetime.utcnow, + description="Row creation timestamp (UTC).", + ) + updated_at: Optional[datetime] = Field( + default_factory=datetime.utcnow, + description="Row last-update timestamp (UTC).", + ) + is_deleted: bool = Field(False, description="Soft-delete flag.") diff --git a/letta/server/server.py b/letta/server/server.py index 2f07d2ac..aca7d043 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -79,6 +79,7 @@ from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.utils import sse_async_generator from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager +from letta.services.files_agents_manager import FileAgentManager from letta.services.group_manager import GroupManager from letta.services.helpers.tool_execution_helper import prepare_local_sandbox from letta.services.identity_manager import IdentityManager @@ -217,6 +218,7 @@ class SyncServer(Server): self.group_manager = GroupManager() self.batch_manager = LLMBatchManager() self.telemetry_manager = TelemetryManager() + self.file_agent_manager = FileAgentManager() # A resusable httpx client timeout = httpx.Timeout(connect=10.0, read=20.0, write=10.0, pool=10.0) diff --git a/letta/services/files_agents_manager.py b/letta/services/files_agents_manager.py new file mode 100644 index 00000000..155e8322 --- /dev/null +++ b/letta/services/files_agents_manager.py @@ -0,0 +1,184 @@ +from datetime import datetime, timezone +from typing import List, Optional + +from sqlalchemy import and_, func, select, update + +from letta.orm.errors import NoResultFound +from letta.orm.files_agents import FileAgent as FileAgentModel +from letta.schemas.file import FileAgent as PydanticFileAgent +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.tracing import trace_method +from letta.utils import enforce_types + + +class FileAgentManager: + """High-level helpers for CRUD / listing on the `files_agents` join table.""" + + @enforce_types + @trace_method + async def attach_file( + self, + *, + agent_id: str, + file_id: str, + actor: PydanticUser, + is_open: bool = True, + visible_content: Optional[str] = None, + ) -> PydanticFileAgent: + """ + Idempotently attach *file_id* to *agent_id*. + + • If the row already exists → update `is_open`, `visible_content` + and always refresh `last_accessed_at`. + • Otherwise create a brand-new association. + """ + async with db_registry.async_session() as session: + query = select(FileAgentModel).where( + and_( + FileAgentModel.agent_id == agent_id, + FileAgentModel.file_id == file_id, + FileAgentModel.organization_id == actor.organization_id, + ) + ) + existing = await session.scalar(query) + + now_ts = datetime.now(timezone.utc) + + if existing: + # update only the fields that actually changed + if existing.is_open != is_open: + existing.is_open = is_open + + if visible_content is not None and existing.visible_content != visible_content: + existing.visible_content = visible_content + + existing.last_accessed_at = now_ts + + await existing.update_async(session, actor=actor) + return existing.to_pydantic() + + assoc = FileAgentModel( + agent_id=agent_id, + file_id=file_id, + organization_id=actor.organization_id, + is_open=is_open, + visible_content=visible_content, + last_accessed_at=now_ts, + ) + await assoc.create_async(session, actor=actor) + return assoc.to_pydantic() + + @enforce_types + @trace_method + async def update_file_agent( + self, + *, + agent_id: str, + file_id: str, + actor: PydanticUser, + is_open: Optional[bool] = None, + visible_content: Optional[str] = None, + ) -> PydanticFileAgent: + """Patch an existing association row.""" + async with db_registry.async_session() as session: + assoc = await self._get_association(session, agent_id, file_id, actor) + + if is_open is not None: + assoc.is_open = is_open + if visible_content is not None: + assoc.visible_content = visible_content + + # touch timestamp + assoc.last_accessed_at = datetime.now(timezone.utc) + + await assoc.update_async(session, actor=actor) + return assoc.to_pydantic() + + @enforce_types + @trace_method + async def detach_file(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> None: + """Hard-delete the association.""" + async with db_registry.async_session() as session: + assoc = await self._get_association(session, agent_id, file_id, actor) + await assoc.hard_delete_async(session, actor=actor) + + @enforce_types + @trace_method + async def get_file_agent(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> Optional[PydanticFileAgent]: + async with db_registry.async_session() as session: + try: + assoc = await self._get_association(session, agent_id, file_id, actor) + return assoc.to_pydantic() + except NoResultFound: + return None + + @enforce_types + @trace_method + async def list_files_for_agent( + self, + agent_id: str, + actor: PydanticUser, + is_open_only: bool = False, + ) -> List[PydanticFileAgent]: + """Return associations for *agent_id* (filtering by `is_open` if asked).""" + async with db_registry.async_session() as session: + conditions = [ + FileAgentModel.agent_id == agent_id, + FileAgentModel.organization_id == actor.organization_id, + ] + if is_open_only: + conditions.append(FileAgentModel.is_open.is_(True)) + + rows = (await session.execute(select(FileAgentModel).where(and_(*conditions)))).scalars().all() + return [r.to_pydantic() for r in rows] + + @enforce_types + @trace_method + async def list_agents_for_file( + self, + file_id: str, + actor: PydanticUser, + is_open_only: bool = False, + ) -> List[PydanticFileAgent]: + """Return associations for *file_id* (filtering by `is_open` if asked).""" + async with db_registry.async_session() as session: + conditions = [ + FileAgentModel.file_id == file_id, + FileAgentModel.organization_id == actor.organization_id, + ] + if is_open_only: + conditions.append(FileAgentModel.is_open.is_(True)) + + rows = (await session.execute(select(FileAgentModel).where(and_(*conditions)))).scalars().all() + return [r.to_pydantic() for r in rows] + + @enforce_types + @trace_method + async def mark_access(self, *, agent_id: str, file_id: str, actor: PydanticUser) -> None: + """Update only `last_accessed_at = now()` without loading the row.""" + async with db_registry.async_session() as session: + stmt = ( + update(FileAgentModel) + .where( + FileAgentModel.agent_id == agent_id, + FileAgentModel.file_id == file_id, + FileAgentModel.organization_id == actor.organization_id, + ) + .values(last_accessed_at=func.now()) + ) + await session.execute(stmt) + await session.commit() + + async def _get_association(self, session, agent_id: str, file_id: str, actor: PydanticUser) -> FileAgentModel: + q = select(FileAgentModel).where( + and_( + FileAgentModel.agent_id == agent_id, + FileAgentModel.file_id == file_id, + FileAgentModel.organization_id == actor.organization_id, + ) + ) + assoc = await session.scalar(q) + if not assoc: + raise NoResultFound(f"FileAgent(agent_id={agent_id}, file_id={file_id}) not found in org {actor.organization_id}") + return assoc diff --git a/tests/test_managers.py b/tests/test_managers.py index d13c5e4b..010ef49f 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -640,6 +640,27 @@ def event_loop(request): loop.close() +@pytest.fixture +async def file_attachment(server, default_user, sarah_agent, default_file): + assoc = await server.file_agent_manager.attach_file( + agent_id=sarah_agent.id, + file_id=default_file.id, + actor=default_user, + visible_content="initial", + ) + yield assoc + + +@pytest.fixture +async def another_file(server, default_source, default_user, default_organization): + pf = PydanticFileMetadata( + file_name="another_file", + organization_id=default_organization.id, + source_id=default_source.id, + ) + return await server.source_manager.create_file(pf, actor=default_user) + + # ====================================================================================================================== # AgentManager Tests - Basic # ====================================================================================================================== @@ -5754,3 +5775,141 @@ async def test_create_mcp_server(server, default_user, event_loop): assert tool.name == tool_name assert f"mcp:{created_server.server_name}" in tool.tags, f"Expected tag {f'mcp:{created_server.server_name}'}, got {tool.tags}" print("TAGS", tool.tags) + + +# ====================================================================================================================== +# FileAgent Tests +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_attach_creates_association(server, default_user, sarah_agent, default_file): + assoc = await server.file_agent_manager.attach_file( + agent_id=sarah_agent.id, + file_id=default_file.id, + actor=default_user, + visible_content="hello", + ) + + assert assoc.agent_id == sarah_agent.id + assert assoc.file_id == default_file.id + assert assoc.is_open is True + assert assoc.visible_content == "hello" + + +@pytest.mark.asyncio +async def test_attach_is_idempotent(server, default_user, sarah_agent, default_file): + a1 = await server.file_agent_manager.attach_file( + agent_id=sarah_agent.id, + file_id=default_file.id, + actor=default_user, + visible_content="first", + ) + + # second attach with different params + a2 = await server.file_agent_manager.attach_file( + agent_id=sarah_agent.id, + file_id=default_file.id, + actor=default_user, + is_open=False, + visible_content="second", + ) + + assert a1.id == a2.id + assert a2.is_open is False + assert a2.visible_content == "second" + + +@pytest.mark.asyncio +async def test_update_file_agent(server, file_attachment, default_user): + updated = await server.file_agent_manager.update_file_agent( + agent_id=file_attachment.agent_id, + file_id=file_attachment.file_id, + actor=default_user, + is_open=False, + visible_content="updated", + ) + assert updated.is_open is False + assert updated.visible_content == "updated" + + +@pytest.mark.asyncio +async def test_mark_access(server, file_attachment, default_user): + old_ts = file_attachment.last_accessed_at + if USING_SQLITE: + time.sleep(CREATE_DELAY_SQLITE) + else: + await asyncio.sleep(0.01) + + await server.file_agent_manager.mark_access( + agent_id=file_attachment.agent_id, + file_id=file_attachment.file_id, + actor=default_user, + ) + refreshed = await server.file_agent_manager.get_file_agent( + agent_id=file_attachment.agent_id, + file_id=file_attachment.file_id, + actor=default_user, + ) + assert refreshed.last_accessed_at > old_ts + + +@pytest.mark.asyncio +async def test_list_files_and_agents( + server, + default_user, + sarah_agent, + charles_agent, + default_file, + another_file, +): + # default_file ↔ charles (open) + await server.file_agent_manager.attach_file(agent_id=charles_agent.id, file_id=default_file.id, actor=default_user) + # default_file ↔ sarah (open) + await server.file_agent_manager.attach_file(agent_id=sarah_agent.id, file_id=default_file.id, actor=default_user) + # another_file ↔ sarah (closed) + await server.file_agent_manager.attach_file(agent_id=sarah_agent.id, file_id=another_file.id, actor=default_user, is_open=False) + + files_for_sarah = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user) + assert {f.file_id for f in files_for_sarah} == {default_file.id, another_file.id} + + open_only = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user, is_open_only=True) + assert {f.file_id for f in open_only} == {default_file.id} + + agents_for_default = await server.file_agent_manager.list_agents_for_file(default_file.id, actor=default_user) + assert {a.agent_id for a in agents_for_default} == {sarah_agent.id, charles_agent.id} + + +@pytest.mark.asyncio +async def test_detach_file(server, file_attachment, default_user): + await server.file_agent_manager.detach_file( + agent_id=file_attachment.agent_id, + file_id=file_attachment.file_id, + actor=default_user, + ) + res = await server.file_agent_manager.get_file_agent( + agent_id=file_attachment.agent_id, + file_id=file_attachment.file_id, + actor=default_user, + ) + assert res is None + + +@pytest.mark.asyncio +async def test_org_scoping( + server, + default_user, + other_user_different_org, + sarah_agent, + default_file, +): + # attach as default_user + await server.file_agent_manager.attach_file( + agent_id=sarah_agent.id, + file_id=default_file.id, + actor=default_user, + ) + + # other org should see nothing + files = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=other_user_different_org) + assert files == []