From 10f6f1d247c038eeeb466164fc383c5f235183cf Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 1 Aug 2025 23:34:49 -0700 Subject: [PATCH] feat: Implement archival sharing (#3689) --- ...4e860718e0d_add_archival_memory_sharing.py | 387 ++++++++++++++++++ letta/agents/voice_sleeptime_agent.py | 1 - letta/functions/function_sets/base.py | 1 - letta/orm/__init__.py | 4 +- letta/orm/agent.py | 8 + letta/orm/archive.py | 87 ++++ letta/orm/archives_agents.py | 27 ++ letta/orm/mixins.py | 8 + letta/orm/organization.py | 15 +- letta/orm/passage.py | 22 +- letta/schemas/archive.py | 44 ++ letta/schemas/passage.py | 6 +- letta/server/server.py | 37 +- letta/services/agent_manager.py | 49 ++- letta/services/archive_manager.py | 269 ++++++++++++ .../services/helpers/agent_manager_helper.py | 78 ++-- letta/services/passage_manager.py | 152 ++++--- .../tool_executor/core_tool_executor.py | 1 - tests/data/list_tools.json | 2 +- tests/test_managers.py | 137 +++++-- 20 files changed, 1140 insertions(+), 195 deletions(-) create mode 100644 alembic/versions/74e860718e0d_add_archival_memory_sharing.py create mode 100644 letta/orm/archive.py create mode 100644 letta/orm/archives_agents.py create mode 100644 letta/schemas/archive.py create mode 100644 letta/services/archive_manager.py diff --git a/alembic/versions/74e860718e0d_add_archival_memory_sharing.py b/alembic/versions/74e860718e0d_add_archival_memory_sharing.py new file mode 100644 index 00000000..7d38c113 --- /dev/null +++ b/alembic/versions/74e860718e0d_add_archival_memory_sharing.py @@ -0,0 +1,387 @@ +"""add archival memory sharing + +Revision ID: 74e860718e0d +Revises: 4c6c9ef0387d +Create Date: 2025-07-30 16:15:49.424711 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# Import custom columns if needed +try: + from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn +except ImportError: + # For environments where these aren't available + EmbeddingConfigColumn = sa.JSON + CommonVector = sa.BLOB + +# revision identifiers, used by Alembic. +revision: str = "74e860718e0d" +down_revision: Union[str, None] = "15b577c62f3f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # get database connection to check DB type + bind = op.get_bind() + is_sqlite = bind.dialect.name == "sqlite" + + # create new tables with appropriate defaults + if is_sqlite: + op.create_table( + "archives", + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("0"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name", "organization_id", name="unique_archive_name_per_org"), + ) + else: + op.create_table( + "archives", + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("metadata_", sa.JSON(), nullable=True), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name", "organization_id", name="unique_archive_name_per_org"), + ) + + op.create_index("ix_archives_created_at", "archives", ["created_at", "id"], unique=False) + op.create_index("ix_archives_organization_id", "archives", ["organization_id"], unique=False) + + if is_sqlite: + op.create_table( + "archives_agents", + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("archive_id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("datetime('now')"), nullable=False), + sa.Column("is_owner", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["archive_id"], ["archives.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("agent_id", "archive_id"), + # TODO: Remove this constraint when we support multiple archives per agent + sa.UniqueConstraint("agent_id", name="unique_agent_archive"), + ) + else: + op.create_table( + "archives_agents", + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("archive_id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("is_owner", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["archive_id"], ["archives.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("agent_id", "archive_id"), + # TODO: Remove this constraint when we support multiple archives per agent + sa.UniqueConstraint("agent_id", name="unique_agent_archive"), + ) + + if is_sqlite: + # For SQLite + # create temporary table to preserve existing agent_passages data + op.execute( + """ + CREATE TEMPORARY TABLE temp_agent_passages AS + SELECT * FROM agent_passages WHERE is_deleted = 0; + """ + ) + + # create default archives and migrate data + # First, create archives for each agent that has passages + op.execute( + """ + INSERT INTO archives (id, name, description, organization_id, created_at, updated_at, is_deleted) + SELECT DISTINCT + 'archive-' || lower(hex(randomblob(16))), + COALESCE(a.name, 'Agent ' || a.id) || '''s Archive', + 'Default archive created during migration', + a.organization_id, + datetime('now'), + datetime('now'), + 0 + FROM temp_agent_passages ap + JOIN agents a ON ap.agent_id = a.id; + """ + ) + + # create archives_agents relationships + op.execute( + """ + INSERT INTO archives_agents (agent_id, archive_id, is_owner, created_at) + SELECT + a.id as agent_id, + ar.id as archive_id, + 1 as is_owner, + datetime('now') as created_at + FROM agents a + JOIN archives ar ON ar.organization_id = a.organization_id + AND ar.name = COALESCE(a.name, 'Agent ' || a.id) || '''s Archive' + WHERE EXISTS ( + SELECT 1 FROM temp_agent_passages ap WHERE ap.agent_id = a.id + ); + """ + ) + + # drop the old agent_passages table + op.drop_index("ix_agent_passages_org_agent", table_name="agent_passages") + op.drop_table("agent_passages") + + # create the new archival_passages table with the new schema + op.create_table( + "archival_passages", + sa.Column("text", sa.String(), nullable=False), + sa.Column("embedding_config", EmbeddingConfigColumn, nullable=False), + sa.Column("metadata_", sa.JSON(), nullable=False), + sa.Column("embedding", CommonVector, nullable=True), # SQLite uses CommonVector for embeddings + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("0"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("archive_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint(["archive_id"], ["archives.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + + # migrate data from temp table to archival_passages with archive_id + op.execute( + """ + INSERT INTO archival_passages ( + id, text, embedding_config, metadata_, embedding, + created_at, updated_at, is_deleted, + _created_by_id, _last_updated_by_id, + organization_id, archive_id + ) + SELECT + ap.id, ap.text, ap.embedding_config, ap.metadata_, ap.embedding, + ap.created_at, ap.updated_at, ap.is_deleted, + ap._created_by_id, ap._last_updated_by_id, + ap.organization_id, ar.id as archive_id + FROM temp_agent_passages ap + JOIN agents a ON ap.agent_id = a.id + JOIN archives ar ON ar.organization_id = a.organization_id + AND ar.name = COALESCE(a.name, 'Agent ' || a.id) || '''s Archive'; + """ + ) + + # drop temporary table + op.execute("DROP TABLE temp_agent_passages;") + + # create indexes + op.create_index("ix_archival_passages_archive_id", "archival_passages", ["archive_id"]) + op.create_index("ix_archival_passages_org_archive", "archival_passages", ["organization_id", "archive_id"]) + op.create_index("archival_passages_created_at_id_idx", "archival_passages", ["created_at", "id"]) + + else: + # PostgreSQL + # add archive_id to agent_passages + op.add_column("agent_passages", sa.Column("archive_id", sa.String(), nullable=True)) + + # create default archives and migrate data + op.execute( + """ + -- Create a unique archive for each agent that has passages + WITH agent_archives AS ( + INSERT INTO archives (id, name, description, organization_id, created_at) + SELECT DISTINCT + 'archive-' || gen_random_uuid(), + COALESCE(a.name, 'Agent ' || a.id) || '''s Archive', + 'Default archive created during migration', + a.organization_id, + NOW() + FROM agent_passages ap + JOIN agents a ON ap.agent_id = a.id + WHERE ap.is_deleted = FALSE + RETURNING id as archive_id, + organization_id, + SUBSTRING(name FROM 1 FOR LENGTH(name) - LENGTH('''s Archive')) as agent_name + ) + -- Create archives_agents relationships + INSERT INTO archives_agents (agent_id, archive_id, is_owner, created_at) + SELECT + a.id as agent_id, + aa.archive_id, + TRUE, + NOW() + FROM agent_archives aa + JOIN agents a ON a.organization_id = aa.organization_id + AND (a.name = aa.agent_name OR ('Agent ' || a.id) = aa.agent_name); + """ + ) + + # update agent_passages with archive_id + op.execute( + """ + UPDATE agent_passages ap + SET archive_id = ar.id + FROM agents a + JOIN archives ar ON ar.organization_id = a.organization_id + AND ar.name = COALESCE(a.name, 'Agent ' || a.id) || '''s Archive' + WHERE ap.agent_id = a.id; + """ + ) + + # schema changes + op.alter_column("agent_passages", "archive_id", nullable=False) + op.create_foreign_key("agent_passages_archive_id_fkey", "agent_passages", "archives", ["archive_id"], ["id"], ondelete="CASCADE") + + # drop old indexes and constraints + op.drop_index("ix_agent_passages_org_agent", table_name="agent_passages") + op.drop_index("agent_passages_org_idx", table_name="agent_passages") + op.drop_index("agent_passages_created_at_id_idx", table_name="agent_passages") + op.drop_constraint("agent_passages_agent_id_fkey", "agent_passages", type_="foreignkey") + op.drop_column("agent_passages", "agent_id") + + # rename table and create new indexes + op.rename_table("agent_passages", "archival_passages") + op.create_index("ix_archival_passages_archive_id", "archival_passages", ["archive_id"]) + op.create_index("ix_archival_passages_org_archive", "archival_passages", ["organization_id", "archive_id"]) + op.create_index("archival_passages_org_idx", "archival_passages", ["organization_id"]) + op.create_index("archival_passages_created_at_id_idx", "archival_passages", ["created_at", "id"]) + + +def downgrade() -> None: + # Get database connection to check DB type + bind = op.get_bind() + is_sqlite = bind.dialect.name == "sqlite" + + if is_sqlite: + # For SQLite, we need to migrate data back carefully + # create temporary table to preserve existing archival_passages data + op.execute( + """ + CREATE TEMPORARY TABLE temp_archival_passages AS + SELECT * FROM archival_passages WHERE is_deleted = 0; + """ + ) + + # drop the archival_passages table and indexes + op.drop_index("ix_archival_passages_org_archive", table_name="archival_passages") + op.drop_index("ix_archival_passages_archive_id", table_name="archival_passages") + op.drop_index("archival_passages_created_at_id_idx", table_name="archival_passages") + op.drop_table("archival_passages") + + # recreate agent_passages with old schema + op.create_table( + "agent_passages", + sa.Column("text", sa.String(), nullable=False), + sa.Column("embedding_config", EmbeddingConfigColumn, nullable=False), + sa.Column("metadata_", sa.JSON(), nullable=False), + sa.Column("embedding", CommonVector, nullable=True), # SQLite uses CommonVector for embeddings + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("0"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + + # restore data from archival_passages back to agent_passages + # use the owner relationship from archives_agents to determine agent_id + op.execute( + """ + INSERT INTO agent_passages ( + id, text, embedding_config, metadata_, embedding, + created_at, updated_at, is_deleted, + _created_by_id, _last_updated_by_id, + organization_id, agent_id + ) + SELECT + ap.id, ap.text, ap.embedding_config, ap.metadata_, ap.embedding, + ap.created_at, ap.updated_at, ap.is_deleted, + ap._created_by_id, ap._last_updated_by_id, + ap.organization_id, aa.agent_id + FROM temp_archival_passages ap + JOIN archives_agents aa ON ap.archive_id = aa.archive_id AND aa.is_owner = 1; + """ + ) + + # drop temporary table + op.execute("DROP TABLE temp_archival_passages;") + + # create original indexes + op.create_index("ix_agent_passages_org_agent", "agent_passages", ["organization_id", "agent_id"]) + op.create_index("agent_passages_org_idx", "agent_passages", ["organization_id"]) + op.create_index("agent_passages_created_at_id_idx", "agent_passages", ["created_at", "id"]) + else: + # PostgreSQL: + # rename table back + op.drop_index("ix_archival_passages_org_archive", table_name="archival_passages") + op.drop_index("ix_archival_passages_archive_id", table_name="archival_passages") + op.drop_index("archival_passages_org_idx", table_name="archival_passages") + op.drop_index("archival_passages_created_at_id_idx", table_name="archival_passages") + op.rename_table("archival_passages", "agent_passages") + + # add agent_id column back + op.add_column("agent_passages", sa.Column("agent_id", sa.String(), nullable=True)) + + # restore agent_id from archives_agents (use the owner relationship) + op.execute( + """ + UPDATE agent_passages ap + SET agent_id = aa.agent_id + FROM archives_agents aa + WHERE ap.archive_id = aa.archive_id AND aa.is_owner = TRUE; + """ + ) + + # schema changes + op.alter_column("agent_passages", "agent_id", nullable=False) + op.create_foreign_key("agent_passages_agent_id_fkey", "agent_passages", "agents", ["agent_id"], ["id"], ondelete="CASCADE") + + # drop archive_id column and constraint + op.drop_constraint("agent_passages_archive_id_fkey", "agent_passages", type_="foreignkey") + op.drop_column("agent_passages", "archive_id") + + # restore original indexes + op.create_index("ix_agent_passages_org_agent", "agent_passages", ["organization_id", "agent_id"]) + op.create_index("agent_passages_org_idx", "agent_passages", ["organization_id"]) + op.create_index("agent_passages_created_at_id_idx", "agent_passages", ["created_at", "id"]) + + # drop new tables (same for both) + op.drop_table("archives_agents") + op.drop_index("ix_archives_organization_id", table_name="archives") + op.drop_index("ix_archives_created_at", table_name="archives") + op.drop_table("archives") diff --git a/letta/agents/voice_sleeptime_agent.py b/letta/agents/voice_sleeptime_agent.py index e9d013f9..359c49cd 100644 --- a/letta/agents/voice_sleeptime_agent.py +++ b/letta/agents/voice_sleeptime_agent.py @@ -166,7 +166,6 @@ class VoiceSleeptimeAgent(LettaAgent): memory = serialize_message_history(messages, context) self.agent_manager.passage_manager.insert_passage( agent_state=agent_state, - agent_id=agent_state.id, text=memory, actor=self.actor, ) diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 3ecbf42e..de61c411 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -75,7 +75,6 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: """ self.passage_manager.insert_passage( agent_state=self.agent_state, - agent_id=self.agent_state.id, text=content, actor=self.user, ) diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 3b65941e..b84923a3 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -1,5 +1,7 @@ from letta.orm.agent import Agent from letta.orm.agents_tags import AgentsTags +from letta.orm.archive import Archive +from letta.orm.archives_agents import ArchivesAgents from letta.orm.base import Base from letta.orm.block import Block from letta.orm.block_history import BlockHistory @@ -19,7 +21,7 @@ from letta.orm.llm_batch_job import LLMBatchJob from letta.orm.mcp_server import MCPServer from letta.orm.message import Message from letta.orm.organization import Organization -from letta.orm.passage import AgentPassage, BasePassage, SourcePassage +from letta.orm.passage import ArchivalPassage, BasePassage, SourcePassage from letta.orm.prompt import Prompt from letta.orm.provider import Provider from letta.orm.provider_trace import ProviderTrace diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 2b8b5f1f..81d5efbe 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -24,6 +24,7 @@ from letta.utils import calculate_file_defaults_based_on_context_window if TYPE_CHECKING: from letta.orm.agents_tags import AgentsTags + from letta.orm.archives_agents import ArchivesAgents from letta.orm.files_agents import FileAgent from letta.orm.identity import Identity from letta.orm.organization import Organization @@ -156,6 +157,13 @@ class Agent(SqlalchemyBase, OrganizationMixin, ProjectMixin, AsyncAttrs): cascade="all, delete-orphan", lazy="selectin", ) + archives_agents: Mapped[List["ArchivesAgents"]] = relationship( + "ArchivesAgents", + back_populates="agent", + cascade="all, delete-orphan", + lazy="noload", + doc="Archives accessible by this agent.", + ) def _get_per_file_view_window_char_limit(self) -> int: """Get the per_file_view_window_char_limit, calculating defaults if None.""" diff --git a/letta/orm/archive.py b/letta/orm/archive.py new file mode 100644 index 00000000..e8d89f63 --- /dev/null +++ b/letta/orm/archive.py @@ -0,0 +1,87 @@ +import uuid +from datetime import datetime, timezone +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy import JSON, Index, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.archive import Archive as PydanticArchive +from letta.settings import DatabaseChoice, settings + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import Session + + from letta.orm.archives_agents import ArchivesAgents + from letta.orm.organization import Organization + from letta.schemas.user import User + + +class Archive(SqlalchemyBase, OrganizationMixin): + """An archive represents a collection of archival passages that can be shared between agents""" + + __tablename__ = "archives" + __pydantic_model__ = PydanticArchive + + __table_args__ = ( + UniqueConstraint("name", "organization_id", name="unique_archive_name_per_org"), + Index("ix_archives_created_at", "created_at", "id"), + Index("ix_archives_organization_id", "organization_id"), + ) + + # archive generates its own id + # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase + # TODO: Some still rely on the Pydantic object to do this + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"archive-{uuid.uuid4()}") + + # archive-specific fields + name: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the archive") + description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="A description of the archive") + metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Additional metadata for the archive") + + # relationships + archives_agents: Mapped[List["ArchivesAgents"]] = relationship( + "ArchivesAgents", + back_populates="archive", + cascade="all, delete-orphan", # this will delete junction entries when archive is deleted + lazy="noload", + ) + + organization: Mapped["Organization"] = relationship("Organization", back_populates="archives", lazy="selectin") + + def create( + self, + db_session: "Session", + actor: Optional["User"] = None, + no_commit: bool = False, + ) -> "Archive": + """Override create to handle SQLite timestamp issues""" + # For SQLite, explicitly set timestamps as server_default may not work + if settings.database_engine == DatabaseChoice.SQLITE: + now = datetime.now(timezone.utc) + if not self.created_at: + self.created_at = now + if not self.updated_at: + self.updated_at = now + + return super().create(db_session, actor=actor, no_commit=no_commit) + + async def create_async( + self, + db_session: "AsyncSession", + actor: Optional["User"] = None, + no_commit: bool = False, + no_refresh: bool = False, + ) -> "Archive": + """Override create_async to handle SQLite timestamp issues""" + # For SQLite, explicitly set timestamps as server_default may not work + if settings.database_engine == DatabaseChoice.SQLITE: + now = datetime.now(timezone.utc) + if not self.created_at: + self.created_at = now + if not self.updated_at: + self.updated_at = now + + return await super().create_async(db_session, actor=actor, no_commit=no_commit, no_refresh=no_refresh) diff --git a/letta/orm/archives_agents.py b/letta/orm/archives_agents.py new file mode 100644 index 00000000..06c63a5e --- /dev/null +++ b/letta/orm/archives_agents.py @@ -0,0 +1,27 @@ +from datetime import datetime + +from sqlalchemy import Boolean, DateTime, ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.base import Base + + +class ArchivesAgents(Base): + """Many-to-many relationship between agents and archives""" + + __tablename__ = "archives_agents" + + # TODO: Remove this unique constraint when we support multiple archives per agent + # For now, each agent can only have one archive + __table_args__ = (UniqueConstraint("agent_id", name="unique_agent_archive"),) + + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True) + archive_id: Mapped[str] = mapped_column(String, ForeignKey("archives.id", ondelete="CASCADE"), primary_key=True) + + # track when the relationship was created and if agent is owner + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default="now()") + is_owner: Mapped[bool] = mapped_column(Boolean, default=False, doc="Whether this agent created/owns the archive") + + # relationships + agent: Mapped["Agent"] = relationship("Agent", back_populates="archives_agents") + archive: Mapped["Archive"] = relationship("Archive", back_populates="archives_agents") diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index 0ce7042c..13848f17 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -70,3 +70,11 @@ class ProjectMixin(Base): __abstract__ = True project_id: Mapped[str] = mapped_column(String, nullable=True, doc="The associated project id.") + + +class ArchiveMixin(Base): + """Mixin for models that belong to an archive.""" + + __abstract__ = True + + archive_id: Mapped[str] = mapped_column(String, ForeignKey("archives.id", ondelete="CASCADE")) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 4f78ec89..6e63df5a 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, List from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -8,13 +8,14 @@ from letta.schemas.organization import Organization as PydanticOrganization if TYPE_CHECKING: from letta.orm import Source from letta.orm.agent import Agent + from letta.orm.archive import Archive from letta.orm.block import Block from letta.orm.group import Group from letta.orm.identity import Identity from letta.orm.llm_batch_items import LLMBatchItem from letta.orm.llm_batch_job import LLMBatchJob from letta.orm.message import Message - from letta.orm.passage import AgentPassage, SourcePassage + from letta.orm.passage import ArchivalPassage, SourcePassage from letta.orm.provider import Provider from letta.orm.sandbox_config import AgentEnvironmentVariable, SandboxConfig, SandboxEnvironmentVariable from letta.orm.tool import Tool @@ -52,7 +53,10 @@ class Organization(SqlalchemyBase): source_passages: Mapped[List["SourcePassage"]] = relationship( "SourcePassage", back_populates="organization", cascade="all, delete-orphan" ) - agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="organization", cascade="all, delete-orphan") + archival_passages: Mapped[List["ArchivalPassage"]] = relationship( + "ArchivalPassage", back_populates="organization", cascade="all, delete-orphan" + ) + archives: Mapped[List["Archive"]] = relationship("Archive", back_populates="organization", cascade="all, delete-orphan") providers: Mapped[List["Provider"]] = relationship("Provider", back_populates="organization", cascade="all, delete-orphan") identities: Mapped[List["Identity"]] = relationship("Identity", back_populates="organization", cascade="all, delete-orphan") groups: Mapped[List["Group"]] = relationship("Group", back_populates="organization", cascade="all, delete-orphan") @@ -60,8 +64,3 @@ class Organization(SqlalchemyBase): llm_batch_items: Mapped[List["LLMBatchItem"]] = relationship( "LLMBatchItem", back_populates="organization", cascade="all, delete-orphan" ) - - @property - def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: - """Convenience property to get all passages""" - return self.source_passages + self.agent_passages diff --git a/letta/orm/passage.py b/letta/orm/passage.py index 1a6c48a2..9507ffc0 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn -from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin +from letta.orm.mixins import ArchiveMixin, FileMixin, OrganizationMixin, SourceMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.passage import Passage as PydanticPassage from letta.settings import DatabaseChoice, settings @@ -70,26 +70,28 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin): ) -class AgentPassage(BasePassage, AgentMixin): - """Passages created by agents as archival memories""" +class ArchivalPassage(BasePassage, ArchiveMixin): + """Passages stored in archives as archival memories""" - __tablename__ = "agent_passages" + __tablename__ = "archival_passages" @declared_attr def organization(cls) -> Mapped["Organization"]: - return relationship("Organization", back_populates="agent_passages", lazy="selectin") + return relationship("Organization", back_populates="archival_passages", lazy="selectin") @declared_attr def __table_args__(cls): if settings.database_engine is DatabaseChoice.POSTGRES: return ( - Index("agent_passages_org_idx", "organization_id"), - Index("ix_agent_passages_org_agent", "organization_id", "agent_id"), - Index("agent_passages_created_at_id_idx", "created_at", "id"), + Index("archival_passages_org_idx", "organization_id"), + Index("ix_archival_passages_org_archive", "organization_id", "archive_id"), + Index("archival_passages_created_at_id_idx", "created_at", "id"), + Index("ix_archival_passages_archive_id", "archive_id"), {"extend_existing": True}, ) return ( - Index("ix_agent_passages_org_agent", "organization_id", "agent_id"), - Index("agent_passages_created_at_id_idx", "created_at", "id"), + Index("ix_archival_passages_org_archive", "organization_id", "archive_id"), + Index("archival_passages_created_at_id_idx", "created_at", "id"), + Index("ix_archival_passages_archive_id", "archive_id"), {"extend_existing": True}, ) diff --git a/letta/schemas/archive.py b/letta/schemas/archive.py new file mode 100644 index 00000000..965708bb --- /dev/null +++ b/letta/schemas/archive.py @@ -0,0 +1,44 @@ +from datetime import datetime +from typing import Dict, Optional + +from pydantic import Field + +from letta.schemas.letta_base import OrmMetadataBase + + +class ArchiveBase(OrmMetadataBase): + __id_prefix__ = "archive" + + name: str = Field(..., description="The name of the archive") + description: Optional[str] = Field(None, description="A description of the archive") + organization_id: str = Field(..., description="The organization this archive belongs to") + metadata: Optional[Dict] = Field(default_factory=dict, validation_alias="metadata_", description="Additional metadata") + + +class Archive(ArchiveBase): + """ + Representation of an archive - a collection of archival passages that can be shared between agents. + + Parameters: + id (str): The unique identifier of the archive. + name (str): The name of the archive. + description (str): A description of the archive. + organization_id (str): The organization this archive belongs to. + created_at (datetime): The creation date of the archive. + metadata (dict): Additional metadata for the archive. + """ + + id: str = ArchiveBase.generate_id_field() + created_at: datetime = Field(..., description="The creation date of the archive") + + +class ArchiveCreate(ArchiveBase): + """Create a new archive""" + + +class ArchiveUpdate(ArchiveBase): + """Update an existing archive""" + + name: Optional[str] = Field(None, description="The name of the archive") + description: Optional[str] = Field(None, description="A description of the archive") + metadata: Optional[Dict] = Field(None, validation_alias="metadata_", description="Additional metadata") diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index da87dd0f..57ab3f3c 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -16,7 +16,7 @@ class PassageBase(OrmMetadataBase): # associated user/agent organization_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.") - agent_id: Optional[str] = Field(None, description="The unique identifier of the agent associated with the passage.") + archive_id: Optional[str] = Field(None, description="The unique identifier of the archive containing this passage.") # origin data source source_id: Optional[str] = Field(None, description="The data source of the passage.") @@ -36,8 +36,8 @@ class Passage(PassageBase): embedding (List[float]): The embedding of the passage. embedding_config (EmbeddingConfig): The embedding configuration used by the passage. created_at (datetime): The creation date of the passage. - user_id (str): The unique identifier of the user associated with the passage. - agent_id (str): The unique identifier of the agent associated with the passage. + organization_id (str): The unique identifier of the organization associated with the passage. + archive_id (str): The unique identifier of the archive containing this passage. source_id (str): The data source of the passage. file_id (str): The unique identifier of the file associated with the passage. """ diff --git a/letta/server/server.py b/letta/server/server.py index 83bcd222..23158e44 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -80,6 +80,7 @@ from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.utils import sse_async_generator from letta.services.agent_manager import AgentManager from letta.services.agent_serialization_manager import AgentSerializationManager +from letta.services.archive_manager import ArchiveManager from letta.services.block_manager import BlockManager from letta.services.file_manager import FileManager from letta.services.files_agents_manager import FileAgentManager @@ -215,6 +216,7 @@ class SyncServer(Server): self.message_manager = MessageManager() self.job_manager = JobManager() self.agent_manager = AgentManager() + self.archive_manager = ArchiveManager() self.provider_manager = ProviderManager() self.step_manager = StepManager() self.identity_manager = IdentityManager() @@ -1146,29 +1148,12 @@ class SyncServer(Server): ) return records - def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]: - # Get the agent object (loaded in memory) - agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - # Insert into archival memory - # TODO: @mindy look at moving this to agent_manager to avoid above extra call - passages = self.passage_manager.insert_passage(agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor) - - # rebuild agent system prompt - force since no archival change - self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True) - - return passages - async def insert_archival_memory_async(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]: # Get the agent object (loaded in memory) agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) - # Insert into archival memory - # TODO: @mindy look at moving this to agent_manager to avoid above extra call - passages = await self.passage_manager.insert_passage_async( - agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor - ) - # rebuild agent system prompt - force since no archival change - await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True) + # Insert passages into the archive + passages = await self.passage_manager.insert_passage_async(agent_state=agent_state, text=memory_contents, actor=actor) return passages @@ -1177,17 +1162,6 @@ class SyncServer(Server): passages = self.passage_manager.update_passage_by_id(passage_id=memory_id, passage=passage, actor=actor) return passages - def delete_archival_memory(self, memory_id: str, actor: User): - # TODO check if it exists first, and throw error if not - # TODO: need to also rebuild the prompt here - passage = self.passage_manager.get_passage_by_id(passage_id=memory_id, actor=actor) - - # delete the passage - self.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor) - - # rebuild system prompt and force - self.agent_manager.rebuild_system_prompt(agent_id=passage.agent_id, actor=actor, force=True) - async def delete_archival_memory_async(self, memory_id: str, actor: User): # TODO check if it exists first, and throw error if not # TODO: need to also rebuild the prompt here @@ -1196,9 +1170,6 @@ class SyncServer(Server): # delete the passage await self.passage_manager.delete_passage_by_id_async(passage_id=memory_id, actor=actor) - # rebuild system prompt and force - await self.agent_manager.rebuild_system_prompt_async(agent_id=passage.agent_id, actor=actor, force=True) - def get_agent_recall( self, user_id: str, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d84d6a29..fdb8171c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -26,7 +26,7 @@ from letta.helpers.datetime_helpers import get_utc_time from letta.llm_api.llm_client import LLMClient from letta.log import get_logger from letta.orm import Agent as AgentModel -from letta.orm import AgentPassage, AgentsTags +from letta.orm import AgentsTags, ArchivalPassage from letta.orm import Block as BlockModel from letta.orm import BlocksAgents from letta.orm import Group as GroupModel @@ -1296,6 +1296,19 @@ class AgentManager: agent = AgentModel.read(db_session=session, name=agent_name, actor=actor) return agent.to_pydantic() + @enforce_types + @trace_method + async def get_agent_archive_ids_async(self, agent_id: str, actor: PydanticUser) -> List[str]: + """Get all archive IDs associated with an agent.""" + from letta.orm import ArchivesAgents + + async with db_registry.async_session() as session: + # Direct query to archives_agents table for performance + query = select(ArchivesAgents.archive_id).where(ArchivesAgents.agent_id == agent_id) + result = await session.execute(query) + archive_ids = [row[0] for row in result.fetchall()] + return archive_ids + @enforce_types @trace_method def delete_agent(self, agent_id: str, actor: PydanticUser) -> None: @@ -2342,21 +2355,24 @@ class AgentManager: main_query = main_query.limit(limit) # Execute query - results = list(session.execute(main_query)) + result = session.execute(main_query) passages = [] - for row in results: + for row in result: data = dict(row._mapping) - if data["agent_id"] is not None: - # This is an AgentPassage - remove source fields + if data.get("archive_id", None): + # This is an ArchivalPassage - remove source fields data.pop("source_id", None) data.pop("file_id", None) data.pop("file_name", None) - passage = AgentPassage(**data) - else: - # This is a SourcePassage - remove agent field - data.pop("agent_id", None) + passage = ArchivalPassage(**data) + elif data.get("source_id", None): + # This is a SourcePassage - remove archive field + data.pop("archive_id", None) + data.pop("agent_id", None) # For backward compatibility passage = SourcePassage(**data) + else: + raise ValueError(f"Passage data is malformed, is neither ArchivalPassage nor SourcePassage {data}") passages.append(passage) return [p.to_pydantic() for p in passages] @@ -2408,16 +2424,19 @@ class AgentManager: passages = [] for row in result: data = dict(row._mapping) - if data["agent_id"] is not None: - # This is an AgentPassage - remove source fields + if data.get("archive_id", None): + # This is an ArchivalPassage - remove source fields data.pop("source_id", None) data.pop("file_id", None) data.pop("file_name", None) - passage = AgentPassage(**data) - else: - # This is a SourcePassage - remove agent field - data.pop("agent_id", None) + passage = ArchivalPassage(**data) + elif data.get("source_id", None): + # This is a SourcePassage - remove archive field + data.pop("archive_id", None) + data.pop("agent_id", None) # For backward compatibility passage = SourcePassage(**data) + else: + raise ValueError(f"Passage data is malformed, is neither ArchivalPassage nor SourcePassage {data}") passages.append(passage) return [p.to_pydantic() for p in passages] diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py new file mode 100644 index 00000000..86a9e546 --- /dev/null +++ b/letta/services/archive_manager.py @@ -0,0 +1,269 @@ +from typing import List, Optional + +from sqlalchemy import select + +from letta.log import get_logger +from letta.orm import ArchivalPassage +from letta.orm import Archive as ArchiveModel +from letta.orm import ArchivesAgents +from letta.schemas.archive import Archive as PydanticArchive +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.utils import enforce_types + +logger = get_logger(__name__) + + +class ArchiveManager: + """Manager class to handle business logic related to Archives.""" + + @enforce_types + def create_archive( + self, + name: str, + description: Optional[str] = None, + actor: PydanticUser = None, + ) -> PydanticArchive: + """Create a new archive.""" + try: + with db_registry.session() as session: + archive = ArchiveModel( + name=name, + description=description, + organization_id=actor.organization_id, + ) + archive.create(session, actor=actor) + return archive.to_pydantic() + except Exception as e: + logger.exception(f"Failed to create archive {name}. error={e}") + raise + + @enforce_types + async def create_archive_async( + self, + name: str, + description: Optional[str] = None, + actor: PydanticUser = None, + ) -> PydanticArchive: + """Create a new archive.""" + try: + async with db_registry.async_session() as session: + archive = ArchiveModel( + name=name, + description=description, + organization_id=actor.organization_id, + ) + await archive.create_async(session, actor=actor) + return archive.to_pydantic() + except Exception as e: + logger.exception(f"Failed to create archive {name}. error={e}") + raise + + @enforce_types + async def get_archive_by_id_async( + self, + archive_id: str, + actor: PydanticUser, + ) -> PydanticArchive: + """Get an archive by ID.""" + async with db_registry.async_session() as session: + archive = await ArchiveModel.read_async( + db_session=session, + identifier=archive_id, + actor=actor, + ) + return archive.to_pydantic() + + @enforce_types + def attach_agent_to_archive( + self, + agent_id: str, + archive_id: str, + is_owner: bool, + actor: PydanticUser, + ) -> None: + """Attach an agent to an archive.""" + with db_registry.session() as session: + # Check if already attached + existing = session.query(ArchivesAgents).filter_by(agent_id=agent_id, archive_id=archive_id).first() + + if existing: + # Update ownership if needed + if existing.is_owner != is_owner: + existing.is_owner = is_owner + session.commit() + return + + # Create new relationship + archives_agents = ArchivesAgents( + agent_id=agent_id, + archive_id=archive_id, + is_owner=is_owner, + ) + session.add(archives_agents) + session.commit() + + @enforce_types + async def attach_agent_to_archive_async( + self, + agent_id: str, + archive_id: str, + is_owner: bool = False, + actor: PydanticUser = None, + ) -> None: + """Attach an agent to an archive.""" + async with db_registry.async_session() as session: + # Check if relationship already exists + existing = await session.execute( + select(ArchivesAgents).where( + ArchivesAgents.agent_id == agent_id, + ArchivesAgents.archive_id == archive_id, + ) + ) + existing_record = existing.scalar_one_or_none() + + if existing_record: + # Update ownership if needed + if existing_record.is_owner != is_owner: + existing_record.is_owner = is_owner + await session.commit() + return + + # Create the relationship + archives_agents = ArchivesAgents( + agent_id=agent_id, + archive_id=archive_id, + is_owner=is_owner, + ) + session.add(archives_agents) + await session.commit() + + @enforce_types + async def get_or_create_default_archive_for_agent_async( + self, + agent_id: str, + agent_name: Optional[str] = None, + actor: PydanticUser = None, + ) -> PydanticArchive: + """Get the agent's default archive, creating one if it doesn't exist.""" + # First check if agent has any archives + from letta.services.agent_manager import AgentManager + + agent_manager = AgentManager() + + archive_ids = await agent_manager.get_agent_archive_ids_async( + agent_id=agent_id, + actor=actor, + ) + + if archive_ids: + # TODO: Remove this check once we support multiple archives per agent + if len(archive_ids) > 1: + raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported") + # Get the archive + archive = await self.get_archive_by_id_async( + archive_id=archive_ids[0], + actor=actor, + ) + return archive + + # Create a default archive for this agent + archive_name = f"{agent_name or f'Agent {agent_id}'}'s Archive" + archive = await self.create_archive_async( + name=archive_name, + description="Default archive created automatically", + actor=actor, + ) + + # Attach the agent to the archive as owner + await self.attach_agent_to_archive_async( + agent_id=agent_id, + archive_id=archive.id, + is_owner=True, + actor=actor, + ) + + return archive + + @enforce_types + def get_or_create_default_archive_for_agent( + self, + agent_id: str, + agent_name: Optional[str] = None, + actor: PydanticUser = None, + ) -> PydanticArchive: + """Get the agent's default archive, creating one if it doesn't exist.""" + with db_registry.session() as session: + # First check if agent has any archives + query = select(ArchivesAgents.archive_id).where(ArchivesAgents.agent_id == agent_id) + result = session.execute(query) + archive_ids = [row[0] for row in result.fetchall()] + + if archive_ids: + # TODO: Remove this check once we support multiple archives per agent + if len(archive_ids) > 1: + raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported") + # Get the archive + archive = ArchiveModel.read(db_session=session, identifier=archive_ids[0], actor=actor) + return archive.to_pydantic() + + # Create a default archive for this agent + archive_name = f"{agent_name or f'Agent {agent_id}'}'s Archive" + + # Create the archive + archive_model = ArchiveModel( + name=archive_name, + description="Default archive created automatically", + organization_id=actor.organization_id, + ) + archive_model.create(session, actor=actor) + + # Attach the agent to the archive as owner + self.attach_agent_to_archive( + agent_id=agent_id, + archive_id=archive_model.id, + is_owner=True, + actor=actor, + ) + + return archive_model.to_pydantic() + + @enforce_types + async def get_agents_for_archive_async( + self, + archive_id: str, + actor: PydanticUser, + ) -> List[str]: + """Get all agent IDs that have access to an archive.""" + async with db_registry.async_session() as session: + result = await session.execute(select(ArchivesAgents.agent_id).where(ArchivesAgents.archive_id == archive_id)) + return [row[0] for row in result.fetchall()] + + @enforce_types + async def get_agent_from_passage_async( + self, + passage_id: str, + actor: PydanticUser, + ) -> Optional[str]: + """Get the agent ID that owns a passage (through its archive). + + Returns the first agent found (for backwards compatibility). + Returns None if no agent found. + """ + async with db_registry.async_session() as session: + # First get the passage to find its archive_id + passage = await ArchivalPassage.read_async( + db_session=session, + identifier=passage_id, + actor=actor, + ) + + # Then find agents connected to that archive + result = await session.execute(select(ArchivesAgents.agent_id).where(ArchivesAgents.archive_id == passage.archive_id)) + agent_ids = [row[0] for row in result.fetchall()] + + if not agent_ids: + return None + + # For now, return the first agent (backwards compatibility) + return agent_ids[0] diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 78d0ba69..04793022 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -25,9 +25,10 @@ from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import format_datetime, get_local_time, get_local_time_fast from letta.orm.agent import Agent as AgentModel from letta.orm.agents_tags import AgentsTags +from letta.orm.archives_agents import ArchivesAgents from letta.orm.errors import NoResultFound from letta.orm.identity import Identity -from letta.orm.passage import AgentPassage, SourcePassage +from letta.orm.passage import ArchivalPassage, SourcePassage from letta.orm.sources_agents import SourcesAgents from letta.orm.sqlite_functions import adapt_array from letta.otel.tracing import trace_method @@ -918,7 +919,7 @@ def build_passage_query( SourcePassage.organization_id, SourcePassage.file_id, SourcePassage.source_id, - literal(None).label("agent_id"), + literal(None).label("archive_id"), ) .join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id) .where(SourcesAgents.agent_id == agent_id) @@ -940,7 +941,7 @@ def build_passage_query( SourcePassage.organization_id, SourcePassage.file_id, SourcePassage.source_id, - literal(None).label("agent_id"), + literal(None).label("archive_id"), ).where(SourcePassage.organization_id == actor.organization_id) if source_id: @@ -954,23 +955,24 @@ def build_passage_query( agent_passages = ( select( literal(None).label("file_name"), - AgentPassage.id, - AgentPassage.text, - AgentPassage.embedding_config, - AgentPassage.metadata_, - AgentPassage.embedding, - AgentPassage.created_at, - AgentPassage.updated_at, - AgentPassage.is_deleted, - AgentPassage._created_by_id, - AgentPassage._last_updated_by_id, - AgentPassage.organization_id, + ArchivalPassage.id, + ArchivalPassage.text, + ArchivalPassage.embedding_config, + ArchivalPassage.metadata_, + ArchivalPassage.embedding, + ArchivalPassage.created_at, + ArchivalPassage.updated_at, + ArchivalPassage.is_deleted, + ArchivalPassage._created_by_id, + ArchivalPassage._last_updated_by_id, + ArchivalPassage.organization_id, literal(None).label("file_id"), literal(None).label("source_id"), - AgentPassage.agent_id, + ArchivalPassage.archive_id, ) - .where(AgentPassage.agent_id == agent_id) - .where(AgentPassage.organization_id == actor.organization_id) + .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id) + .where(ArchivesAgents.agent_id == agent_id) + .where(ArchivalPassage.organization_id == actor.organization_id) ) # Combine queries @@ -1201,56 +1203,60 @@ def build_agent_passage_query( embedded_text = np.array(embedded_text) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() - # Base query for agent passages - query = select(AgentPassage).where(AgentPassage.agent_id == agent_id, AgentPassage.organization_id == actor.organization_id) + # Base query for agent passages - join through archives_agents + query = ( + select(ArchivalPassage) + .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id) + .where(ArchivesAgents.agent_id == agent_id, ArchivalPassage.organization_id == actor.organization_id) + ) # Apply filters if start_date: - query = query.where(AgentPassage.created_at >= start_date) + query = query.where(ArchivalPassage.created_at >= start_date) if end_date: - query = query.where(AgentPassage.created_at <= end_date) + query = query.where(ArchivalPassage.created_at <= end_date) # Handle text search or vector search if embedded_text: if settings.database_engine is DatabaseChoice.POSTGRES: # PostgreSQL with pgvector - query = query.order_by(AgentPassage.embedding.cosine_distance(embedded_text).asc()) + query = query.order_by(ArchivalPassage.embedding.cosine_distance(embedded_text).asc()) else: # SQLite with custom vector type query_embedding_binary = adapt_array(embedded_text) query = query.order_by( - func.cosine_distance(AgentPassage.embedding, query_embedding_binary).asc(), - AgentPassage.created_at.asc() if ascending else AgentPassage.created_at.desc(), - AgentPassage.id.asc(), + func.cosine_distance(ArchivalPassage.embedding, query_embedding_binary).asc(), + ArchivalPassage.created_at.asc() if ascending else ArchivalPassage.created_at.desc(), + ArchivalPassage.id.asc(), ) else: if query_text: - query = query.where(func.lower(AgentPassage.text).contains(func.lower(query_text))) + query = query.where(func.lower(ArchivalPassage.text).contains(func.lower(query_text))) # Handle pagination if before or after: if before: # Get the reference record - before_subq = select(AgentPassage.created_at, AgentPassage.id).where(AgentPassage.id == before).subquery() + before_subq = select(ArchivalPassage.created_at, ArchivalPassage.id).where(ArchivalPassage.id == before).subquery() query = query.where( or_( - AgentPassage.created_at < before_subq.c.created_at, + ArchivalPassage.created_at < before_subq.c.created_at, and_( - AgentPassage.created_at == before_subq.c.created_at, - AgentPassage.id < before_subq.c.id, + ArchivalPassage.created_at == before_subq.c.created_at, + ArchivalPassage.id < before_subq.c.id, ), ) ) if after: # Get the reference record - after_subq = select(AgentPassage.created_at, AgentPassage.id).where(AgentPassage.id == after).subquery() + after_subq = select(ArchivalPassage.created_at, ArchivalPassage.id).where(ArchivalPassage.id == after).subquery() query = query.where( or_( - AgentPassage.created_at > after_subq.c.created_at, + ArchivalPassage.created_at > after_subq.c.created_at, and_( - AgentPassage.created_at == after_subq.c.created_at, - AgentPassage.id > after_subq.c.id, + ArchivalPassage.created_at == after_subq.c.created_at, + ArchivalPassage.id > after_subq.c.id, ), ) ) @@ -1258,9 +1264,9 @@ def build_agent_passage_query( # Apply ordering if not already ordered by similarity if not embed_query: if ascending: - query = query.order_by(AgentPassage.created_at.asc(), AgentPassage.id.asc()) + query = query.order_by(ArchivalPassage.created_at.asc(), ArchivalPassage.id.asc()) else: - query = query.order_by(AgentPassage.created_at.desc(), AgentPassage.id.asc()) + query = query.order_by(ArchivalPassage.created_at.desc(), ArchivalPassage.id.asc()) return query diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 67580776..895ecef3 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -9,14 +9,16 @@ from sqlalchemy import select from letta.constants import MAX_EMBEDDING_DIM from letta.embeddings import embedding_model, parse_and_chunk_text from letta.helpers.decorators import async_redis_cache +from letta.orm import ArchivesAgents from letta.orm.errors import NoResultFound -from letta.orm.passage import AgentPassage, SourcePassage +from letta.orm.passage import ArchivalPassage, SourcePassage from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.services.archive_manager import ArchiveManager from letta.utils import enforce_types @@ -42,6 +44,9 @@ async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> li class PassageManager: """Manager class to handle business logic related to Passages.""" + def __init__(self): + self.archive_manager = ArchiveManager() + # AGENT PASSAGE METHODS @enforce_types @trace_method @@ -49,7 +54,7 @@ class PassageManager: """Fetch an agent passage by ID.""" with db_registry.session() as session: try: - passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor) + passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor) return passage.to_pydantic() except NoResultFound: raise NoResultFound(f"Agent passage with id {passage_id} not found in database.") @@ -60,7 +65,7 @@ class PassageManager: """Fetch an agent passage by ID.""" async with db_registry.async_session() as session: try: - passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor) + passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor) return passage.to_pydantic() except NoResultFound: raise NoResultFound(f"Agent passage with id {passage_id} not found in database.") @@ -109,7 +114,7 @@ class PassageManager: except NoResultFound: # Try archival passages try: - passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor) + passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor) return passage.to_pydantic() except NoResultFound: raise NoResultFound(f"Passage with id {passage_id} not found in database.") @@ -134,7 +139,7 @@ class PassageManager: except NoResultFound: # Try archival passages try: - passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor) + passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor) return passage.to_pydantic() except NoResultFound: raise NoResultFound(f"Passage with id {passage_id} not found in database.") @@ -143,8 +148,8 @@ class PassageManager: @trace_method def create_agent_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: """Create a new agent passage.""" - if not pydantic_passage.agent_id: - raise ValueError("Agent passage must have agent_id") + if not pydantic_passage.archive_id: + raise ValueError("Agent passage must have archive_id") if pydantic_passage.source_id: raise ValueError("Agent passage cannot have source_id") @@ -159,8 +164,8 @@ class PassageManager: "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } - agent_fields = {"agent_id": data["agent_id"]} - passage = AgentPassage(**common_fields, **agent_fields) + agent_fields = {"archive_id": data["archive_id"]} + passage = ArchivalPassage(**common_fields, **agent_fields) with db_registry.session() as session: passage.create(session, actor=actor) @@ -170,8 +175,8 @@ class PassageManager: @trace_method async def create_agent_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: """Create a new agent passage.""" - if not pydantic_passage.agent_id: - raise ValueError("Agent passage must have agent_id") + if not pydantic_passage.archive_id: + raise ValueError("Agent passage must have archive_id") if pydantic_passage.source_id: raise ValueError("Agent passage cannot have source_id") @@ -186,8 +191,8 @@ class PassageManager: "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } - agent_fields = {"agent_id": data["agent_id"]} - passage = AgentPassage(**common_fields, **agent_fields) + agent_fields = {"archive_id": data["archive_id"]} + passage = ArchivalPassage(**common_fields, **agent_fields) async with db_registry.async_session() as session: passage = await passage.create_async(session, actor=actor) @@ -201,8 +206,8 @@ class PassageManager: """Create a new source passage.""" if not pydantic_passage.source_id: raise ValueError("Source passage must have source_id") - if pydantic_passage.agent_id: - raise ValueError("Source passage cannot have agent_id") + if pydantic_passage.archive_id: + raise ValueError("Source passage cannot have archive_id") data = pydantic_passage.model_dump(to_orm=True) common_fields = { @@ -234,8 +239,8 @@ class PassageManager: """Create a new source passage.""" if not pydantic_passage.source_id: raise ValueError("Source passage must have source_id") - if pydantic_passage.agent_id: - raise ValueError("Source passage cannot have agent_id") + if pydantic_passage.archive_id: + raise ValueError("Source passage cannot have archive_id") data = pydantic_passage.model_dump(to_orm=True) common_fields = { @@ -308,21 +313,21 @@ class PassageManager: "created_at": data.get("created_at", datetime.now(timezone.utc)), } - if "agent_id" in data and data["agent_id"]: - assert not data.get("source_id"), "Passage cannot have both agent_id and source_id" + if "archive_id" in data and data["archive_id"]: + assert not data.get("source_id"), "Passage cannot have both archive_id and source_id" agent_fields = { - "agent_id": data["agent_id"], + "archive_id": data["archive_id"], } - passage = AgentPassage(**common_fields, **agent_fields) + passage = ArchivalPassage(**common_fields, **agent_fields) elif "source_id" in data and data["source_id"]: - assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id" + assert not data.get("archive_id"), "Passage cannot have both archive_id and source_id" source_fields = { "source_id": data["source_id"], "file_id": data.get("file_id"), } passage = SourcePassage(**common_fields, **source_fields) else: - raise ValueError("Passage must have either agent_id or source_id") + raise ValueError("Passage must have either archive_id or source_id") return passage @@ -334,14 +339,14 @@ class PassageManager: @enforce_types @trace_method - async def create_many_agent_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: - """Create multiple agent passages.""" - agent_passages = [] + async def create_many_archival_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: + """Create multiple archival passages.""" + archival_passages = [] for p in passages: - if not p.agent_id: - raise ValueError("Agent passage must have agent_id") + if not p.archive_id: + raise ValueError("Archival passage must have archive_id") if p.source_id: - raise ValueError("Agent passage cannot have source_id") + raise ValueError("Archival passage cannot have source_id") data = p.model_dump(to_orm=True) common_fields = { @@ -354,12 +359,12 @@ class PassageManager: "is_deleted": data.get("is_deleted", False), "created_at": data.get("created_at", datetime.now(timezone.utc)), } - agent_fields = {"agent_id": data["agent_id"]} - agent_passages.append(AgentPassage(**common_fields, **agent_fields)) + archival_fields = {"archive_id": data["archive_id"]} + archival_passages.append(ArchivalPassage(**common_fields, **archival_fields)) async with db_registry.async_session() as session: - agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor) - return [p.to_pydantic() for p in agent_created] + archival_created = await ArchivalPassage.batch_create_async(items=archival_passages, db_session=session, actor=actor) + return [p.to_pydantic() for p in archival_created] @enforce_types @trace_method @@ -379,8 +384,8 @@ class PassageManager: for p in passages: if not p.source_id: raise ValueError("Source passage must have source_id") - if p.agent_id: - raise ValueError("Source passage cannot have agent_id") + if p.archive_id: + raise ValueError("Source passage cannot have archive_id") data = p.model_dump(to_orm=True) common_fields = { @@ -436,7 +441,7 @@ class PassageManager: for p in passages: model = self._preprocess_passage_for_creation(p) - if isinstance(model, AgentPassage): + if isinstance(model, ArchivalPassage): agent_passages.append(model) elif isinstance(model, SourcePassage): source_passages.append(model) @@ -445,7 +450,7 @@ class PassageManager: results = [] if agent_passages: - agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor) + agent_created = await ArchivalPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor) results.extend(agent_created) if source_passages: source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor) @@ -458,7 +463,6 @@ class PassageManager: def insert_passage( self, agent_state: AgentState, - agent_id: str, text: str, actor: PydanticUser, ) -> List[PydanticPassage]: @@ -494,10 +498,15 @@ class PassageManager: raise TypeError( f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}" ) + # Get or create the default archive for the agent + archive = self.archive_manager.get_or_create_default_archive_for_agent( + agent_id=agent_state.id, agent_name=agent_state.name, actor=actor + ) + passage = self.create_agent_passage( PydanticPassage( organization_id=actor.organization_id, - agent_id=agent_id, + archive_id=archive.id, text=text, embedding=embedding, embedding_config=agent_state.embedding_config, @@ -516,12 +525,18 @@ class PassageManager: async def insert_passage_async( self, agent_state: AgentState, - agent_id: str, text: str, actor: PydanticUser, image_ids: Optional[List[str]] = None, ) -> List[PydanticPassage]: """Insert passage(s) into archival memory""" + # Get or create default archive for the agent + archive = await self.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=agent_state.id, + agent_name=agent_state.name, + actor=actor, + ) + archive_id = archive.id embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size text_chunks = list(parse_and_chunk_text(text, embedding_chunk_size)) @@ -535,7 +550,7 @@ class PassageManager: passages = [ PydanticPassage( organization_id=actor.organization_id, - agent_id=agent_id, + archive_id=archive_id, text=chunk_text, embedding=embedding, embedding_config=agent_state.embedding_config, @@ -543,7 +558,7 @@ class PassageManager: for chunk_text, embedding in zip(text_chunks, embeddings) ] - passages = await self.create_many_agent_passages_async(passages=passages, actor=actor) + passages = await self.create_many_archival_passages_async(passages=passages, actor=actor) return passages @@ -595,7 +610,7 @@ class PassageManager: with db_registry.session() as session: try: - curr_passage = AgentPassage.read( + curr_passage = ArchivalPassage.read( db_session=session, identifier=passage_id, actor=actor, @@ -623,7 +638,7 @@ class PassageManager: async with db_registry.async_session() as session: try: - curr_passage = await AgentPassage.read_async( + curr_passage = await ArchivalPassage.read_async( db_session=session, identifier=passage_id, actor=actor, @@ -705,7 +720,7 @@ class PassageManager: with db_registry.session() as session: try: - passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor) + passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor) passage.hard_delete(session, actor=actor) return True except NoResultFound: @@ -720,7 +735,7 @@ class PassageManager: async with db_registry.async_session() as session: try: - passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor) + passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor) await passage.hard_delete_async(session, actor=actor) return True except NoResultFound: @@ -783,7 +798,7 @@ class PassageManager: except NoResultFound: # Try agent passages try: - curr_passage = AgentPassage.read( + curr_passage = ArchivalPassage.read( db_session=session, identifier=passage_id, actor=actor, @@ -824,7 +839,7 @@ class PassageManager: except NoResultFound: # Try archival passages try: - passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor) + passage = ArchivalPassage.read(db_session=session, identifier=passage_id, actor=actor) passage.hard_delete(session, actor=actor) return True except NoResultFound: @@ -854,7 +869,7 @@ class PassageManager: except NoResultFound: # Try archival passages try: - passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor) + passage = await ArchivalPassage.read_async(db_session=session, identifier=passage_id, actor=actor) await passage.hard_delete_async(session, actor=actor) return True except NoResultFound: @@ -883,7 +898,7 @@ class PassageManager: ) -> bool: """Delete multiple agent passages.""" async with db_registry.async_session() as session: - await AgentPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor) + await ArchivalPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor) return True @enforce_types @@ -947,7 +962,21 @@ class PassageManager: agent_id: The agent ID of the messages """ with db_registry.session() as session: - return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id) + if agent_id: + # Count passages through the archives relationship + return ( + session.query(ArchivalPassage) + .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id) + .filter( + ArchivesAgents.agent_id == agent_id, + ArchivalPassage.organization_id == actor.organization_id, + ArchivalPassage.is_deleted == False, + ) + .count() + ) + else: + # Count all archival passages in the organization + return ArchivalPassage.size(db_session=session, actor=actor) # DEPRECATED - Use agent_passage_size() instead since this only counted agent passages anyway @enforce_types @@ -961,8 +990,7 @@ class PassageManager: import warnings warnings.warn("size is deprecated. Use agent_passage_size() instead.", DeprecationWarning, stacklevel=2) - with db_registry.session() as session: - return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id) + return self.agent_passage_size(actor=actor, agent_id=agent_id) @enforce_types @trace_method @@ -977,7 +1005,23 @@ class PassageManager: agent_id: The agent ID of the messages """ async with db_registry.async_session() as session: - return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id) + if agent_id: + # Count passages through the archives relationship + from sqlalchemy import func, select + + result = await session.execute( + select(func.count(ArchivalPassage.id)) + .join(ArchivesAgents, ArchivalPassage.archive_id == ArchivesAgents.archive_id) + .where( + ArchivesAgents.agent_id == agent_id, + ArchivalPassage.organization_id == actor.organization_id, + ArchivalPassage.is_deleted == False, + ) + ) + return result.scalar() or 0 + else: + # Count all archival passages in the organization + return await ArchivalPassage.size_async(db_session=session, actor=actor) @enforce_types @trace_method diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index bff2545d..0ed4f6f4 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -176,7 +176,6 @@ class LettaCoreToolExecutor(ToolExecutor): """ await PassageManager().insert_passage_async( agent_state=agent_state, - agent_id=agent_state.id, text=content, actor=actor, ) diff --git a/tests/data/list_tools.json b/tests/data/list_tools.json index 9f696f64..919fcd0a 100644 --- a/tests/data/list_tools.json +++ b/tests/data/list_tools.json @@ -9,7 +9,7 @@ "tags": [ "letta_core" ], - "source_code": "def archival_memory_insert(self: \"Agent\", content: str) -> Optional[str]:\n \"\"\"\n Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.\n\n Args:\n content (str): Content to write to the memory. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n \"\"\"\n self.passage_manager.insert_passage(\n agent_state=self.agent_state,\n agent_id=self.agent_state.id,\n text=content,\n actor=self.user,\n )\n return None\n", + "source_code": "def archival_memory_insert(self: \"Agent\", content: str) -> Optional[str]:\n \"\"\"\n Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.\n\n Args:\n content (str): Content to write to the memory. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n \"\"\"\n self.passage_manager.insert_passage(\n agent_state=self.agent_state,\n text=content,\n actor=self.user,\n )\n return None\n", "json_schema": { "name": "archival_memory_insert", "description": "Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.", diff --git a/tests/test_managers.py b/tests/test_managers.py index 892f56e7..3e177e27 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -294,10 +294,15 @@ async def default_run(server: SyncServer, default_user): @pytest.fixture def agent_passage_fixture(server: SyncServer, default_user, sarah_agent): """Fixture to create an agent passage.""" + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + passage = server.passage_manager.create_agent_passage( PydanticPassage( text="Hello, I am an agent passage", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -330,13 +335,18 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau @pytest.fixture def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent, default_source): """Helper function to create test passages for all tests.""" + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + # Create agent passages passages = [] for i in range(5): passage = server.passage_manager.create_agent_passage( PydanticPassage( text=f"Agent passage {i}", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -540,7 +550,14 @@ def server(): @pytest.fixture @pytest.mark.asyncio -async def agent_passages_setup(server, default_source, default_file, default_user, sarah_agent, event_loop): +async def default_archive(server, default_user): + archive = await server.archive_manager.create_archive_async("test", actor=default_user) + yield archive + + +@pytest.fixture +@pytest.mark.asyncio +async def agent_passages_setup(server, default_archive, default_source, default_file, default_user, sarah_agent, event_loop): """Setup fixture for agent passages tests""" agent_id = sarah_agent.id actor = default_user @@ -564,13 +581,18 @@ async def agent_passages_setup(server, default_source, default_file, default_use ) source_passages.append(passage) + # attach archive + await server.archive_manager.attach_agent_to_archive_async( + agent_id=agent_id, archive_id=default_archive.id, is_owner=True, actor=default_user + ) + # Create some agent passages agent_passages = [] for i in range(2): passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( organization_id=actor.organization_id, - agent_id=agent_id, + archive_id=default_archive.id, text=f"Agent passage {i}", embedding=[0.1], # Default OpenAI embedding size embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -2734,6 +2756,11 @@ async def test_agent_list_passages_vector_search( """Test vector search functionality of agent passages""" embed_model = mock_embed_model + # Get or create default archive for the agent + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + # Create passages with known embeddings passages = [] @@ -2753,7 +2780,7 @@ async def test_agent_list_passages_vector_search( passage = PydanticPassage( text=text, organization_id=default_user.organization_id, - agent_id=sarah_agent.id, + archive_id=archive.id, embedding_config=DEFAULT_EMBEDDING_CONFIG, embedding=embedding, ) @@ -2818,7 +2845,7 @@ async def test_list_source_passages_only(server: SyncServer, default_user, defau # Verify we get only source passages (3 from agent_passages_setup) assert len(source_passages) == 3 assert all(p.source_id == default_source.id for p in source_passages) - assert all(p.agent_id is None for p in source_passages) + assert all(p.archive_id is None for p in source_passages) # ====================================================================================================================== @@ -2923,12 +2950,12 @@ async def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, assert agent_passage_fixture is not None assert agent_passage_fixture.text == "Hello, I am an agent passage" - # Try to create an invalid passage (with both agent_id and source_id) + # Try to create an invalid passage (with both archive_id and source_id) with pytest.raises(AssertionError): await server.passage_manager.create_passage_async( PydanticPassage( text="Invalid passage", - agent_id="123", + archive_id="123", source_id="456", organization_id=default_user.organization_id, embedding=[0.1] * 1024, @@ -2970,10 +2997,15 @@ async def test_passage_cascade_deletion( def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test creating an agent passage using the new agent-specific method.""" + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + passage = server.passage_manager.create_agent_passage( PydanticPassage( text="Test agent passage via specific method", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -2984,7 +3016,7 @@ def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_a assert passage.id is not None assert passage.text == "Test agent passage via specific method" - assert passage.agent_id == sarah_agent.id + assert passage.archive_id == archive.id assert passage.source_id is None @@ -3007,13 +3039,13 @@ def test_create_source_passage_specific(server: SyncServer, default_user, defaul assert passage.id is not None assert passage.text == "Test source passage via specific method" assert passage.source_id == default_source.id - assert passage.agent_id is None + assert passage.archive_id is None def test_create_agent_passage_validation(server: SyncServer, default_user, default_source, sarah_agent): """Test that agent passage creation validates inputs correctly.""" - # Should fail if agent_id is missing - with pytest.raises(ValueError, match="Agent passage must have agent_id"): + # Should fail if archive_id is missing + with pytest.raises(ValueError, match="Agent passage must have archive_id"): server.passage_manager.create_agent_passage( PydanticPassage( text="Invalid agent passage", @@ -3024,12 +3056,17 @@ def test_create_agent_passage_validation(server: SyncServer, default_user, defau actor=default_user, ) + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + # Should fail if source_id is present with pytest.raises(ValueError, match="Agent passage cannot have source_id"): server.passage_manager.create_agent_passage( PydanticPassage( text="Invalid agent passage", - agent_id=sarah_agent.id, + archive_id=archive.id, source_id=default_source.id, organization_id=default_user.organization_id, embedding=[0.1], @@ -3054,13 +3091,18 @@ def test_create_source_passage_validation(server: SyncServer, default_user, defa actor=default_user, ) - # Should fail if agent_id is present - with pytest.raises(ValueError, match="Source passage cannot have agent_id"): + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + + # Should fail if archive_id is present + with pytest.raises(ValueError, match="Source passage cannot have archive_id"): server.passage_manager.create_source_passage( PydanticPassage( text="Invalid source passage", source_id=default_source.id, - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -3072,11 +3114,16 @@ def test_create_source_passage_validation(server: SyncServer, default_user, defa def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sarah_agent): """Test retrieving an agent passage using the new agent-specific method.""" + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + # Create an agent passage passage = server.passage_manager.create_agent_passage( PydanticPassage( text="Agent passage for retrieval test", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -3089,7 +3136,7 @@ def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sara assert retrieved is not None assert retrieved.id == passage.id assert retrieved.text == passage.text - assert retrieved.agent_id == sarah_agent.id + assert retrieved.archive_id == archive.id def test_get_source_passage_by_id_specific(server: SyncServer, default_user, default_file, default_source): @@ -3119,10 +3166,15 @@ def test_get_source_passage_by_id_specific(server: SyncServer, default_user, def def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sarah_agent, default_file, default_source): """Test that trying to get the wrong passage type with specific methods fails.""" # Create an agent passage + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + agent_passage = server.passage_manager.create_agent_passage( PydanticPassage( text="Agent passage", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -3155,11 +3207,16 @@ def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sarah_ag def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test updating an agent passage using the new agent-specific method.""" + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + # Create an agent passage passage = server.passage_manager.create_agent_passage( PydanticPassage( text="Original agent passage text", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -3172,7 +3229,7 @@ def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_a passage.id, PydanticPassage( text="Updated agent passage text", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.2], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -3222,11 +3279,16 @@ def test_update_source_passage_specific(server: SyncServer, default_user, defaul def test_delete_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test deleting an agent passage using the new agent-specific method.""" + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + # Create an agent passage passage = server.passage_manager.create_agent_passage( PydanticPassage( text="Agent passage to delete", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -3279,10 +3341,15 @@ def test_delete_source_passage_specific(server: SyncServer, default_user, defaul @pytest.mark.asyncio async def test_create_many_agent_passages_async(server: SyncServer, default_user, sarah_agent, event_loop): """Test creating multiple agent passages using the new batch method.""" + # Get or create default archive for the agent + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + passages = [ PydanticPassage( text=f"Batch agent passage {i}", - agent_id=sarah_agent.id, + archive_id=archive.id, # Now archive is a PydanticArchive object organization_id=default_user.organization_id, embedding=[0.1 * i], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -3290,12 +3357,12 @@ async def test_create_many_agent_passages_async(server: SyncServer, default_user for i in range(3) ] - created_passages = await server.passage_manager.create_many_agent_passages_async(passages, actor=default_user) + created_passages = await server.passage_manager.create_many_archival_passages_async(passages, actor=default_user) assert len(created_passages) == 3 for i, passage in enumerate(created_passages): assert passage.text == f"Batch agent passage {i}" - assert passage.agent_id == sarah_agent.id + assert passage.archive_id == archive.id assert passage.source_id is None @@ -3322,19 +3389,24 @@ async def test_create_many_source_passages_async(server: SyncServer, default_use for i, passage in enumerate(created_passages): assert passage.text == f"Batch source passage {i}" assert passage.source_id == default_source.id - assert passage.agent_id is None + assert passage.archive_id is None def test_agent_passage_size(server: SyncServer, default_user, sarah_agent): """Test counting agent passages using the new agent-specific size method.""" initial_size = server.passage_manager.agent_passage_size(actor=default_user, agent_id=sarah_agent.id) + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + # Create some agent passages for i in range(3): server.passage_manager.create_agent_passage( PydanticPassage( text=f"Agent passage {i} for size test", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -3350,6 +3422,11 @@ def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sara """Test that deprecated methods show deprecation warnings.""" import warnings + # Get or create default archive for the agent + archive = server.archive_manager.get_or_create_default_archive_for_agent( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -3357,7 +3434,7 @@ def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sara passage = server.passage_manager.create_passage( PydanticPassage( text="Test deprecated method", - agent_id=sarah_agent.id, + archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, @@ -8019,7 +8096,6 @@ async def test_create_and_get_batch_item( ) assert item.llm_batch_id == batch.id - assert item.agent_id == sarah_agent.id assert item.step_state == dummy_step_state fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) @@ -8662,7 +8738,6 @@ async def test_attach_creates_association(server, default_user, sarah_agent, def max_files_open=sarah_agent.max_files_open, ) - assert assoc.agent_id == sarah_agent.id assert assoc.file_id == default_file.id assert assoc.is_open is True assert assoc.visible_content == "hello"