feat: separate Passages tables (#2245)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2024-12-16 15:24:20 -08:00
committed by GitHub
parent 10e610bb95
commit e2d916148e
19 changed files with 1026 additions and 546 deletions

View File

@@ -0,0 +1,105 @@
"""divide passage table into SourcePassages and AgentPassages
Revision ID: 54dec07619c4
Revises: 4e88e702f85e
Create Date: 2024-12-14 17:23:08.772554
"""
from typing import Sequence, Union
from alembic import op
from pgvector.sqlalchemy import Vector
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from letta.orm.custom_columns import EmbeddingConfigColumn
# revision identifiers, used by Alembic.
revision: str = '54dec07619c4'
down_revision: Union[str, None] = '4e88e702f85e'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
'agent_passages',
sa.Column('id', sa.String(), nullable=False),
sa.Column('text', sa.String(), nullable=False),
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
sa.Column('metadata_', sa.JSON(), nullable=False),
sa.Column('embedding', Vector(dim=4096), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
sa.Column('_created_by_id', sa.String(), nullable=True),
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
sa.Column('organization_id', sa.String(), nullable=False),
sa.Column('agent_id', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False)
op.create_table(
'source_passages',
sa.Column('id', sa.String(), nullable=False),
sa.Column('text', sa.String(), nullable=False),
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
sa.Column('metadata_', sa.JSON(), nullable=False),
sa.Column('embedding', Vector(dim=4096), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
sa.Column('_created_by_id', sa.String(), nullable=True),
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
sa.Column('organization_id', sa.String(), nullable=False),
sa.Column('file_id', sa.String(), nullable=True),
sa.Column('source_id', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False)
op.drop_table('passages')
op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey')
op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE')
op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey')
op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'messages', type_='foreignkey')
op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id'])
op.drop_constraint(None, 'files', type_='foreignkey')
op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id'])
op.create_table(
'passages',
sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True),
sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True),
sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'),
sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'),
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'),
sa.PrimaryKeyConstraint('id', name='passages_pkey')
)
op.drop_index('source_passages_org_idx', table_name='source_passages')
op.drop_table('source_passages')
op.drop_index('agent_passages_org_idx', table_name='agent_passages')
op.drop_table('agent_passages')
# ### end Alembic commands ###

View File

@@ -41,7 +41,6 @@ from letta.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage,
)
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.passage import Passage
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
@@ -82,7 +81,7 @@ def compile_memory_metadata_block(
actor: PydanticUser,
agent_id: str,
memory_edit_timestamp: datetime.datetime,
passage_manager: Optional[PassageManager] = None,
agent_manager: Optional[AgentManager] = None,
message_manager: Optional[MessageManager] = None,
) -> str:
# Put the timestamp in the local timezone (mimicking get_local_time())
@@ -93,7 +92,7 @@ def compile_memory_metadata_block(
[
f"### Memory [last modified: {timestamp_str}]",
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
f"{agent_manager.passage_size(actor=actor, agent_id=agent_id) if agent_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
]
)
@@ -106,7 +105,7 @@ def compile_system_message(
in_context_memory: Memory,
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
actor: PydanticUser,
passage_manager: Optional[PassageManager] = None,
agent_manager: Optional[AgentManager] = None,
message_manager: Optional[MessageManager] = None,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
@@ -135,7 +134,7 @@ def compile_system_message(
actor=actor,
agent_id=agent_id,
memory_edit_timestamp=in_context_memory_last_edit,
passage_manager=passage_manager,
agent_manager=agent_manager,
message_manager=message_manager,
)
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
@@ -172,7 +171,7 @@ def initialize_message_sequence(
agent_id: str,
memory: Memory,
actor: PydanticUser,
passage_manager: Optional[PassageManager] = None,
agent_manager: Optional[AgentManager] = None,
message_manager: Optional[MessageManager] = None,
memory_edit_timestamp: Optional[datetime.datetime] = None,
include_initial_boot_message: bool = True,
@@ -181,7 +180,7 @@ def initialize_message_sequence(
memory_edit_timestamp = get_local_time()
# full_system_message = construct_system_with_memory(
# system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory
# system, memory, memory_edit_timestamp, agent_manager=agent_manager, recall_memory=recall_memory
# )
full_system_message = compile_system_message(
agent_id=agent_id,
@@ -189,7 +188,7 @@ def initialize_message_sequence(
in_context_memory=memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=actor,
passage_manager=passage_manager,
agent_manager=agent_manager,
message_manager=message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
@@ -291,8 +290,9 @@ class Agent(BaseAgent):
self.interface = interface
# Create the persistence manager object based on the AgentState info
self.passage_manager = PassageManager()
self.message_manager = MessageManager()
self.passage_manager = PassageManager()
self.agent_manager = AgentManager()
# State needed for heartbeat pausing
self.pause_heartbeats_start = None
@@ -322,7 +322,7 @@ class Agent(BaseAgent):
agent_id=self.agent_state.id,
memory=self.agent_state.memory,
actor=self.user,
passage_manager=None,
agent_manager=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
@@ -347,7 +347,7 @@ class Agent(BaseAgent):
memory=self.agent_state.memory,
agent_id=self.agent_state.id,
actor=self.user,
passage_manager=None,
agent_manager=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
@@ -1297,7 +1297,7 @@ class Agent(BaseAgent):
in_context_memory=self.agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=self.user,
passage_manager=self.passage_manager,
agent_manager=self.agent_manager,
message_manager=self.message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
@@ -1368,33 +1368,24 @@ class Agent(BaseAgent):
source_id: str,
source_manager: SourceManager,
agent_manager: AgentManager,
page_size: Optional[int] = None,
):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)
for passage in passages:
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
passage.agent_id = self.agent_state.id
self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user)
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
assert len(agents_passages) == passage_size # sanity check
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"
# attach to agent
"""Attach a source to the agent using the SourcesAgents ORM relationship.
Args:
user: User performing the action
source_id: ID of the source to attach
source_manager: SourceManager instance to verify source exists
agent_manager: AgentManager instance to manage agent-source relationship
"""
# Verify source exists and user has permission to access it
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in metadata store"
assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})"
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
# TODO: delete @matt and remove
# Use the agent_manager to create the relationship
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)
printd(
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
f"Attached data source {source.name} to agent {self.agent_state.name}.",
)
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
@@ -1550,13 +1541,13 @@ class Agent(BaseAgent):
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
)
passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id)
agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id)
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
external_memory_summary = compile_memory_metadata_block(
actor=self.user,
agent_id=self.agent_state.id,
memory_edit_timestamp=get_utc_time(), # dummy timestamp
passage_manager=self.passage_manager,
agent_manager=self.agent_manager,
message_manager=self.message_manager,
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
@@ -1582,7 +1573,7 @@ class Agent(BaseAgent):
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(self._messages),
num_archival_memory=passage_manager_size,
num_archival_memory=agent_manager_passage_size,
num_recall_memory=message_manager_size,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
# top-level information

View File

@@ -3,6 +3,7 @@ from typing import Optional
from letta.agent import Agent
from letta.constants import MAX_PAUSE_HEARTBEATS
from letta.services.agent_manager import AgentManager
# import math
# from letta.utils import json_dumps
@@ -200,8 +201,9 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
try:
# Get results using passage manager
all_results = self.passage_manager.list_passages(
all_results = self.agent_manager.list_passages(
actor=self.user,
agent_id=self.agent_state.id,
query_text=query,
limit=count + start, # Request enough results to handle offset
embedding_config=self.agent_state.embedding_config,

View File

@@ -312,11 +312,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
for param in sig.parameters.values():
# Exclude 'self' parameter
# TODO: eventually remove this (only applies to BASE_TOOLS)
if param.name == "self":
continue
# exclude 'agent_state' parameter
if param.name == "agent_state":
if param.name in ["self", "agent_state"]: # Add agent_manager to excluded
continue
# Assert that the parameter has a type annotation

View File

@@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata
from letta.orm.job import Job
from letta.orm.message import Message
from letta.orm.organization import Organization
from letta.orm.passage import Passage
from letta.orm.passage import BasePassage, AgentPassage, SourcePassage
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents

View File

@@ -82,7 +82,25 @@ class Agent(SqlalchemyBase, OrganizationMixin):
lazy="selectin",
doc="Tags associated with the agent.",
)
# passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="agent", lazy="selectin")
source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage",
secondary="sources_agents", # The join table for Agent -> Source
primaryjoin="Agent.id == sources_agents.c.agent_id",
secondaryjoin="and_(SourcePassage.source_id == sources_agents.c.source_id)",
lazy="selectin",
order_by="SourcePassage.created_at.desc()",
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
doc="All passages derived from sources associated with this agent.",
)
agent_passages: Mapped[List["AgentPassage"]] = relationship(
"AgentPassage",
back_populates="agent",
lazy="selectin",
order_by="AgentPassage.created_at.desc()",
cascade="all, delete-orphan",
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
doc="All passages derived created by this agent.",
)
def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart"""

View File

@@ -9,7 +9,8 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.source import Source
from letta.orm.passage import SourcePassage
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
"""Represents metadata for an uploaded file."""
@@ -27,4 +28,4 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")

View File

@@ -31,30 +31,19 @@ class UserMixin(Base):
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
class FileMixin(Base):
"""Mixin for models that belong to a file."""
__abstract__ = True
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id"))
class AgentMixin(Base):
"""Mixin for models that belong to an agent."""
__abstract__ = True
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"))
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"))
class FileMixin(Base):
"""Mixin for models that belong to a file."""
__abstract__ = True
file_id: Mapped[Optional[str]] = mapped_column(
String,
ForeignKey("files.id", ondelete="CASCADE"),
nullable=True
)
file_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("files.id", ondelete="CASCADE"))
class SourceMixin(Base):
@@ -62,7 +51,7 @@ class SourceMixin(Base):
__abstract__ = True
source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id"))
source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), nullable=False)
class SandboxConfigMixin(Base):

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Union
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -35,6 +35,22 @@ class Organization(SqlalchemyBase):
)
# relationships
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage",
back_populates="organization",
cascade="all, delete-orphan"
)
agent_passages: Mapped[List["AgentPassage"]] = relationship(
"AgentPassage",
back_populates="organization",
cascade="all, delete-orphan"
)
@property
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
"""Convenience property to get all passages"""
return self.source_passages + self.agent_passages

View File

@@ -1,39 +1,35 @@
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from sqlalchemy import Column, JSON, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import FileMixin, OrganizationMixin
from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.orm.custom_columns import CommonVector
from letta.orm.mixins import FileMixin, OrganizationMixin
from letta.orm.source import EmbeddingConfigColumn
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
config = LettaConfig()
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.agent import Agent
# TODO: After migration to Passage, will need to manually delete passages where files
# are deleted on web
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
"""Defines data model for storing Passages"""
__tablename__ = "passages"
__table_args__ = {"extend_existing": True}
class BasePassage(SqlalchemyBase, OrganizationMixin):
"""Base class for all passage types with common fields"""
__abstract__ = True
__pydantic_model__ = PydanticPassage
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
text: Mapped[str] = mapped_column(doc="Passage text content")
source_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Source identifier")
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
# Vector embedding field based on database type
if settings.letta_pg_uri_no_default:
from pgvector.sqlalchemy import Vector
@@ -41,9 +37,49 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
else:
embedding = Column(CommonVector)
# Foreign keys
agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id"), nullable=True)
@declared_attr
def organization(cls) -> Mapped["Organization"]:
"""Relationship to organization"""
return relationship("Organization", back_populates="passages", lazy="selectin")
# Relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="passages", lazy="selectin")
file: Mapped["FileMetadata"] = relationship("FileMetadata", back_populates="passages", lazy="selectin")
@declared_attr
def __table_args__(cls):
if settings.letta_pg_uri_no_default:
return (
Index(f'{cls.__tablename__}_org_idx', 'organization_id'),
{"extend_existing": True}
)
return ({"extend_existing": True},)
class SourcePassage(BasePassage, FileMixin, SourceMixin):
"""Passages derived from external files/sources"""
__tablename__ = "source_passages"
@declared_attr
def file(cls) -> Mapped["FileMetadata"]:
"""Relationship to file"""
return relationship("FileMetadata", back_populates="source_passages", lazy="selectin")
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="source_passages", lazy="selectin")
@declared_attr
def source(cls) -> Mapped["Source"]:
"""Relationship to source"""
return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True)
class AgentPassage(BasePassage, AgentMixin):
"""Passages created by agents as archival memories"""
__tablename__ = "agent_passages"
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="agent_passages", lazy="selectin")
@declared_attr
def agent(cls) -> Mapped["Agent"]:
"""Relationship to agent"""
return relationship("Agent", back_populates="agent_passages", lazy="selectin", passive_deletes=True)

View File

@@ -12,6 +12,9 @@ from letta.schemas.source import Source as PydanticSource
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.file import FileMetadata
from letta.orm.passage import SourcePassage
from letta.orm.agent import Agent
class Source(SqlalchemyBase, OrganizationMixin):
@@ -28,4 +31,5 @@ class Source(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")

View File

@@ -3,7 +3,7 @@ from enum import Enum
from typing import TYPE_CHECKING, List, Literal, Optional
from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError
from sqlalchemy.exc import DBAPIError, IntegrityError
from sqlalchemy.orm import Mapped, Session, mapped_column
from letta.log import get_logger
@@ -242,7 +242,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
session.commit()
session.refresh(self)
return self
except DBAPIError as e:
except (DBAPIError, IntegrityError) as e:
self._handle_dbapi_error(e)
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":

View File

@@ -10,7 +10,7 @@ from letta.utils import get_utc_time
class PassageBase(OrmMetadataBase):
__id_prefix__ = "passage_legacy"
__id_prefix__ = "passage"
is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")

View File

@@ -932,7 +932,7 @@ class SyncServer(Server):
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
agent = self.load_agent(agent_id=agent_id, actor=actor)
return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user))
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
agent = self.load_agent(agent_id=agent_id, actor=actor)
@@ -949,18 +949,9 @@ class SyncServer(Server):
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
passages = self.agent_manager.list_passages(agent_id=agent_id, actor=actor)
# iterate over records
records = letta_agent.passage_manager.list_passages(
actor=actor,
agent_id=agent_id,
cursor=cursor,
limit=limit,
)
return records
return passages
def get_agent_archival_cursor(
self,
@@ -974,15 +965,13 @@ class SyncServer(Server):
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
# iterate over records
records = letta_agent.passage_manager.list_passages(
actor=self.default_user,
records = self.agent_manager.list_passages(
actor=actor,
agent_id=agent_id,
cursor=cursor,
limit=limit,
ascending=not reverse,
)
return records
@@ -1098,7 +1087,8 @@ class SyncServer(Server):
self.source_manager.delete_source(source_id=source_id, actor=actor)
# delete data from passage store
self.passage_manager.delete_passages(actor=actor, limit=None, source_id=source_id)
passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None)
self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted)
# TODO: delete data from agent passage stores (?)
@@ -1129,9 +1119,11 @@ class SyncServer(Server):
for agent_state in agent_states:
agent_id = agent_state.id
agent = self.load_agent(agent_id=agent_id, actor=actor)
curr_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
# Attach source to agent
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
new_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
assert new_passage_size >= curr_passage_size # in case empty files are added
return job
@@ -1195,14 +1187,9 @@ class SyncServer(Server):
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
elif source_name:
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
source_id = source.id
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
source_id = source.id
# TODO: This should be done with the ORM?
# delete all Passage objects with source_id==source_id from agent's archival memory
agent = self.load_agent(agent_id=agent_id, actor=actor)
agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id)
# delete agent-source mapping
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
@@ -1224,7 +1211,7 @@ class SyncServer(Server):
for source in sources:
# count number of passages
num_passages = self.passage_manager.size(actor=actor, source_id=source.id)
num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id)
# TODO: add when files table implemented
## count number of files

View File

@@ -1,17 +1,26 @@
from typing import Dict, List, Optional
from datetime import datetime
import numpy as np
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from sqlalchemy import select, union_all, literal, func, Select
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model
from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
from letta.orm import Tool as ToolModel
from letta.orm import AgentPassage, SourcePassage
from letta.orm import SourcesAgents
from letta.orm.errors import NoResultFound
from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
@@ -21,9 +30,9 @@ from letta.services.helpers.agent_manager_helper import (
_process_tags,
derive_system_message,
)
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.settings import settings
from letta.utils import enforce_types
logger = get_logger(__name__)
@@ -229,13 +238,6 @@ class AgentManager:
with self.session_maker() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
# TODO: This is done very hacky on purpose
# TODO: 1000 limit is also wack
passage_manager = PassageManager()
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
agent_state = agent.to_pydantic()
agent.hard_delete(session)
return agent_state
@@ -407,6 +409,262 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
# ======================================================================================================================
# Passage Management
# ======================================================================================================================
def _build_passage_query(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> Select:
"""Helper function to build the base passage query with all filters applied.
Returns the query before any limit or count operations are applied.
"""
embedded_text = None
if embed_query:
assert embedding_config is not None, "embedding_config must be specified for vector search"
assert query_text is not None, "query_text must be specified for vector search"
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
embedded_text = np.array(embedded_text)
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
with self.session_maker() as session:
# Start with base query for source passages
source_passages = None
if not agent_only: # Include source passages
if agent_id is not None:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
else:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
.where(SourcePassage.organization_id == actor.organization_id)
)
if source_id:
source_passages = source_passages.where(SourcePassage.source_id == source_id)
if file_id:
source_passages = source_passages.where(SourcePassage.file_id == file_id)
# Add agent passages query
agent_passages = None
if agent_id is not None:
agent_passages = (
select(
AgentPassage.id,
AgentPassage.text,
AgentPassage.embedding_config,
AgentPassage.metadata_,
AgentPassage.embedding,
AgentPassage.created_at,
AgentPassage.updated_at,
AgentPassage.is_deleted,
AgentPassage._created_by_id,
AgentPassage._last_updated_by_id,
AgentPassage.organization_id,
literal(None).label('file_id'),
literal(None).label('source_id'),
AgentPassage.agent_id
)
.where(AgentPassage.agent_id == agent_id)
.where(AgentPassage.organization_id == actor.organization_id)
)
# Combine queries
if source_passages is not None and agent_passages is not None:
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
elif agent_passages is not None:
combined_query = agent_passages.cte('combined_passages')
elif source_passages is not None:
combined_query = source_passages.cte('combined_passages')
else:
raise ValueError("No passages found")
# Build main query from combined CTE
main_query = select(combined_query)
# Apply filters
if start_date:
main_query = main_query.where(combined_query.c.created_at >= start_date)
if end_date:
main_query = main_query.where(combined_query.c.created_at <= end_date)
if source_id:
main_query = main_query.where(combined_query.c.source_id == source_id)
if file_id:
main_query = main_query.where(combined_query.c.file_id == file_id)
# Vector search
if embedded_text:
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
main_query = main_query.order_by(
combined_query.c.embedding.cosine_distance(embedded_text).asc()
)
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(embedded_text)
if ascending:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.asc(),
combined_query.c.id.asc()
)
else:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.desc(),
combined_query.c.id.asc()
)
else:
if query_text:
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
# Handle cursor-based pagination
if cursor:
cursor_query = select(combined_query.c.created_at).where(
combined_query.c.id == cursor
).scalar_subquery()
if ascending:
main_query = main_query.where(
combined_query.c.created_at > cursor_query
)
else:
main_query = main_query.where(
combined_query.c.created_at < cursor_query
)
# Add ordering if not already ordered by similarity
if not embed_query:
if ascending:
main_query = main_query.order_by(
combined_query.c.created_at.asc(),
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
combined_query.c.created_at.desc(),
combined_query.c.id.asc(),
)
return main_query
@enforce_types
def list_passages(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
with self.session_maker() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
cursor=cursor,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Add limit
if limit:
main_query = main_query.limit(limit)
# Execute query
results = list(session.execute(main_query))
passages = []
for row in results:
data = dict(row._mapping)
if data['agent_id'] is not None:
# This is an AgentPassage - remove source fields
data.pop('source_id', None)
data.pop('file_id', None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop('agent_id', None)
passage = SourcePassage(**data)
passages.append(passage)
return [p.to_pydantic() for p in passages]
@enforce_types
def passage_size(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
) -> int:
"""Returns the count of passages matching the given criteria."""
with self.session_maker() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
cursor=cursor,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Convert to count query
count_query = select(func.count()).select_from(main_query.subquery())
return session.scalar(count_query) or 0
# ======================================================================================================================
# Tool Management
# ======================================================================================================================

View File

@@ -1,12 +1,13 @@
from datetime import datetime
from typing import List, Optional
from datetime import datetime
import numpy as np
from sqlalchemy import select, union_all, literal
from letta.constants import MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model, parse_and_chunk_text
from letta.orm.errors import NoResultFound
from letta.orm.passage import Passage as PassageModel
from letta.orm.passage import AgentPassage, SourcePassage
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.passage import Passage as PydanticPassage
@@ -14,6 +15,7 @@ from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class PassageManager:
"""Manager class to handle business logic related to Passages."""
@@ -26,14 +28,51 @@ class PassageManager:
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch a passage by ID."""
with self.session_maker() as session:
passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor)
return passage.to_pydantic()
# Try source passages first
try:
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
return passage.to_pydantic()
except NoResultFound:
# Try archival passages
try:
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
return passage.to_pydantic()
except NoResultFound:
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
@enforce_types
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new passage."""
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
# Common fields for both passage types
data = pydantic_passage.model_dump()
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata_", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.utcnow()),
}
if "agent_id" in data and data["agent_id"]:
assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
agent_fields = {
"agent_id": data["agent_id"],
}
passage = AgentPassage(**common_fields, **agent_fields)
elif "source_id" in data and data["source_id"]:
assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
source_fields = {
"source_id": data["source_id"],
"file_id": data.get("file_id"),
}
passage = SourcePassage(**common_fields, **source_fields)
else:
raise ValueError("Passage must have either agent_id or source_id")
with self.session_maker() as session:
passage = PassageModel(**pydantic_passage.model_dump())
passage.create(session, actor=actor)
return passage.to_pydantic()
@@ -93,14 +132,23 @@ class PassageManager:
raise ValueError("Passage ID must be provided.")
with self.session_maker() as session:
# Fetch existing message from database
curr_passage = PassageModel.read(
db_session=session,
identifier=passage_id,
actor=actor,
)
if not curr_passage:
raise ValueError(f"Passage with id {passage_id} does not exist.")
# Try source passages first
try:
curr_passage = SourcePassage.read(
db_session=session,
identifier=passage_id,
actor=actor,
)
except NoResultFound:
# Try agent passages
try:
curr_passage = AgentPassage.read(
db_session=session,
identifier=passage_id,
actor=actor,
)
except NoResultFound:
raise ValueError(f"Passage with id {passage_id} does not exist.")
# Update the database record with values from the provided record
update_data = passage.model_dump(exclude_unset=True, exclude_none=True)
@@ -113,104 +161,32 @@ class PassageManager:
@enforce_types
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
"""Delete a passage."""
"""Delete a passage from either source or archival passages."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
with self.session_maker() as session:
# Try source passages first
try:
passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor)
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
passage.hard_delete(session, actor=actor)
return True
except NoResultFound:
raise ValueError(f"Passage with id {passage_id} not found.")
@enforce_types
def list_passages(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
cursor: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
ascending: bool = True,
source_id: Optional[str] = None,
embed_query: bool = False,
embedding_config: Optional[EmbeddingConfig] = None,
) -> List[PydanticPassage]:
"""List passages with pagination."""
with self.session_maker() as session:
filters = {"organization_id": actor.organization_id}
if agent_id:
filters["agent_id"] = agent_id
if file_id:
filters["file_id"] = file_id
if source_id:
filters["source_id"] = source_id
embedded_text = None
if embed_query:
assert embedding_config is not None
# Embed the text
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
# Pad the embedding with zeros
embedded_text = np.array(embedded_text)
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
results = PassageModel.list(
db_session=session,
cursor=cursor,
start_date=start_date,
end_date=end_date,
limit=limit,
ascending=ascending,
query_text=query_text if not embedded_text else None,
query_embedding=embedded_text,
**filters,
)
return [p.to_pydantic() for p in results]
@enforce_types
def size(self, actor: PydanticUser, agent_id: Optional[str] = None, **kwargs) -> int:
"""Get the total count of messages with optional filters.
Args:
actor : The user requesting the count
agent_id: The agent ID
"""
with self.session_maker() as session:
return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs)
# Try archival passages
try:
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
passage.hard_delete(session, actor=actor)
return True
except NoResultFound:
raise NoResultFound(f"Passage with id {passage_id} not found.")
def delete_passages(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
cursor: Optional[str] = None,
query_text: Optional[str] = None,
source_id: Optional[str] = None,
passages: List[PydanticPassage],
) -> bool:
passages = self.list_passages(
actor=actor,
agent_id=agent_id,
file_id=file_id,
cursor=cursor,
limit=limit,
start_date=start_date,
end_date=end_date,
query_text=query_text,
source_id=source_id,
)
# TODO: This is very inefficient
# TODO: We should have a base `delete_all_matching_filters`-esque function
for passage in passages:
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
return True

View File

@@ -482,7 +482,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# check agent archival memory size
archival_memories = client.get_archival_memory(agent_id=agent.id)
print(archival_memories)
assert len(archival_memories) == 0
# load a file into a source (non-blocking job)

View File

@@ -2,6 +2,8 @@ import os
import time
from datetime import datetime, timedelta
from httpx._transports import default
from numpy import source
import pytest
from sqlalchemy import delete
from sqlalchemy.exc import IntegrityError
@@ -17,7 +19,8 @@ from letta.orm import (
Job,
Message,
Organization,
Passage,
AgentPassage,
SourcePassage,
SandboxConfig,
SandboxEnvironmentVariable,
Source,
@@ -82,7 +85,8 @@ def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Message))
session.execute(delete(Passage))
session.execute(delete(AgentPassage))
session.execute(delete(SourcePassage))
session.execute(delete(Job))
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
session.execute(delete(BlocksAgents))
@@ -189,39 +193,79 @@ def print_tool(server: SyncServer, default_user, default_organization):
@pytest.fixture
def hello_world_passage_fixture(server: SyncServer, default_user, default_file, sarah_agent):
"""Fixture to create a tool with default settings and clean up after the test."""
# Set up passage
dummy_embedding = [0.0] * 2
message = PydanticPassage(
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
file_id=default_file.id,
text="Hello, world!",
embedding=dummy_embedding,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
"""Fixture to create an agent passage."""
passage = server.passage_manager.create_passage(
PydanticPassage(
text="Hello, I am an agent passage",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
),
actor=default_user
)
msg = server.passage_manager.create_passage(message, actor=default_user)
yield msg
yield passage
@pytest.fixture
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent) -> list[PydanticPassage]:
"""Helper function to create test passages for all tests"""
dummy_embedding = [0] * 2
passages = [
def source_passage_fixture(server: SyncServer, default_user, default_file, default_source):
"""Fixture to create a source passage."""
passage = server.passage_manager.create_passage(
PydanticPassage(
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
text="Hello, I am a source passage",
source_id=default_source.id,
file_id=default_file.id,
text=f"Test passage {i}",
embedding=dummy_embedding,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
),
actor=default_user
)
yield passage
@pytest.fixture
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent, default_source):
"""Helper function to create test passages for all tests."""
# Create agent passages
passages = []
for i in range(5):
passage = server.passage_manager.create_passage(
PydanticPassage(
text=f"Agent passage {i}",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
),
actor=default_user
)
for i in range(4)
]
server.passage_manager.create_many_passages(passages, actor=default_user)
passages.append(passage)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
# Create source passages
for i in range(5):
passage = server.passage_manager.create_passage(
PydanticPassage(
text=f"Source passage {i}",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
),
actor=default_user
)
passages.append(passage)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
return passages
@@ -389,6 +433,49 @@ def server():
return server
@pytest.fixture
def agent_passages_setup(server, default_source, default_user, sarah_agent):
"""Setup fixture for agent passages tests"""
agent_id = sarah_agent.id
actor = default_user
server.agent_manager.attach_source(agent_id=agent_id, source_id=default_source.id, actor=actor)
# Create some source passages
source_passages = []
for i in range(3):
passage = server.passage_manager.create_passage(
PydanticPassage(
organization_id=actor.organization_id,
source_id=default_source.id,
text=f"Source passage {i}",
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=actor
)
source_passages.append(passage)
# Create some agent passages
agent_passages = []
for i in range(2):
passage = server.passage_manager.create_passage(
PydanticPassage(
organization_id=actor.organization_id,
agent_id=agent_id,
text=f"Agent passage {i}",
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=actor
)
agent_passages.append(passage)
yield agent_passages, source_passages
# Cleanup
server.source_manager.delete_source(default_source.id, actor=actor)
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -849,6 +936,199 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
assert block.label == default_block.label
# ======================================================================================================================
# Agent Manager - Passages Tests
# ======================================================================================================================
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
"""Test basic listing functionality of agent passages"""
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
assert len(all_passages) == 5 # 3 source + 2 agent passages
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
"""Test ordering of agent passages"""
# Test ascending order
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
assert len(asc_passages) == 5
for i in range(1, len(asc_passages)):
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
# Test descending order
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
assert len(desc_passages) == 5
for i in range(1, len(desc_passages)):
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
"""Test pagination of agent passages"""
# Test limit
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
assert len(limited_passages) == 3
# Test cursor-based pagination
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
assert len(first_page) == 2
second_page = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
cursor=first_page[-1].id,
limit=2,
ascending=True
)
assert len(second_page) == 2
assert first_page[-1].id != second_page[0].id
assert first_page[-1].created_at <= second_page[0].created_at
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
"""Test text search functionality of agent passages"""
# Test text search for source passages
source_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text="Source passage"
)
assert len(source_text_passages) == 3
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text="Agent passage"
)
assert len(agent_text_passages) == 2
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
"""Test text search functionality of agent passages"""
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
agent_only=True
)
assert len(agent_text_passages) == 2
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
"""Test filtering functionality of agent passages"""
# Test source filtering
source_filtered = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
source_id=default_source.id
)
assert len(source_filtered) == 3
# Test date filtering
now = datetime.utcnow()
future_date = now + timedelta(days=1)
past_date = now - timedelta(days=1)
date_filtered = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
start_date=past_date,
end_date=future_date
)
assert len(date_filtered) == 5
def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source):
"""Test vector search functionality of agent passages"""
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
# Create passages with known embeddings
passages = []
# Create passages with different embeddings
test_passages = [
"I like red",
"random text",
"blue shoes",
]
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
for i, text in enumerate(test_passages):
embedding = embed_model.get_text_embedding(text)
if i % 2 == 0:
passage = PydanticPassage(
text=text,
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding
)
else:
passage = PydanticPassage(
text=text,
organization_id=default_user.organization_id,
source_id=default_source.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding
)
created_passage = server.passage_manager.create_passage(passage, default_user)
passages.append(created_passage)
# Query vector similar to "red" embedding
query_key = "What's my favorite color?"
# Test vector search with all passages
results = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text=query_key,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
)
# Verify results are ordered by similarity
assert len(results) == 3
assert results[0].text == "I like red"
assert "random" in results[1].text or "random" in results[2].text
assert "blue" in results[1].text or "blue" in results[2].text
# Test vector search with agent_only=True
agent_only_results = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text=query_key,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
agent_only=True
)
# Verify agent-only results
assert len(agent_only_results) == 2
assert agent_only_results[0].text == "I like red"
assert agent_only_results[1].text == "blue shoes"
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
"""Test listing passages from a source without specifying an agent."""
# List passages by source_id without agent_id
source_passages = server.agent_manager.list_passages(
actor=default_user,
source_id=default_source.id,
)
# Verify we get only source passages (3 from agent_passages_setup)
assert len(source_passages) == 3
assert all(p.source_id == default_source.id for p in source_passages)
assert all(p.agent_id is None for p in source_passages)
# ======================================================================================================================
# Organization Manager Tests
# ======================================================================================================================
@@ -900,266 +1180,86 @@ def test_list_organizations_pagination(server: SyncServer):
# Passage Manager Tests
# ======================================================================================================================
def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user):
"""Test creating a passage using hello_world_passage_fixture fixture"""
assert hello_world_passage_fixture.id is not None
assert hello_world_passage_fixture.text == "Hello, world!"
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
"""Test creating a passage using agent_passage_fixture fixture"""
assert agent_passage_fixture.id is not None
assert agent_passage_fixture.text == "Hello, I am an agent passage"
# Verify we can retrieve it
retrieved = server.passage_manager.get_passage_by_id(
hello_world_passage_fixture.id,
agent_passage_fixture.id,
actor=default_user,
)
assert retrieved is not None
assert retrieved.id == hello_world_passage_fixture.id
assert retrieved.text == hello_world_passage_fixture.text
assert retrieved.id == agent_passage_fixture.id
assert retrieved.text == agent_passage_fixture.text
def test_passage_get_by_id(server: SyncServer, hello_world_passage_fixture, default_user):
"""Test retrieving a passage by ID"""
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
def test_passage_create_source(server: SyncServer, source_passage_fixture, default_user):
"""Test creating a source passage."""
assert source_passage_fixture is not None
assert source_passage_fixture.text == "Hello, I am a source passage"
# Verify we can retrieve it
retrieved = server.passage_manager.get_passage_by_id(
source_passage_fixture.id,
actor=default_user,
)
assert retrieved is not None
assert retrieved.id == hello_world_passage_fixture.id
assert retrieved.text == hello_world_passage_fixture.text
assert retrieved.id == source_passage_fixture.id
assert retrieved.text == source_passage_fixture.text
def test_passage_update(server: SyncServer, hello_world_passage_fixture, default_user):
"""Test updating a passage"""
new_text = "Updated text"
hello_world_passage_fixture.text = new_text
updated = server.passage_manager.update_passage_by_id(hello_world_passage_fixture.id, hello_world_passage_fixture, actor=default_user)
assert updated is not None
assert updated.text == new_text
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
assert retrieved.text == new_text
def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user):
"""Test deleting a passage"""
server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
with pytest.raises(NoResultFound):
server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
"""Test counting passages with filters"""
base_passage = hello_world_passage_fixture
# Test total count
total = server.passage_manager.size(actor=default_user)
assert total == 5 # base passage + 4 test passages
# TODO: change login passage to be a system not user passage
# Test count with agent filter
agent_count = server.passage_manager.size(actor=default_user, agent_id=base_passage.agent_id)
assert agent_count == 5
# Test count with role filter
role_count = server.passage_manager.size(actor=default_user)
assert role_count == 5
# Test count with non-existent filter
empty_count = server.passage_manager.size(actor=default_user, agent_id="non-existent")
assert empty_count == 0
def test_passage_listing_basic(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
"""Test basic passage listing with limit"""
results = server.passage_manager.list_passages(actor=default_user, limit=3)
assert len(results) == 3
def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
"""Test cursor-based pagination functionality"""
# Make sure there are 5 passages
assert server.passage_manager.size(actor=default_user) == 5
# Get first page
first_page = server.passage_manager.list_passages(actor=default_user, limit=3)
assert len(first_page) == 3
last_id_on_first_page = first_page[-1].id
# Get second page
second_page = server.passage_manager.list_passages(actor=default_user, cursor=last_id_on_first_page, limit=3)
assert len(second_page) == 2 # Should have 2 remaining passages
assert all(r1.id != r2.id for r1 in first_page for r2 in second_page)
def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
"""Test filtering passages by agent ID"""
agent_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, limit=10)
assert len(agent_results) == 5 # base passage + 4 test passages
assert all(msg.agent_id == hello_world_passage_fixture.agent_id for msg in agent_results)
def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
"""Test searching passages by text content"""
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10)
assert len(search_results) == 4
assert all("Test passage" in msg.text for msg in search_results)
# Test no results
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10)
assert len(search_results) == 0
def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_passage_fixture, default_user, default_file, sarah_agent):
"""Test filtering passages by date range with various scenarios"""
# Set up test data with known dates
base_time = datetime.utcnow()
# Create passages at different times
passages = []
time_offsets = [
timedelta(days=-2), # 2 days ago
timedelta(days=-1), # Yesterday
timedelta(hours=-2), # 2 hours ago
timedelta(minutes=-30), # 30 minutes ago
timedelta(minutes=-1), # 1 minute ago
timedelta(minutes=0), # Now
]
for i, offset in enumerate(time_offsets):
timestamp = base_time + offset
passage = server.passage_manager.create_passage(
def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user):
"""Test creating an agent passage."""
assert agent_passage_fixture is not None
assert agent_passage_fixture.text == "Hello, I am an agent passage"
# Try to create an invalid passage (with both agent_id and source_id)
with pytest.raises(AssertionError):
server.passage_manager.create_passage(
PydanticPassage(
text="Invalid passage",
agent_id="123",
source_id="456",
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
file_id=default_file.id,
text=f"Test passage {i}",
embedding=[0.1, 0.2, 0.3],
embedding=[0.1] * 1024,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
created_at=timestamp,
),
actor=default_user,
)
passages.append(passage)
# Test cases
test_cases = [
{
"name": "Recent passages (last hour)",
"start_date": base_time - timedelta(hours=1),
"end_date": base_time + timedelta(minutes=1),
"expected_count": 1 + 3, # Should include base + -30min, -1min, and now
},
{
"name": "Yesterday's passages",
"start_date": base_time - timedelta(days=1, hours=12),
"end_date": base_time - timedelta(hours=12),
"expected_count": 1, # Should only include yesterday's passage
},
{
"name": "Future time range",
"start_date": base_time + timedelta(days=1),
"end_date": base_time + timedelta(days=2),
"expected_count": 0, # Should find no passages
},
{
"name": "All time",
"start_date": base_time - timedelta(days=3),
"end_date": base_time + timedelta(days=1),
"expected_count": 1 + len(passages), # Should find all passages
},
{
"name": "Exact timestamp match",
"start_date": passages[0].created_at - timedelta(microseconds=1),
"end_date": passages[0].created_at + timedelta(microseconds=1),
"expected_count": 1, # Should find exactly one passage
},
{
"name": "Small time window",
"start_date": base_time - timedelta(seconds=30),
"end_date": base_time + timedelta(seconds=30),
"expected_count": 1 + 1, # date + "now"
},
]
# Run test cases
for case in test_cases:
results = server.passage_manager.list_passages(
agent_id=sarah_agent.id, actor=default_user, start_date=case["start_date"], end_date=case["end_date"], limit=10
actor=default_user
)
# Verify count
assert (
len(results) == case["expected_count"]
), f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}"
# Test edge cases
def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user):
"""Test retrieving a passage by ID"""
retrieved = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, actor=default_user)
assert retrieved is not None
assert retrieved.id == agent_passage_fixture.id
assert retrieved.text == agent_passage_fixture.text
# Test with start_date but no end_date
results_start_only = server.passage_manager.list_passages(
agent_id=sarah_agent.id, actor=default_user, start_date=base_time - timedelta(minutes=2), end_date=None, limit=10
)
assert len(results_start_only) >= 2, "Should find passages after start_date"
# Test with end_date but no start_date
results_end_only = server.passage_manager.list_passages(
agent_id=sarah_agent.id, actor=default_user, start_date=None, end_date=base_time - timedelta(days=1), limit=10
)
assert len(results_end_only) >= 1, "Should find passages before end_date"
# Test limit enforcement
limited_results = server.passage_manager.list_passages(
agent_id=sarah_agent.id,
actor=default_user,
start_date=base_time - timedelta(days=3),
end_date=base_time + timedelta(days=1),
limit=3,
)
assert len(limited_results) <= 3, "Should respect the limit parameter"
retrieved = server.passage_manager.get_passage_by_id(source_passage_fixture.id, actor=default_user)
assert retrieved is not None
assert retrieved.id == source_passage_fixture.id
assert retrieved.text == source_passage_fixture.text
def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent):
"""Test vector search functionality for passages."""
passage_manager = server.passage_manager
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
# Create passages with known embeddings
passages = []
# Create passages with different embeddings
test_passages = [
"I like red",
"random text",
"blue shoes",
]
for text in test_passages:
embedding = embed_model.get_text_embedding(text)
passage = PydanticPassage(
text=text,
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding,
)
created_passage = passage_manager.create_passage(passage, default_user)
passages.append(created_passage)
assert passage_manager.size(actor=default_user) == len(passages)
# Query vector similar to "cats" embedding
query_key = "What's my favorite color?"
# List passages with vector search
results = passage_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text=query_key,
limit=3,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
)
# Verify results are ordered by similarity
assert len(results) == 3
assert results[0].text == "I like red"
assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes"
assert results[2].text == "blue shoes"
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
"""Test that passages are deleted when their parent (agent or source) is deleted."""
# Verify passages exist
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
assert agent_passage is not None
assert source_passage is not None
# Delete agent and verify its passages are deleted
server.agent_manager.delete_agent(sarah_agent.id, default_user)
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
assert len(agentic_passages) == 0
# Delete source and verify its passages are deleted
server.source_manager.delete_source(default_source.id, default_user)
with pytest.raises(NoResultFound):
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
# ======================================================================================================================
@@ -1220,6 +1320,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
assert print_tool.organization_id == default_organization.id
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
data = print_tool.model_dump(exclude=["id"])
@@ -1787,6 +1888,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
# ======================================================================================================================
# Source Manager Tests - Files
# ======================================================================================================================
def test_get_file_by_id(server: SyncServer, default_user, default_source):
"""Test retrieving a file by ID."""
file_metadata = PydanticFileMetadata(
@@ -1857,6 +1959,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
# ======================================================================================================================
# SandboxConfigManager Tests - Sandbox Configs
# ======================================================================================================================
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
sandbox_config_create = SandboxConfigCreate(
config=E2BSandboxConfig(),
@@ -1935,6 +2038,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
# ======================================================================================================================
# SandboxConfigManager Tests - Environment Variables
# ======================================================================================================================
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
@@ -2007,7 +2111,6 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
# JobManager Tests
# ======================================================================================================================
def test_create_job(server: SyncServer, default_user):
"""Test creating a job."""
job_data = PydanticJob(

View File

@@ -390,12 +390,16 @@ def test_user_message_memory(server, user_id, agent_id):
@pytest.mark.order(3)
def test_load_data(server, user_id, agent_id):
user = server.user_manager.get_user_or_default(user_id=user_id)
# create source
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
passages_before = server.agent_manager.list_passages(
actor=user, agent_id=agent_id, cursor=None, limit=10000
)
assert len(passages_before) == 0
source = server.source_manager.create_source(
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=server.default_user
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
)
# load data
@@ -409,15 +413,11 @@ def test_load_data(server, user_id, agent_id):
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)
# @pytest.mark.order(3)
# def test_attach_source_to_agent(server, user_id, agent_id):
# check archival memory size
# attach source
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
# check archival memory size
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
assert len(passages_after) == 5
@@ -465,7 +465,7 @@ def test_get_archival_memory(server, user_id, agent_id):
user = server.user_manager.get_user_by_id(user_id=user_id)
# List latest 2 passages
passages_1 = server.passage_manager.list_passages(
passages_1 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
@@ -475,7 +475,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.passage_manager.list_passages(
passages_2 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
@@ -484,24 +484,28 @@ def test_get_archival_memory(server, user_id, agent_id):
# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.passage_manager.list_passages(
passages_3 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
limit=1000,
)
# assert passages_1[0].text == "Cinderella wore a blue dress"
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
latest = passages_1[0]
earliest = passages_2[-1]
# test archival memory
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1)
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000)
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000)
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0
@@ -955,6 +959,14 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
actor = server.user_manager.get_user_or_default(user_id)
existing_sources = server.source_manager.list_sources(actor=actor)
if len(existing_sources) > 0:
for source in existing_sources:
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert initial_passage_count == 0
# Create a source
source = server.source_manager.create_source(
PydanticSource(
@@ -973,10 +985,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
# Attach source to agent first
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
# Get initial passage count
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
assert initial_passage_count == 0
# Create a job for loading the first file
job = server.job_manager.create_job(
PydanticJob(
@@ -1001,7 +1009,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert job.metadata_["num_documents"] == 1
# Verify passages were added
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert first_file_passage_count > initial_passage_count
# Create a second test file with different content
@@ -1032,14 +1040,13 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert job2.metadata_["num_documents"] == 1
# Verify passages were appended (not replaced)
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert final_passage_count > first_file_passage_count
# Verify both old and new content is searchable
passages = server.passage_manager.list_passages(
actor=actor,
passages = server.agent_manager.list_passages(
agent_id=agent_id,
source_id=source.id,
actor=actor,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
@@ -1048,35 +1055,27 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert any("chicken" in passage.text.lower() for passage in passages)
assert any("Anna".lower() in passage.text.lower() for passage in passages)
# TODO: Add this test back in after separation of `Passage tables` (LET-449)
# # Load second agent
# agent2 = server.load_agent(agent_id=other_agent_id)
# Initially should have no passages
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
assert initial_agent2_passages == 0
# # Initially should have no passages
# initial_agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
# assert initial_agent2_passages == 0
# Attach source to second agent
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)
# # Attach source to second agent
# agent2.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)
# Verify second agent has same number of passages as first agent
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
assert agent2_passages == agent1_passages
# # Verify second agent has same number of passages as first agent
# agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
# agent1_passages = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
# assert agent2_passages == agent1_passages
# # Verify second agent can query the same content
# passages2 = server.passage_manager.list_passages(
# actor=user,
# agent_id=other_agent_id,
# source_id=source.id,
# query_text="what does Timber like to eat",
# embedding_config=EmbeddingConfig.default_config(provider="openai"),
# embed_query=True,
# limit=10,
# )
# assert len(passages2) == len(passages)
# assert any("chicken" in passage.text.lower() for passage in passages2)
# assert any("sleep" in passage.text.lower() for passage in passages2)
# # Cleanup
# server.delete_agent(user_id=user_id, agent_id=agent2_state.id)
# Verify second agent can query the same content
passages2 = server.agent_manager.list_passages(
actor=actor,
agent_id=other_agent_id,
source_id=source.id,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
)
assert len(passages2) == len(passages)
assert any("chicken" in passage.text.lower() for passage in passages2)
assert any("Anna".lower() in passage.text.lower() for passage in passages2)