From e2d916148e277fa3d2e3bc44c97b2cc00574a91a Mon Sep 17 00:00:00 2001 From: mlong93 <35275280+mlong93@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:24:20 -0800 Subject: [PATCH] feat: separate Passages tables (#2245) Co-authored-by: Mindy Long --- ...54dec07619c4_divide_passage_table_into_.py | 105 +++ letta/agent.py | 63 +- letta/functions/function_sets/base.py | 4 +- letta/functions/schema_generator.py | 6 +- letta/orm/__init__.py | 2 +- letta/orm/agent.py | 20 +- letta/orm/file.py | 5 +- letta/orm/mixins.py | 17 +- letta/orm/organization.py | 22 +- letta/orm/passage.py | 84 ++- letta/orm/source.py | 4 + letta/orm/sqlalchemy_base.py | 4 +- letta/schemas/passage.py | 2 +- letta/server/server.py | 41 +- letta/services/agent_manager.py | 276 +++++++- letta/services/passage_manager.py | 176 +++-- tests/test_client_legacy.py | 1 - tests/test_managers.py | 637 ++++++++++-------- tests/test_server.py | 103 ++- 19 files changed, 1026 insertions(+), 546 deletions(-) create mode 100644 alembic/versions/54dec07619c4_divide_passage_table_into_.py diff --git a/alembic/versions/54dec07619c4_divide_passage_table_into_.py b/alembic/versions/54dec07619c4_divide_passage_table_into_.py new file mode 100644 index 00000000..afe9d418 --- /dev/null +++ b/alembic/versions/54dec07619c4_divide_passage_table_into_.py @@ -0,0 +1,105 @@ +"""divide passage table into SourcePassages and AgentPassages + +Revision ID: 54dec07619c4 +Revises: 4e88e702f85e +Create Date: 2024-12-14 17:23:08.772554 + +""" +from typing import Sequence, Union + +from alembic import op +from pgvector.sqlalchemy import Vector +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from letta.orm.custom_columns import EmbeddingConfigColumn + +# revision identifiers, used by Alembic. +revision: str = '54dec07619c4' +down_revision: Union[str, None] = '4e88e702f85e' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'agent_passages', + sa.Column('id', sa.String(), nullable=False), + sa.Column('text', sa.String(), nullable=False), + sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False), + sa.Column('metadata_', sa.JSON(), nullable=False), + sa.Column('embedding', Vector(dim=4096), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), + sa.Column('_created_by_id', sa.String(), nullable=True), + sa.Column('_last_updated_by_id', sa.String(), nullable=True), + sa.Column('organization_id', sa.String(), nullable=False), + sa.Column('agent_id', sa.String(), nullable=False), + sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False) + op.create_table( + 'source_passages', + sa.Column('id', sa.String(), nullable=False), + sa.Column('text', sa.String(), nullable=False), + sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False), + sa.Column('metadata_', sa.JSON(), nullable=False), + sa.Column('embedding', Vector(dim=4096), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False), + sa.Column('_created_by_id', sa.String(), nullable=True), + sa.Column('_last_updated_by_id', sa.String(), nullable=True), + sa.Column('organization_id', sa.String(), nullable=False), + sa.Column('file_id', sa.String(), nullable=True), + sa.Column('source_id', sa.String(), nullable=False), + sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ), + sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False) + op.drop_table('passages') + op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey') + op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE') + op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey') + op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'messages', type_='foreignkey') + op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id']) + op.drop_constraint(None, 'files', type_='foreignkey') + op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id']) + op.create_table( + 'passages', + sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True), + sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False), + sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False), + sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False), + sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True), + sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False), + sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True), + sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'), + sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'), + sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'), + sa.PrimaryKeyConstraint('id', name='passages_pkey') + ) + op.drop_index('source_passages_org_idx', table_name='source_passages') + op.drop_table('source_passages') + op.drop_index('agent_passages_org_idx', table_name='agent_passages') + op.drop_table('agent_passages') + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 71bcd71a..341b25fd 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -41,7 +41,6 @@ from letta.schemas.openai.chat_completion_response import ( Message as ChatCompletionMessage, ) from letta.schemas.openai.chat_completion_response import UsageStatistics -from letta.schemas.passage import Passage from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics @@ -82,7 +81,7 @@ def compile_memory_metadata_block( actor: PydanticUser, agent_id: str, memory_edit_timestamp: datetime.datetime, - passage_manager: Optional[PassageManager] = None, + agent_manager: Optional[AgentManager] = None, message_manager: Optional[MessageManager] = None, ) -> str: # Put the timestamp in the local timezone (mimicking get_local_time()) @@ -93,7 +92,7 @@ def compile_memory_metadata_block( [ f"### Memory [last modified: {timestamp_str}]", f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)", - f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)", + f"{agent_manager.passage_size(actor=actor, agent_id=agent_id) if agent_manager else 0} total memories you created are stored in archival memory (use functions to access them)", "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):", ] ) @@ -106,7 +105,7 @@ def compile_system_message( in_context_memory: Memory, in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory? actor: PydanticUser, - passage_manager: Optional[PassageManager] = None, + agent_manager: Optional[AgentManager] = None, message_manager: Optional[MessageManager] = None, user_defined_variables: Optional[dict] = None, append_icm_if_missing: bool = True, @@ -135,7 +134,7 @@ def compile_system_message( actor=actor, agent_id=agent_id, memory_edit_timestamp=in_context_memory_last_edit, - passage_manager=passage_manager, + agent_manager=agent_manager, message_manager=message_manager, ) full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile() @@ -172,7 +171,7 @@ def initialize_message_sequence( agent_id: str, memory: Memory, actor: PydanticUser, - passage_manager: Optional[PassageManager] = None, + agent_manager: Optional[AgentManager] = None, message_manager: Optional[MessageManager] = None, memory_edit_timestamp: Optional[datetime.datetime] = None, include_initial_boot_message: bool = True, @@ -181,7 +180,7 @@ def initialize_message_sequence( memory_edit_timestamp = get_local_time() # full_system_message = construct_system_with_memory( - # system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory + # system, memory, memory_edit_timestamp, agent_manager=agent_manager, recall_memory=recall_memory # ) full_system_message = compile_system_message( agent_id=agent_id, @@ -189,7 +188,7 @@ def initialize_message_sequence( in_context_memory=memory, in_context_memory_last_edit=memory_edit_timestamp, actor=actor, - passage_manager=passage_manager, + agent_manager=agent_manager, message_manager=message_manager, user_defined_variables=None, append_icm_if_missing=True, @@ -291,8 +290,9 @@ class Agent(BaseAgent): self.interface = interface # Create the persistence manager object based on the AgentState info - self.passage_manager = PassageManager() self.message_manager = MessageManager() + self.passage_manager = PassageManager() + self.agent_manager = AgentManager() # State needed for heartbeat pausing self.pause_heartbeats_start = None @@ -322,7 +322,7 @@ class Agent(BaseAgent): agent_id=self.agent_state.id, memory=self.agent_state.memory, actor=self.user, - passage_manager=None, + agent_manager=None, message_manager=None, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True, @@ -347,7 +347,7 @@ class Agent(BaseAgent): memory=self.agent_state.memory, agent_id=self.agent_state.id, actor=self.user, - passage_manager=None, + agent_manager=None, message_manager=None, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True, @@ -1297,7 +1297,7 @@ class Agent(BaseAgent): in_context_memory=self.agent_state.memory, in_context_memory_last_edit=memory_edit_timestamp, actor=self.user, - passage_manager=self.passage_manager, + agent_manager=self.agent_manager, message_manager=self.message_manager, user_defined_variables=None, append_icm_if_missing=True, @@ -1368,33 +1368,24 @@ class Agent(BaseAgent): source_id: str, source_manager: SourceManager, agent_manager: AgentManager, - page_size: Optional[int] = None, ): - """Attach data with name `source_name` to the agent from source_connector.""" - # TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory - passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size) - - for passage in passages: - assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}" - passage.agent_id = self.agent_state.id - self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user) - - agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size) - passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id) - assert all([p.agent_id == self.agent_state.id for p in agents_passages]) - assert len(agents_passages) == passage_size # sanity check - assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}" - - # attach to agent + """Attach a source to the agent using the SourcesAgents ORM relationship. + + Args: + user: User performing the action + source_id: ID of the source to attach + source_manager: SourceManager instance to verify source exists + agent_manager: AgentManager instance to manage agent-source relationship + """ + # Verify source exists and user has permission to access it source = source_manager.get_source_by_id(source_id=source_id, actor=user) - assert source is not None, f"Source {source_id} not found in metadata store" + assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})" - # NOTE: need this redundant line here because we haven't migrated agent to ORM yet - # TODO: delete @matt and remove + # Use the agent_manager to create the relationship agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user) printd( - f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.", + f"Attached data source {source.name} to agent {self.agent_state.name}.", ) def update_message(self, message_id: str, request: MessageUpdate) -> Message: @@ -1550,13 +1541,13 @@ class Agent(BaseAgent): num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0 ) - passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id) + agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id) message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id) external_memory_summary = compile_memory_metadata_block( actor=self.user, agent_id=self.agent_state.id, memory_edit_timestamp=get_utc_time(), # dummy timestamp - passage_manager=self.passage_manager, + agent_manager=self.agent_manager, message_manager=self.message_manager, ) num_tokens_external_memory_summary = count_tokens(external_memory_summary) @@ -1582,7 +1573,7 @@ class Agent(BaseAgent): return ContextWindowOverview( # context window breakdown (in messages) num_messages=len(self._messages), - num_archival_memory=passage_manager_size, + num_archival_memory=agent_manager_passage_size, num_recall_memory=message_manager_size, num_tokens_external_memory_summary=num_tokens_external_memory_summary, # top-level information diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index cdcad3ac..e35739dd 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -3,6 +3,7 @@ from typing import Optional from letta.agent import Agent from letta.constants import MAX_PAUSE_HEARTBEATS +from letta.services.agent_manager import AgentManager # import math # from letta.utils import json_dumps @@ -200,8 +201,9 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s try: # Get results using passage manager - all_results = self.passage_manager.list_passages( + all_results = self.agent_manager.list_passages( actor=self.user, + agent_id=self.agent_state.id, query_text=query, limit=count + start, # Request enough results to handle offset embedding_config=self.agent_state.embedding_config, diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index e36efc07..170bea30 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -312,11 +312,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ for param in sig.parameters.values(): # Exclude 'self' parameter # TODO: eventually remove this (only applies to BASE_TOOLS) - if param.name == "self": - continue - - # exclude 'agent_state' parameter - if param.name == "agent_state": + if param.name in ["self", "agent_state"]: # Add agent_manager to excluded continue # Assert that the parameter has a type annotation diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index ed8f2460..8a0f0c77 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata from letta.orm.job import Job from letta.orm.message import Message from letta.orm.organization import Organization -from letta.orm.passage import Passage +from letta.orm.passage import BasePassage, AgentPassage, SourcePassage from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.sources_agents import SourcesAgents diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 0f086e27..c4645c3e 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -82,7 +82,25 @@ class Agent(SqlalchemyBase, OrganizationMixin): lazy="selectin", doc="Tags associated with the agent.", ) - # passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="agent", lazy="selectin") + source_passages: Mapped[List["SourcePassage"]] = relationship( + "SourcePassage", + secondary="sources_agents", # The join table for Agent -> Source + primaryjoin="Agent.id == sources_agents.c.agent_id", + secondaryjoin="and_(SourcePassage.source_id == sources_agents.c.source_id)", + lazy="selectin", + order_by="SourcePassage.created_at.desc()", + viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship + doc="All passages derived from sources associated with this agent.", + ) + agent_passages: Mapped[List["AgentPassage"]] = relationship( + "AgentPassage", + back_populates="agent", + lazy="selectin", + order_by="AgentPassage.created_at.desc()", + cascade="all, delete-orphan", + viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship + doc="All passages derived created by this agent.", + ) def to_pydantic(self) -> PydanticAgentState: """converts to the basic pydantic model counterpart""" diff --git a/letta/orm/file.py b/letta/orm/file.py index 6f711163..45470c6c 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -9,7 +9,8 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata if TYPE_CHECKING: from letta.orm.organization import Organization - + from letta.orm.source import Source + from letta.orm.passage import SourcePassage class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): """Represents metadata for an uploaded file.""" @@ -27,4 +28,4 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin") source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin") - passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="file", lazy="selectin", cascade="all, delete-orphan") + source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan") diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index 60c319d9..328772d7 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -31,30 +31,19 @@ class UserMixin(Base): user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id")) -class FileMixin(Base): - """Mixin for models that belong to a file.""" - - __abstract__ = True - - file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id")) - class AgentMixin(Base): """Mixin for models that belong to an agent.""" __abstract__ = True - agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id")) + agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE")) class FileMixin(Base): """Mixin for models that belong to a file.""" __abstract__ = True - file_id: Mapped[Optional[str]] = mapped_column( - String, - ForeignKey("files.id", ondelete="CASCADE"), - nullable=True - ) + file_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("files.id", ondelete="CASCADE")) class SourceMixin(Base): @@ -62,7 +51,7 @@ class SourceMixin(Base): __abstract__ = True - source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id")) + source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), nullable=False) class SandboxConfigMixin(Base): diff --git a/letta/orm/organization.py b/letta/orm/organization.py index bed2b00f..9a71a09b 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Union from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -35,6 +35,22 @@ class Organization(SqlalchemyBase): ) # relationships - messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan") agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") - passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan") + messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan") + source_passages: Mapped[List["SourcePassage"]] = relationship( + "SourcePassage", + back_populates="organization", + cascade="all, delete-orphan" + ) + agent_passages: Mapped[List["AgentPassage"]] = relationship( + "AgentPassage", + back_populates="organization", + cascade="all, delete-orphan" + ) + + @property + def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]: + """Convenience property to get all passages""" + return self.source_passages + self.agent_passages + + diff --git a/letta/orm/passage.py b/letta/orm/passage.py index a53e1d24..d3887841 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -1,39 +1,35 @@ -from datetime import datetime -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING +from sqlalchemy import Column, JSON, Index +from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr -from sqlalchemy import JSON, Column, DateTime, ForeignKey, String -from sqlalchemy.orm import Mapped, mapped_column, relationship +from letta.orm.mixins import FileMixin, OrganizationMixin +from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin +from letta.schemas.passage import Passage as PydanticPassage +from letta.settings import settings from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM -from letta.orm.custom_columns import CommonVector -from letta.orm.mixins import FileMixin, OrganizationMixin -from letta.orm.source import EmbeddingConfigColumn -from letta.orm.sqlalchemy_base import SqlalchemyBase -from letta.schemas.passage import Passage as PydanticPassage -from letta.settings import settings config = LettaConfig() if TYPE_CHECKING: from letta.orm.organization import Organization + from letta.orm.agent import Agent -# TODO: After migration to Passage, will need to manually delete passages where files -# are deleted on web -class Passage(SqlalchemyBase, OrganizationMixin, FileMixin): - """Defines data model for storing Passages""" - - __tablename__ = "passages" - __table_args__ = {"extend_existing": True} +class BasePassage(SqlalchemyBase, OrganizationMixin): + """Base class for all passage types with common fields""" + __abstract__ = True __pydantic_model__ = PydanticPassage id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier") text: Mapped[str] = mapped_column(doc="Passage text content") - source_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Source identifier") embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration") metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata") - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow) + + # Vector embedding field based on database type if settings.letta_pg_uri_no_default: from pgvector.sqlalchemy import Vector @@ -41,9 +37,49 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin): else: embedding = Column(CommonVector) - # Foreign keys - agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id"), nullable=True) + @declared_attr + def organization(cls) -> Mapped["Organization"]: + """Relationship to organization""" + return relationship("Organization", back_populates="passages", lazy="selectin") - # Relationships - organization: Mapped["Organization"] = relationship("Organization", back_populates="passages", lazy="selectin") - file: Mapped["FileMetadata"] = relationship("FileMetadata", back_populates="passages", lazy="selectin") + @declared_attr + def __table_args__(cls): + if settings.letta_pg_uri_no_default: + return ( + Index(f'{cls.__tablename__}_org_idx', 'organization_id'), + {"extend_existing": True} + ) + return ({"extend_existing": True},) + + +class SourcePassage(BasePassage, FileMixin, SourceMixin): + """Passages derived from external files/sources""" + __tablename__ = "source_passages" + + @declared_attr + def file(cls) -> Mapped["FileMetadata"]: + """Relationship to file""" + return relationship("FileMetadata", back_populates="source_passages", lazy="selectin") + + @declared_attr + def organization(cls) -> Mapped["Organization"]: + return relationship("Organization", back_populates="source_passages", lazy="selectin") + + @declared_attr + def source(cls) -> Mapped["Source"]: + """Relationship to source""" + return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True) + + +class AgentPassage(BasePassage, AgentMixin): + """Passages created by agents as archival memories""" + __tablename__ = "agent_passages" + + @declared_attr + def organization(cls) -> Mapped["Organization"]: + return relationship("Organization", back_populates="agent_passages", lazy="selectin") + + @declared_attr + def agent(cls) -> Mapped["Agent"]: + """Relationship to agent""" + return relationship("Agent", back_populates="agent_passages", lazy="selectin", passive_deletes=True) diff --git a/letta/orm/source.py b/letta/orm/source.py index c3fbdf65..3ecffda6 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -12,6 +12,9 @@ from letta.schemas.source import Source as PydanticSource if TYPE_CHECKING: from letta.orm.organization import Organization + from letta.orm.file import FileMetadata + from letta.orm.passage import SourcePassage + from letta.orm.agent import Agent class Source(SqlalchemyBase, OrganizationMixin): @@ -28,4 +31,5 @@ class Source(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="sources") files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan") + passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan") agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources") diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index d13e85b1..48b8c44a 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -3,7 +3,7 @@ from enum import Enum from typing import TYPE_CHECKING, List, Literal, Optional from sqlalchemy import String, desc, func, or_, select -from sqlalchemy.exc import DBAPIError +from sqlalchemy.exc import DBAPIError, IntegrityError from sqlalchemy.orm import Mapped, Session, mapped_column from letta.log import get_logger @@ -242,7 +242,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): session.commit() session.refresh(self) return self - except DBAPIError as e: + except (DBAPIError, IntegrityError) as e: self._handle_dbapi_error(e) def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index faa520c0..c1ec13be 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -10,7 +10,7 @@ from letta.utils import get_utc_time class PassageBase(OrmMetadataBase): - __id_prefix__ = "passage_legacy" + __id_prefix__ = "passage" is_deleted: bool = Field(False, description="Whether this passage is deleted or not.") diff --git a/letta/server/server.py b/letta/server/server.py index b01bfd34..3eb3207a 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -932,7 +932,7 @@ class SyncServer(Server): def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary: agent = self.load_agent(agent_id=agent_id, actor=actor) - return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user)) + return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id)) def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary: agent = self.load_agent(agent_id=agent_id, actor=actor) @@ -949,18 +949,9 @@ class SyncServer(Server): # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user actor = self.user_manager.get_user_or_default(user_id=user_id) - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) + passages = self.agent_manager.list_passages(agent_id=agent_id, actor=actor) - # iterate over records - records = letta_agent.passage_manager.list_passages( - actor=actor, - agent_id=agent_id, - cursor=cursor, - limit=limit, - ) - - return records + return passages def get_agent_archival_cursor( self, @@ -974,15 +965,13 @@ class SyncServer(Server): # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user actor = self.user_manager.get_user_or_default(user_id=user_id) - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - # iterate over records - records = letta_agent.passage_manager.list_passages( - actor=self.default_user, + records = self.agent_manager.list_passages( + actor=actor, agent_id=agent_id, cursor=cursor, limit=limit, + ascending=not reverse, ) return records @@ -1098,7 +1087,8 @@ class SyncServer(Server): self.source_manager.delete_source(source_id=source_id, actor=actor) # delete data from passage store - self.passage_manager.delete_passages(actor=actor, limit=None, source_id=source_id) + passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None) + self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted) # TODO: delete data from agent passage stores (?) @@ -1129,9 +1119,11 @@ class SyncServer(Server): for agent_state in agent_states: agent_id = agent_state.id agent = self.load_agent(agent_id=agent_id, actor=actor) - curr_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id) + + # Attach source to agent + curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id) agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager) - new_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id) + new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id) assert new_passage_size >= curr_passage_size # in case empty files are added return job @@ -1195,14 +1187,9 @@ class SyncServer(Server): source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor) elif source_name: source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor) + source_id = source.id else: raise ValueError(f"Need to provide at least source_id or source_name to find the source.") - source_id = source.id - - # TODO: This should be done with the ORM? - # delete all Passage objects with source_id==source_id from agent's archival memory - agent = self.load_agent(agent_id=agent_id, actor=actor) - agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id) # delete agent-source mapping self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor) @@ -1224,7 +1211,7 @@ class SyncServer(Server): for source in sources: # count number of passages - num_passages = self.passage_manager.size(actor=actor, source_id=source.id) + num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id) # TODO: add when files table implemented ## count number of files diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 52a526f9..d1edb3ea 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1,17 +1,26 @@ from typing import Dict, List, Optional +from datetime import datetime +import numpy as np -from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from sqlalchemy import select, union_all, literal, func, Select + +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM +from letta.embeddings import embedding_model from letta.log import get_logger from letta.orm import Agent as AgentModel from letta.orm import Block as BlockModel from letta.orm import Source as SourceModel from letta.orm import Tool as ToolModel +from letta.orm import AgentPassage, SourcePassage +from letta.orm import SourcesAgents from letta.orm.errors import NoResultFound +from letta.orm.sqlite_functions import adapt_array from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.source import Source as PydanticSource from letta.schemas.tool_rule import ToolRule as PydanticToolRule from letta.schemas.user import User as PydanticUser @@ -21,9 +30,9 @@ from letta.services.helpers.agent_manager_helper import ( _process_tags, derive_system_message, ) -from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager +from letta.settings import settings from letta.utils import enforce_types logger = get_logger(__name__) @@ -229,13 +238,6 @@ class AgentManager: with self.session_maker() as session: # Retrieve the agent agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) - - # TODO: @mindy delete this piece when we have a proper passages/sources implementation - # TODO: This is done very hacky on purpose - # TODO: 1000 limit is also wack - passage_manager = PassageManager() - passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000) - agent_state = agent.to_pydantic() agent.hard_delete(session) return agent_state @@ -407,6 +409,262 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + # ====================================================================================================================== + # Passage Management + # ====================================================================================================================== + def _build_passage_query( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + query_text: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + cursor: Optional[str] = None, + source_id: Optional[str] = None, + embed_query: bool = False, + ascending: bool = True, + embedding_config: Optional[EmbeddingConfig] = None, + agent_only: bool = False, + ) -> Select: + """Helper function to build the base passage query with all filters applied. + + Returns the query before any limit or count operations are applied. + """ + embedded_text = None + if embed_query: + assert embedding_config is not None, "embedding_config must be specified for vector search" + assert query_text is not None, "query_text must be specified for vector search" + embedded_text = embedding_model(embedding_config).get_text_embedding(query_text) + embedded_text = np.array(embedded_text) + embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() + + with self.session_maker() as session: + # Start with base query for source passages + source_passages = None + if not agent_only: # Include source passages + if agent_id is not None: + source_passages = ( + select( + SourcePassage, + literal(None).label('agent_id') + ) + .join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id) + .where(SourcesAgents.agent_id == agent_id) + .where(SourcePassage.organization_id == actor.organization_id) + ) + else: + source_passages = ( + select( + SourcePassage, + literal(None).label('agent_id') + ) + .where(SourcePassage.organization_id == actor.organization_id) + ) + + if source_id: + source_passages = source_passages.where(SourcePassage.source_id == source_id) + if file_id: + source_passages = source_passages.where(SourcePassage.file_id == file_id) + + # Add agent passages query + agent_passages = None + if agent_id is not None: + agent_passages = ( + select( + AgentPassage.id, + AgentPassage.text, + AgentPassage.embedding_config, + AgentPassage.metadata_, + AgentPassage.embedding, + AgentPassage.created_at, + AgentPassage.updated_at, + AgentPassage.is_deleted, + AgentPassage._created_by_id, + AgentPassage._last_updated_by_id, + AgentPassage.organization_id, + literal(None).label('file_id'), + literal(None).label('source_id'), + AgentPassage.agent_id + ) + .where(AgentPassage.agent_id == agent_id) + .where(AgentPassage.organization_id == actor.organization_id) + ) + + # Combine queries + if source_passages is not None and agent_passages is not None: + combined_query = union_all(source_passages, agent_passages).cte('combined_passages') + elif agent_passages is not None: + combined_query = agent_passages.cte('combined_passages') + elif source_passages is not None: + combined_query = source_passages.cte('combined_passages') + else: + raise ValueError("No passages found") + + # Build main query from combined CTE + main_query = select(combined_query) + + # Apply filters + if start_date: + main_query = main_query.where(combined_query.c.created_at >= start_date) + if end_date: + main_query = main_query.where(combined_query.c.created_at <= end_date) + if source_id: + main_query = main_query.where(combined_query.c.source_id == source_id) + if file_id: + main_query = main_query.where(combined_query.c.file_id == file_id) + + # Vector search + if embedded_text: + if settings.letta_pg_uri_no_default: + # PostgreSQL with pgvector + main_query = main_query.order_by( + combined_query.c.embedding.cosine_distance(embedded_text).asc() + ) + else: + # SQLite with custom vector type + query_embedding_binary = adapt_array(embedded_text) + if ascending: + main_query = main_query.order_by( + func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(), + combined_query.c.created_at.asc(), + combined_query.c.id.asc() + ) + else: + main_query = main_query.order_by( + func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(), + combined_query.c.created_at.desc(), + combined_query.c.id.asc() + ) + else: + if query_text: + main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text))) + + # Handle cursor-based pagination + if cursor: + cursor_query = select(combined_query.c.created_at).where( + combined_query.c.id == cursor + ).scalar_subquery() + + if ascending: + main_query = main_query.where( + combined_query.c.created_at > cursor_query + ) + else: + main_query = main_query.where( + combined_query.c.created_at < cursor_query + ) + + # Add ordering if not already ordered by similarity + if not embed_query: + if ascending: + main_query = main_query.order_by( + combined_query.c.created_at.asc(), + combined_query.c.id.asc(), + ) + else: + main_query = main_query.order_by( + combined_query.c.created_at.desc(), + combined_query.c.id.asc(), + ) + + return main_query + + @enforce_types + def list_passages( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + limit: Optional[int] = 50, + query_text: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + cursor: Optional[str] = None, + source_id: Optional[str] = None, + embed_query: bool = False, + ascending: bool = True, + embedding_config: Optional[EmbeddingConfig] = None, + agent_only: bool = False + ) -> List[PydanticPassage]: + """Lists all passages attached to an agent.""" + with self.session_maker() as session: + main_query = self._build_passage_query( + actor=actor, + agent_id=agent_id, + file_id=file_id, + query_text=query_text, + start_date=start_date, + end_date=end_date, + cursor=cursor, + source_id=source_id, + embed_query=embed_query, + ascending=ascending, + embedding_config=embedding_config, + agent_only=agent_only, + ) + + # Add limit + if limit: + main_query = main_query.limit(limit) + + # Execute query + results = list(session.execute(main_query)) + + passages = [] + for row in results: + data = dict(row._mapping) + if data['agent_id'] is not None: + # This is an AgentPassage - remove source fields + data.pop('source_id', None) + data.pop('file_id', None) + passage = AgentPassage(**data) + else: + # This is a SourcePassage - remove agent field + data.pop('agent_id', None) + passage = SourcePassage(**data) + passages.append(passage) + + return [p.to_pydantic() for p in passages] + + + @enforce_types + def passage_size( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + query_text: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + cursor: Optional[str] = None, + source_id: Optional[str] = None, + embed_query: bool = False, + ascending: bool = True, + embedding_config: Optional[EmbeddingConfig] = None, + agent_only: bool = False + ) -> int: + """Returns the count of passages matching the given criteria.""" + with self.session_maker() as session: + main_query = self._build_passage_query( + actor=actor, + agent_id=agent_id, + file_id=file_id, + query_text=query_text, + start_date=start_date, + end_date=end_date, + cursor=cursor, + source_id=source_id, + embed_query=embed_query, + ascending=ascending, + embedding_config=embedding_config, + agent_only=agent_only, + ) + + # Convert to count query + count_query = select(func.count()).select_from(main_query.subquery()) + return session.scalar(count_query) or 0 + # ====================================================================================================================== # Tool Management # ====================================================================================================================== diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 100a4433..d8554063 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -1,12 +1,13 @@ -from datetime import datetime from typing import List, Optional - +from datetime import datetime import numpy as np +from sqlalchemy import select, union_all, literal + from letta.constants import MAX_EMBEDDING_DIM from letta.embeddings import embedding_model, parse_and_chunk_text from letta.orm.errors import NoResultFound -from letta.orm.passage import Passage as PassageModel +from letta.orm.passage import AgentPassage, SourcePassage from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.passage import Passage as PydanticPassage @@ -14,6 +15,7 @@ from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types + class PassageManager: """Manager class to handle business logic related to Passages.""" @@ -26,14 +28,51 @@ class PassageManager: def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: """Fetch a passage by ID.""" with self.session_maker() as session: - passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor) - return passage.to_pydantic() + # Try source passages first + try: + passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) + return passage.to_pydantic() + except NoResultFound: + # Try archival passages + try: + passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor) + return passage.to_pydantic() + except NoResultFound: + raise NoResultFound(f"Passage with id {passage_id} not found in database.") @enforce_types def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: - """Create a new passage.""" + """Create a new passage in the appropriate table based on whether it has agent_id or source_id.""" + # Common fields for both passage types + data = pydantic_passage.model_dump() + common_fields = { + "id": data.get("id"), + "text": data["text"], + "embedding": data["embedding"], + "embedding_config": data["embedding_config"], + "organization_id": data["organization_id"], + "metadata_": data.get("metadata_", {}), + "is_deleted": data.get("is_deleted", False), + "created_at": data.get("created_at", datetime.utcnow()), + } + + if "agent_id" in data and data["agent_id"]: + assert not data.get("source_id"), "Passage cannot have both agent_id and source_id" + agent_fields = { + "agent_id": data["agent_id"], + } + passage = AgentPassage(**common_fields, **agent_fields) + elif "source_id" in data and data["source_id"]: + assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id" + source_fields = { + "source_id": data["source_id"], + "file_id": data.get("file_id"), + } + passage = SourcePassage(**common_fields, **source_fields) + else: + raise ValueError("Passage must have either agent_id or source_id") + with self.session_maker() as session: - passage = PassageModel(**pydantic_passage.model_dump()) passage.create(session, actor=actor) return passage.to_pydantic() @@ -93,14 +132,23 @@ class PassageManager: raise ValueError("Passage ID must be provided.") with self.session_maker() as session: - # Fetch existing message from database - curr_passage = PassageModel.read( - db_session=session, - identifier=passage_id, - actor=actor, - ) - if not curr_passage: - raise ValueError(f"Passage with id {passage_id} does not exist.") + # Try source passages first + try: + curr_passage = SourcePassage.read( + db_session=session, + identifier=passage_id, + actor=actor, + ) + except NoResultFound: + # Try agent passages + try: + curr_passage = AgentPassage.read( + db_session=session, + identifier=passage_id, + actor=actor, + ) + except NoResultFound: + raise ValueError(f"Passage with id {passage_id} does not exist.") # Update the database record with values from the provided record update_data = passage.model_dump(exclude_unset=True, exclude_none=True) @@ -113,104 +161,32 @@ class PassageManager: @enforce_types def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: - """Delete a passage.""" + """Delete a passage from either source or archival passages.""" if not passage_id: raise ValueError("Passage ID must be provided.") with self.session_maker() as session: + # Try source passages first try: - passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor) + passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) passage.hard_delete(session, actor=actor) + return True except NoResultFound: - raise ValueError(f"Passage with id {passage_id} not found.") - - @enforce_types - def list_passages( - self, - actor: PydanticUser, - agent_id: Optional[str] = None, - file_id: Optional[str] = None, - cursor: Optional[str] = None, - limit: Optional[int] = 50, - query_text: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - ascending: bool = True, - source_id: Optional[str] = None, - embed_query: bool = False, - embedding_config: Optional[EmbeddingConfig] = None, - ) -> List[PydanticPassage]: - """List passages with pagination.""" - with self.session_maker() as session: - filters = {"organization_id": actor.organization_id} - if agent_id: - filters["agent_id"] = agent_id - if file_id: - filters["file_id"] = file_id - if source_id: - filters["source_id"] = source_id - - embedded_text = None - if embed_query: - assert embedding_config is not None - - # Embed the text - embedded_text = embedding_model(embedding_config).get_text_embedding(query_text) - - # Pad the embedding with zeros - embedded_text = np.array(embedded_text) - embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() - - results = PassageModel.list( - db_session=session, - cursor=cursor, - start_date=start_date, - end_date=end_date, - limit=limit, - ascending=ascending, - query_text=query_text if not embedded_text else None, - query_embedding=embedded_text, - **filters, - ) - return [p.to_pydantic() for p in results] - - @enforce_types - def size(self, actor: PydanticUser, agent_id: Optional[str] = None, **kwargs) -> int: - """Get the total count of messages with optional filters. - - Args: - actor : The user requesting the count - agent_id: The agent ID - """ - with self.session_maker() as session: - return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs) + # Try archival passages + try: + passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor) + passage.hard_delete(session, actor=actor) + return True + except NoResultFound: + raise NoResultFound(f"Passage with id {passage_id} not found.") def delete_passages( self, actor: PydanticUser, - agent_id: Optional[str] = None, - file_id: Optional[str] = None, - start_date: Optional[datetime] = None, - end_date: Optional[datetime] = None, - limit: Optional[int] = 50, - cursor: Optional[str] = None, - query_text: Optional[str] = None, - source_id: Optional[str] = None, + passages: List[PydanticPassage], ) -> bool: - - passages = self.list_passages( - actor=actor, - agent_id=agent_id, - file_id=file_id, - cursor=cursor, - limit=limit, - start_date=start_date, - end_date=end_date, - query_text=query_text, - source_id=source_id, - ) - # TODO: This is very inefficient # TODO: We should have a base `delete_all_matching_filters`-esque function for passage in passages: self.delete_passage_by_id(passage_id=passage.id, actor=actor) + return True diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 3839611b..e0b51255 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -482,7 +482,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): # check agent archival memory size archival_memories = client.get_archival_memory(agent_id=agent.id) - print(archival_memories) assert len(archival_memories) == 0 # load a file into a source (non-blocking job) diff --git a/tests/test_managers.py b/tests/test_managers.py index 3df4e8b5..dc5f15ad 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2,6 +2,8 @@ import os import time from datetime import datetime, timedelta +from httpx._transports import default +from numpy import source import pytest from sqlalchemy import delete from sqlalchemy.exc import IntegrityError @@ -17,7 +19,8 @@ from letta.orm import ( Job, Message, Organization, - Passage, + AgentPassage, + SourcePassage, SandboxConfig, SandboxEnvironmentVariable, Source, @@ -82,7 +85,8 @@ def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: session.execute(delete(Message)) - session.execute(delete(Passage)) + session.execute(delete(AgentPassage)) + session.execute(delete(SourcePassage)) session.execute(delete(Job)) session.execute(delete(ToolsAgents)) # Clear ToolsAgents first session.execute(delete(BlocksAgents)) @@ -189,39 +193,79 @@ def print_tool(server: SyncServer, default_user, default_organization): @pytest.fixture -def hello_world_passage_fixture(server: SyncServer, default_user, default_file, sarah_agent): - """Fixture to create a tool with default settings and clean up after the test.""" - # Set up passage - dummy_embedding = [0.0] * 2 - message = PydanticPassage( - organization_id=default_user.organization_id, - agent_id=sarah_agent.id, - file_id=default_file.id, - text="Hello, world!", - embedding=dummy_embedding, - embedding_config=DEFAULT_EMBEDDING_CONFIG, +def agent_passage_fixture(server: SyncServer, default_user, sarah_agent): + """Fixture to create an agent passage.""" + passage = server.passage_manager.create_passage( + PydanticPassage( + text="Hello, I am an agent passage", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + metadata_={"type": "test"} + ), + actor=default_user ) - - msg = server.passage_manager.create_passage(message, actor=default_user) - yield msg + yield passage @pytest.fixture -def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent) -> list[PydanticPassage]: - """Helper function to create test passages for all tests""" - dummy_embedding = [0] * 2 - passages = [ +def source_passage_fixture(server: SyncServer, default_user, default_file, default_source): + """Fixture to create a source passage.""" + passage = server.passage_manager.create_passage( PydanticPassage( - organization_id=default_user.organization_id, - agent_id=sarah_agent.id, + text="Hello, I am a source passage", + source_id=default_source.id, file_id=default_file.id, - text=f"Test passage {i}", - embedding=dummy_embedding, + organization_id=default_user.organization_id, + embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, + metadata_={"type": "test"} + ), + actor=default_user + ) + yield passage + + +@pytest.fixture +def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent, default_source): + """Helper function to create test passages for all tests.""" + # Create agent passages + passages = [] + for i in range(5): + passage = server.passage_manager.create_passage( + PydanticPassage( + text=f"Agent passage {i}", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + metadata_={"type": "test"} + ), + actor=default_user ) - for i in range(4) - ] - server.passage_manager.create_many_passages(passages, actor=default_user) + passages.append(passage) + if USING_SQLITE: + time.sleep(CREATE_DELAY_SQLITE) + + # Create source passages + for i in range(5): + passage = server.passage_manager.create_passage( + PydanticPassage( + text=f"Source passage {i}", + source_id=default_source.id, + file_id=default_file.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + metadata_={"type": "test"} + ), + actor=default_user + ) + passages.append(passage) + if USING_SQLITE: + time.sleep(CREATE_DELAY_SQLITE) + return passages @@ -389,6 +433,49 @@ def server(): return server +@pytest.fixture +def agent_passages_setup(server, default_source, default_user, sarah_agent): + """Setup fixture for agent passages tests""" + agent_id = sarah_agent.id + actor = default_user + + server.agent_manager.attach_source(agent_id=agent_id, source_id=default_source.id, actor=actor) + + # Create some source passages + source_passages = [] + for i in range(3): + passage = server.passage_manager.create_passage( + PydanticPassage( + organization_id=actor.organization_id, + source_id=default_source.id, + text=f"Source passage {i}", + embedding=[0.1], # Default OpenAI embedding size + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=actor + ) + source_passages.append(passage) + + # Create some agent passages + agent_passages = [] + for i in range(2): + passage = server.passage_manager.create_passage( + PydanticPassage( + organization_id=actor.organization_id, + agent_id=agent_id, + text=f"Agent passage {i}", + embedding=[0.1], # Default OpenAI embedding size + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=actor + ) + agent_passages.append(passage) + + yield agent_passages, source_passages + + # Cleanup + server.source_manager.delete_source(default_source.id, actor=actor) + # ====================================================================================================================== # AgentManager Tests - Basic # ====================================================================================================================== @@ -849,6 +936,199 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de assert block.label == default_block.label +# ====================================================================================================================== +# Agent Manager - Passages Tests +# ====================================================================================================================== + +def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup): + """Test basic listing functionality of agent passages""" + + all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id) + assert len(all_passages) == 5 # 3 source + 2 agent passages + + +def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup): + """Test ordering of agent passages""" + + # Test ascending order + asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True) + assert len(asc_passages) == 5 + for i in range(1, len(asc_passages)): + assert asc_passages[i-1].created_at <= asc_passages[i].created_at + + # Test descending order + desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False) + assert len(desc_passages) == 5 + for i in range(1, len(desc_passages)): + assert desc_passages[i-1].created_at >= desc_passages[i].created_at + + +def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup): + """Test pagination of agent passages""" + + # Test limit + limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3) + assert len(limited_passages) == 3 + + # Test cursor-based pagination + first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True) + assert len(first_page) == 2 + + second_page = server.agent_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + cursor=first_page[-1].id, + limit=2, + ascending=True + ) + assert len(second_page) == 2 + assert first_page[-1].id != second_page[0].id + assert first_page[-1].created_at <= second_page[0].created_at + + +def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup): + """Test text search functionality of agent passages""" + + # Test text search for source passages + source_text_passages = server.agent_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + query_text="Source passage" + ) + assert len(source_text_passages) == 3 + + # Test text search for agent passages + agent_text_passages = server.agent_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + query_text="Agent passage" + ) + assert len(agent_text_passages) == 2 + + +def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup): + """Test text search functionality of agent passages""" + + # Test text search for agent passages + agent_text_passages = server.agent_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + agent_only=True + ) + assert len(agent_text_passages) == 2 + + +def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup): + """Test filtering functionality of agent passages""" + + # Test source filtering + source_filtered = server.agent_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + source_id=default_source.id + ) + assert len(source_filtered) == 3 + + # Test date filtering + now = datetime.utcnow() + future_date = now + timedelta(days=1) + past_date = now - timedelta(days=1) + + date_filtered = server.agent_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + start_date=past_date, + end_date=future_date + ) + assert len(date_filtered) == 5 + + +def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source): + """Test vector search functionality of agent passages""" + embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG) + + # Create passages with known embeddings + passages = [] + + # Create passages with different embeddings + test_passages = [ + "I like red", + "random text", + "blue shoes", + ] + + server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) + + for i, text in enumerate(test_passages): + embedding = embed_model.get_text_embedding(text) + if i % 2 == 0: + passage = PydanticPassage( + text=text, + organization_id=default_user.organization_id, + agent_id=sarah_agent.id, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + embedding=embedding + ) + else: + passage = PydanticPassage( + text=text, + organization_id=default_user.organization_id, + source_id=default_source.id, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + embedding=embedding + ) + created_passage = server.passage_manager.create_passage(passage, default_user) + passages.append(created_passage) + + # Query vector similar to "red" embedding + query_key = "What's my favorite color?" + + # Test vector search with all passages + results = server.agent_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + query_text=query_key, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + embed_query=True, + ) + + # Verify results are ordered by similarity + assert len(results) == 3 + assert results[0].text == "I like red" + assert "random" in results[1].text or "random" in results[2].text + assert "blue" in results[1].text or "blue" in results[2].text + + # Test vector search with agent_only=True + agent_only_results = server.agent_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + query_text=query_key, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + embed_query=True, + agent_only=True + ) + + # Verify agent-only results + assert len(agent_only_results) == 2 + assert agent_only_results[0].text == "I like red" + assert agent_only_results[1].text == "blue shoes" + + +def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup): + """Test listing passages from a source without specifying an agent.""" + + # List passages by source_id without agent_id + source_passages = server.agent_manager.list_passages( + actor=default_user, + source_id=default_source.id, + ) + + # Verify we get only source passages (3 from agent_passages_setup) + assert len(source_passages) == 3 + assert all(p.source_id == default_source.id for p in source_passages) + assert all(p.agent_id is None for p in source_passages) + + # ====================================================================================================================== # Organization Manager Tests # ====================================================================================================================== @@ -900,266 +1180,86 @@ def test_list_organizations_pagination(server: SyncServer): # Passage Manager Tests # ====================================================================================================================== - -def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user): - """Test creating a passage using hello_world_passage_fixture fixture""" - assert hello_world_passage_fixture.id is not None - assert hello_world_passage_fixture.text == "Hello, world!" +def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user): + """Test creating a passage using agent_passage_fixture fixture""" + assert agent_passage_fixture.id is not None + assert agent_passage_fixture.text == "Hello, I am an agent passage" # Verify we can retrieve it retrieved = server.passage_manager.get_passage_by_id( - hello_world_passage_fixture.id, + agent_passage_fixture.id, actor=default_user, ) assert retrieved is not None - assert retrieved.id == hello_world_passage_fixture.id - assert retrieved.text == hello_world_passage_fixture.text + assert retrieved.id == agent_passage_fixture.id + assert retrieved.text == agent_passage_fixture.text -def test_passage_get_by_id(server: SyncServer, hello_world_passage_fixture, default_user): - """Test retrieving a passage by ID""" - retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user) +def test_passage_create_source(server: SyncServer, source_passage_fixture, default_user): + """Test creating a source passage.""" + assert source_passage_fixture is not None + assert source_passage_fixture.text == "Hello, I am a source passage" + + # Verify we can retrieve it + retrieved = server.passage_manager.get_passage_by_id( + source_passage_fixture.id, + actor=default_user, + ) assert retrieved is not None - assert retrieved.id == hello_world_passage_fixture.id - assert retrieved.text == hello_world_passage_fixture.text + assert retrieved.id == source_passage_fixture.id + assert retrieved.text == source_passage_fixture.text -def test_passage_update(server: SyncServer, hello_world_passage_fixture, default_user): - """Test updating a passage""" - new_text = "Updated text" - hello_world_passage_fixture.text = new_text - updated = server.passage_manager.update_passage_by_id(hello_world_passage_fixture.id, hello_world_passage_fixture, actor=default_user) - assert updated is not None - assert updated.text == new_text - retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user) - assert retrieved.text == new_text - - -def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user): - """Test deleting a passage""" - server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user) - with pytest.raises(NoResultFound): - server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user) - - -def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user): - """Test counting passages with filters""" - base_passage = hello_world_passage_fixture - - # Test total count - total = server.passage_manager.size(actor=default_user) - assert total == 5 # base passage + 4 test passages - # TODO: change login passage to be a system not user passage - - # Test count with agent filter - agent_count = server.passage_manager.size(actor=default_user, agent_id=base_passage.agent_id) - assert agent_count == 5 - - # Test count with role filter - role_count = server.passage_manager.size(actor=default_user) - assert role_count == 5 - - # Test count with non-existent filter - empty_count = server.passage_manager.size(actor=default_user, agent_id="non-existent") - assert empty_count == 0 - - -def test_passage_listing_basic(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user): - """Test basic passage listing with limit""" - results = server.passage_manager.list_passages(actor=default_user, limit=3) - assert len(results) == 3 - - -def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user): - """Test cursor-based pagination functionality""" - - # Make sure there are 5 passages - assert server.passage_manager.size(actor=default_user) == 5 - - # Get first page - first_page = server.passage_manager.list_passages(actor=default_user, limit=3) - assert len(first_page) == 3 - - last_id_on_first_page = first_page[-1].id - - # Get second page - second_page = server.passage_manager.list_passages(actor=default_user, cursor=last_id_on_first_page, limit=3) - assert len(second_page) == 2 # Should have 2 remaining passages - assert all(r1.id != r2.id for r1 in first_page for r2 in second_page) - - -def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent): - """Test filtering passages by agent ID""" - agent_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, limit=10) - assert len(agent_results) == 5 # base passage + 4 test passages - assert all(msg.agent_id == hello_world_passage_fixture.agent_id for msg in agent_results) - - -def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent): - """Test searching passages by text content""" - search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10) - assert len(search_results) == 4 - assert all("Test passage" in msg.text for msg in search_results) - - # Test no results - search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10) - assert len(search_results) == 0 - - -def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_passage_fixture, default_user, default_file, sarah_agent): - """Test filtering passages by date range with various scenarios""" - # Set up test data with known dates - base_time = datetime.utcnow() - - # Create passages at different times - passages = [] - time_offsets = [ - timedelta(days=-2), # 2 days ago - timedelta(days=-1), # Yesterday - timedelta(hours=-2), # 2 hours ago - timedelta(minutes=-30), # 30 minutes ago - timedelta(minutes=-1), # 1 minute ago - timedelta(minutes=0), # Now - ] - - for i, offset in enumerate(time_offsets): - timestamp = base_time + offset - passage = server.passage_manager.create_passage( +def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user): + """Test creating an agent passage.""" + assert agent_passage_fixture is not None + assert agent_passage_fixture.text == "Hello, I am an agent passage" + + # Try to create an invalid passage (with both agent_id and source_id) + with pytest.raises(AssertionError): + server.passage_manager.create_passage( PydanticPassage( + text="Invalid passage", + agent_id="123", + source_id="456", organization_id=default_user.organization_id, - agent_id=sarah_agent.id, - file_id=default_file.id, - text=f"Test passage {i}", - embedding=[0.1, 0.2, 0.3], + embedding=[0.1] * 1024, embedding_config=DEFAULT_EMBEDDING_CONFIG, - created_at=timestamp, ), - actor=default_user, - ) - passages.append(passage) - - # Test cases - test_cases = [ - { - "name": "Recent passages (last hour)", - "start_date": base_time - timedelta(hours=1), - "end_date": base_time + timedelta(minutes=1), - "expected_count": 1 + 3, # Should include base + -30min, -1min, and now - }, - { - "name": "Yesterday's passages", - "start_date": base_time - timedelta(days=1, hours=12), - "end_date": base_time - timedelta(hours=12), - "expected_count": 1, # Should only include yesterday's passage - }, - { - "name": "Future time range", - "start_date": base_time + timedelta(days=1), - "end_date": base_time + timedelta(days=2), - "expected_count": 0, # Should find no passages - }, - { - "name": "All time", - "start_date": base_time - timedelta(days=3), - "end_date": base_time + timedelta(days=1), - "expected_count": 1 + len(passages), # Should find all passages - }, - { - "name": "Exact timestamp match", - "start_date": passages[0].created_at - timedelta(microseconds=1), - "end_date": passages[0].created_at + timedelta(microseconds=1), - "expected_count": 1, # Should find exactly one passage - }, - { - "name": "Small time window", - "start_date": base_time - timedelta(seconds=30), - "end_date": base_time + timedelta(seconds=30), - "expected_count": 1 + 1, # date + "now" - }, - ] - - # Run test cases - for case in test_cases: - results = server.passage_manager.list_passages( - agent_id=sarah_agent.id, actor=default_user, start_date=case["start_date"], end_date=case["end_date"], limit=10 + actor=default_user ) - # Verify count - assert ( - len(results) == case["expected_count"] - ), f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}" - # Test edge cases +def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user): + """Test retrieving a passage by ID""" + retrieved = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, actor=default_user) + assert retrieved is not None + assert retrieved.id == agent_passage_fixture.id + assert retrieved.text == agent_passage_fixture.text - # Test with start_date but no end_date - results_start_only = server.passage_manager.list_passages( - agent_id=sarah_agent.id, actor=default_user, start_date=base_time - timedelta(minutes=2), end_date=None, limit=10 - ) - assert len(results_start_only) >= 2, "Should find passages after start_date" - - # Test with end_date but no start_date - results_end_only = server.passage_manager.list_passages( - agent_id=sarah_agent.id, actor=default_user, start_date=None, end_date=base_time - timedelta(days=1), limit=10 - ) - assert len(results_end_only) >= 1, "Should find passages before end_date" - - # Test limit enforcement - limited_results = server.passage_manager.list_passages( - agent_id=sarah_agent.id, - actor=default_user, - start_date=base_time - timedelta(days=3), - end_date=base_time + timedelta(days=1), - limit=3, - ) - assert len(limited_results) <= 3, "Should respect the limit parameter" + retrieved = server.passage_manager.get_passage_by_id(source_passage_fixture.id, actor=default_user) + assert retrieved is not None + assert retrieved.id == source_passage_fixture.id + assert retrieved.text == source_passage_fixture.text -def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent): - """Test vector search functionality for passages.""" - passage_manager = server.passage_manager - embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG) - - # Create passages with known embeddings - passages = [] - - # Create passages with different embeddings - test_passages = [ - "I like red", - "random text", - "blue shoes", - ] - - for text in test_passages: - embedding = embed_model.get_text_embedding(text) - passage = PydanticPassage( - text=text, - organization_id=default_user.organization_id, - agent_id=sarah_agent.id, - embedding_config=DEFAULT_EMBEDDING_CONFIG, - embedding=embedding, - ) - created_passage = passage_manager.create_passage(passage, default_user) - passages.append(created_passage) - assert passage_manager.size(actor=default_user) == len(passages) - - # Query vector similar to "cats" embedding - query_key = "What's my favorite color?" - - # List passages with vector search - results = passage_manager.list_passages( - actor=default_user, - agent_id=sarah_agent.id, - query_text=query_key, - limit=3, - embedding_config=DEFAULT_EMBEDDING_CONFIG, - embed_query=True, - ) - - # Verify results are ordered by similarity - assert len(results) == 3 - assert results[0].text == "I like red" - assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes" - assert results[2].text == "blue shoes" +def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent): + """Test that passages are deleted when their parent (agent or source) is deleted.""" + # Verify passages exist + agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user) + source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user) + assert agent_passage is not None + assert source_passage is not None + + # Delete agent and verify its passages are deleted + server.agent_manager.delete_agent(sarah_agent.id, default_user) + agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True) + assert len(agentic_passages) == 0 + + # Delete source and verify its passages are deleted + server.source_manager.delete_source(default_source.id, default_user) + with pytest.raises(NoResultFound): + server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user) # ====================================================================================================================== @@ -1220,6 +1320,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ assert print_tool.organization_id == default_organization.id + @pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.") def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization): data = print_tool.model_dump(exclude=["id"]) @@ -1787,6 +1888,7 @@ def test_update_source_no_changes(server: SyncServer, default_user): # ====================================================================================================================== # Source Manager Tests - Files # ====================================================================================================================== + def test_get_file_by_id(server: SyncServer, default_user, default_source): """Test retrieving a file by ID.""" file_metadata = PydanticFileMetadata( @@ -1857,6 +1959,7 @@ def test_delete_file(server: SyncServer, default_user, default_source): # ====================================================================================================================== # SandboxConfigManager Tests - Sandbox Configs # ====================================================================================================================== + def test_create_or_update_sandbox_config(server: SyncServer, default_user): sandbox_config_create = SandboxConfigCreate( config=E2BSandboxConfig(), @@ -1935,6 +2038,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user): # ====================================================================================================================== # SandboxConfigManager Tests - Environment Variables # ====================================================================================================================== + def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user): env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.") created_env_var = server.sandbox_config_manager.create_sandbox_env_var( @@ -2007,7 +2111,6 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, # JobManager Tests # ====================================================================================================================== - def test_create_job(server: SyncServer, default_user): """Test creating a job.""" job_data = PydanticJob( diff --git a/tests/test_server.py b/tests/test_server.py index 09dfb94c..0718e6ce 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -390,12 +390,16 @@ def test_user_message_memory(server, user_id, agent_id): @pytest.mark.order(3) def test_load_data(server, user_id, agent_id): + user = server.user_manager.get_user_or_default(user_id=user_id) + # create source - passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000) + passages_before = server.agent_manager.list_passages( + actor=user, agent_id=agent_id, cursor=None, limit=10000 + ) assert len(passages_before) == 0 source = server.source_manager.create_source( - PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=server.default_user + PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user ) # load data @@ -409,15 +413,11 @@ def test_load_data(server, user_id, agent_id): connector = DummyDataConnector(archival_memories) server.load_data(user_id, connector, source.name) - # @pytest.mark.order(3) - # def test_attach_source_to_agent(server, user_id, agent_id): - # check archival memory size - # attach source server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source") # check archival memory size - passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000) + passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000) assert len(passages_after) == 5 @@ -465,7 +465,7 @@ def test_get_archival_memory(server, user_id, agent_id): user = server.user_manager.get_user_by_id(user_id=user_id) # List latest 2 passages - passages_1 = server.passage_manager.list_passages( + passages_1 = server.agent_manager.list_passages( actor=user, agent_id=agent_id, ascending=False, @@ -475,7 +475,7 @@ def test_get_archival_memory(server, user_id, agent_id): # List next 3 passages (earliest 3) cursor1 = passages_1[-1].id - passages_2 = server.passage_manager.list_passages( + passages_2 = server.agent_manager.list_passages( actor=user, agent_id=agent_id, ascending=False, @@ -484,24 +484,28 @@ def test_get_archival_memory(server, user_id, agent_id): # List all 5 cursor2 = passages_1[0].created_at - passages_3 = server.passage_manager.list_passages( + passages_3 = server.agent_manager.list_passages( actor=user, agent_id=agent_id, ascending=False, end_date=cursor2, limit=1000, ) - # assert passages_1[0].text == "Cinderella wore a blue dress" assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test + latest = passages_1[0] + earliest = passages_2[-1] + # test archival memory - passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1) + passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True) assert len(passage_1) == 1 - passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000) + assert passage_1[0].text == "alpha" + passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True) assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test + assert all("alpha" not in passage.text for passage in passage_2) # test safe empty return - passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000) + passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True) assert len(passage_none) == 0 @@ -955,6 +959,14 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path): actor = server.user_manager.get_user_or_default(user_id) + existing_sources = server.source_manager.list_sources(actor=actor) + if len(existing_sources) > 0: + for source in existing_sources: + server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor) + initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor) + assert initial_passage_count == 0 + + # Create a source source = server.source_manager.create_source( PydanticSource( @@ -973,10 +985,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot # Attach source to agent first server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor) - # Get initial passage count - initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id) - assert initial_passage_count == 0 - # Create a job for loading the first file job = server.job_manager.create_job( PydanticJob( @@ -1001,7 +1009,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot assert job.metadata_["num_documents"] == 1 # Verify passages were added - first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id) + first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor) assert first_file_passage_count > initial_passage_count # Create a second test file with different content @@ -1032,14 +1040,13 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot assert job2.metadata_["num_documents"] == 1 # Verify passages were appended (not replaced) - final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id) + final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor) assert final_passage_count > first_file_passage_count # Verify both old and new content is searchable - passages = server.passage_manager.list_passages( - actor=actor, + passages = server.agent_manager.list_passages( agent_id=agent_id, - source_id=source.id, + actor=actor, query_text="what does Timber like to eat", embedding_config=EmbeddingConfig.default_config(provider="openai"), embed_query=True, @@ -1048,35 +1055,27 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot assert any("chicken" in passage.text.lower() for passage in passages) assert any("Anna".lower() in passage.text.lower() for passage in passages) - # TODO: Add this test back in after separation of `Passage tables` (LET-449) - # # Load second agent - # agent2 = server.load_agent(agent_id=other_agent_id) + # Initially should have no passages + initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id) + assert initial_agent2_passages == 0 - # # Initially should have no passages - # initial_agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id) - # assert initial_agent2_passages == 0 + # Attach source to second agent + server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor) - # # Attach source to second agent - # agent2.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms) + # Verify second agent has same number of passages as first agent + agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id) + agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id) + assert agent2_passages == agent1_passages - # # Verify second agent has same number of passages as first agent - # agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id) - # agent1_passages = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id) - # assert agent2_passages == agent1_passages - - # # Verify second agent can query the same content - # passages2 = server.passage_manager.list_passages( - # actor=user, - # agent_id=other_agent_id, - # source_id=source.id, - # query_text="what does Timber like to eat", - # embedding_config=EmbeddingConfig.default_config(provider="openai"), - # embed_query=True, - # limit=10, - # ) - # assert len(passages2) == len(passages) - # assert any("chicken" in passage.text.lower() for passage in passages2) - # assert any("sleep" in passage.text.lower() for passage in passages2) - - # # Cleanup - # server.delete_agent(user_id=user_id, agent_id=agent2_state.id) + # Verify second agent can query the same content + passages2 = server.agent_manager.list_passages( + actor=actor, + agent_id=other_agent_id, + source_id=source.id, + query_text="what does Timber like to eat", + embedding_config=EmbeddingConfig.default_config(provider="openai"), + embed_query=True, + ) + assert len(passages2) == len(passages) + assert any("chicken" in passage.text.lower() for passage in passages2) + assert any("Anna".lower() in passage.text.lower() for passage in passages2)