feat: separate Passages tables (#2245)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
105
alembic/versions/54dec07619c4_divide_passage_table_into_.py
Normal file
105
alembic/versions/54dec07619c4_divide_passage_table_into_.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""divide passage table into SourcePassages and AgentPassages
|
||||
|
||||
Revision ID: 54dec07619c4
|
||||
Revises: 4e88e702f85e
|
||||
Create Date: 2024-12-14 17:23:08.772554
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
from pgvector.sqlalchemy import Vector
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from letta.orm.custom_columns import EmbeddingConfigColumn
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '54dec07619c4'
|
||||
down_revision: Union[str, None] = '4e88e702f85e'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
'agent_passages',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('text', sa.String(), nullable=False),
|
||||
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
|
||||
sa.Column('metadata_', sa.JSON(), nullable=False),
|
||||
sa.Column('embedding', Vector(dim=4096), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||
sa.Column('organization_id', sa.String(), nullable=False),
|
||||
sa.Column('agent_id', sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False)
|
||||
op.create_table(
|
||||
'source_passages',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('text', sa.String(), nullable=False),
|
||||
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
|
||||
sa.Column('metadata_', sa.JSON(), nullable=False),
|
||||
sa.Column('embedding', Vector(dim=4096), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||
sa.Column('organization_id', sa.String(), nullable=False),
|
||||
sa.Column('file_id', sa.String(), nullable=True),
|
||||
sa.Column('source_id', sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False)
|
||||
op.drop_table('passages')
|
||||
op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey')
|
||||
op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, 'messages', type_='foreignkey')
|
||||
op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id'])
|
||||
op.drop_constraint(None, 'files', type_='foreignkey')
|
||||
op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id'])
|
||||
op.create_table(
|
||||
'passages',
|
||||
sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True),
|
||||
sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
|
||||
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True),
|
||||
sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||
sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'),
|
||||
sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'),
|
||||
sa.PrimaryKeyConstraint('id', name='passages_pkey')
|
||||
)
|
||||
op.drop_index('source_passages_org_idx', table_name='source_passages')
|
||||
op.drop_table('source_passages')
|
||||
op.drop_index('agent_passages_org_idx', table_name='agent_passages')
|
||||
op.drop_table('agent_passages')
|
||||
# ### end Alembic commands ###
|
||||
@@ -41,7 +41,6 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
Message as ChatCompletionMessage,
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
@@ -82,7 +81,7 @@ def compile_memory_metadata_block(
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
memory_edit_timestamp: datetime.datetime,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
agent_manager: Optional[AgentManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
) -> str:
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
@@ -93,7 +92,7 @@ def compile_memory_metadata_block(
|
||||
[
|
||||
f"### Memory [last modified: {timestamp_str}]",
|
||||
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
f"{agent_manager.passage_size(actor=actor, agent_id=agent_id) if agent_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
]
|
||||
)
|
||||
@@ -106,7 +105,7 @@ def compile_system_message(
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
||||
actor: PydanticUser,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
agent_manager: Optional[AgentManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
@@ -135,7 +134,7 @@ def compile_system_message(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
passage_manager=passage_manager,
|
||||
agent_manager=agent_manager,
|
||||
message_manager=message_manager,
|
||||
)
|
||||
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
||||
@@ -172,7 +171,7 @@ def initialize_message_sequence(
|
||||
agent_id: str,
|
||||
memory: Memory,
|
||||
actor: PydanticUser,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
agent_manager: Optional[AgentManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
||||
include_initial_boot_message: bool = True,
|
||||
@@ -181,7 +180,7 @@ def initialize_message_sequence(
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
# full_system_message = construct_system_with_memory(
|
||||
# system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory
|
||||
# system, memory, memory_edit_timestamp, agent_manager=agent_manager, recall_memory=recall_memory
|
||||
# )
|
||||
full_system_message = compile_system_message(
|
||||
agent_id=agent_id,
|
||||
@@ -189,7 +188,7 @@ def initialize_message_sequence(
|
||||
in_context_memory=memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=actor,
|
||||
passage_manager=passage_manager,
|
||||
agent_manager=agent_manager,
|
||||
message_manager=message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
@@ -291,8 +290,9 @@ class Agent(BaseAgent):
|
||||
self.interface = interface
|
||||
|
||||
# Create the persistence manager object based on the AgentState info
|
||||
self.passage_manager = PassageManager()
|
||||
self.message_manager = MessageManager()
|
||||
self.passage_manager = PassageManager()
|
||||
self.agent_manager = AgentManager()
|
||||
|
||||
# State needed for heartbeat pausing
|
||||
self.pause_heartbeats_start = None
|
||||
@@ -322,7 +322,7 @@ class Agent(BaseAgent):
|
||||
agent_id=self.agent_state.id,
|
||||
memory=self.agent_state.memory,
|
||||
actor=self.user,
|
||||
passage_manager=None,
|
||||
agent_manager=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
@@ -347,7 +347,7 @@ class Agent(BaseAgent):
|
||||
memory=self.agent_state.memory,
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
passage_manager=None,
|
||||
agent_manager=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
@@ -1297,7 +1297,7 @@ class Agent(BaseAgent):
|
||||
in_context_memory=self.agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=self.user,
|
||||
passage_manager=self.passage_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
message_manager=self.message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
@@ -1368,33 +1368,24 @@ class Agent(BaseAgent):
|
||||
source_id: str,
|
||||
source_manager: SourceManager,
|
||||
agent_manager: AgentManager,
|
||||
page_size: Optional[int] = None,
|
||||
):
|
||||
"""Attach data with name `source_name` to the agent from source_connector."""
|
||||
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
|
||||
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)
|
||||
|
||||
for passage in passages:
|
||||
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
|
||||
passage.agent_id = self.agent_state.id
|
||||
self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user)
|
||||
|
||||
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
|
||||
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
|
||||
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
|
||||
assert len(agents_passages) == passage_size # sanity check
|
||||
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"
|
||||
|
||||
# attach to agent
|
||||
"""Attach a source to the agent using the SourcesAgents ORM relationship.
|
||||
|
||||
Args:
|
||||
user: User performing the action
|
||||
source_id: ID of the source to attach
|
||||
source_manager: SourceManager instance to verify source exists
|
||||
agent_manager: AgentManager instance to manage agent-source relationship
|
||||
"""
|
||||
# Verify source exists and user has permission to access it
|
||||
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
|
||||
assert source is not None, f"Source {source_id} not found in metadata store"
|
||||
assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})"
|
||||
|
||||
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
|
||||
# TODO: delete @matt and remove
|
||||
# Use the agent_manager to create the relationship
|
||||
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)
|
||||
|
||||
printd(
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}.",
|
||||
)
|
||||
|
||||
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
|
||||
@@ -1550,13 +1541,13 @@ class Agent(BaseAgent):
|
||||
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
|
||||
)
|
||||
|
||||
passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id)
|
||||
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
external_memory_summary = compile_memory_metadata_block(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
memory_edit_timestamp=get_utc_time(), # dummy timestamp
|
||||
passage_manager=self.passage_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
message_manager=self.message_manager,
|
||||
)
|
||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||
@@ -1582,7 +1573,7 @@ class Agent(BaseAgent):
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(self._messages),
|
||||
num_archival_memory=passage_manager_size,
|
||||
num_archival_memory=agent_manager_passage_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
# top-level information
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.constants import MAX_PAUSE_HEARTBEATS
|
||||
from letta.services.agent_manager import AgentManager
|
||||
|
||||
# import math
|
||||
# from letta.utils import json_dumps
|
||||
@@ -200,8 +201,9 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
|
||||
|
||||
try:
|
||||
# Get results using passage manager
|
||||
all_results = self.passage_manager.list_passages(
|
||||
all_results = self.agent_manager.list_passages(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
query_text=query,
|
||||
limit=count + start, # Request enough results to handle offset
|
||||
embedding_config=self.agent_state.embedding_config,
|
||||
|
||||
@@ -312,11 +312,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
||||
for param in sig.parameters.values():
|
||||
# Exclude 'self' parameter
|
||||
# TODO: eventually remove this (only applies to BASE_TOOLS)
|
||||
if param.name == "self":
|
||||
continue
|
||||
|
||||
# exclude 'agent_state' parameter
|
||||
if param.name == "agent_state":
|
||||
if param.name in ["self", "agent_state"]: # Add agent_manager to excluded
|
||||
continue
|
||||
|
||||
# Assert that the parameter has a type annotation
|
||||
|
||||
@@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import Passage
|
||||
from letta.orm.passage import BasePassage, AgentPassage, SourcePassage
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
|
||||
@@ -82,7 +82,25 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
lazy="selectin",
|
||||
doc="Tags associated with the agent.",
|
||||
)
|
||||
# passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="agent", lazy="selectin")
|
||||
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||
"SourcePassage",
|
||||
secondary="sources_agents", # The join table for Agent -> Source
|
||||
primaryjoin="Agent.id == sources_agents.c.agent_id",
|
||||
secondaryjoin="and_(SourcePassage.source_id == sources_agents.c.source_id)",
|
||||
lazy="selectin",
|
||||
order_by="SourcePassage.created_at.desc()",
|
||||
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
|
||||
doc="All passages derived from sources associated with this agent.",
|
||||
)
|
||||
agent_passages: Mapped[List["AgentPassage"]] = relationship(
|
||||
"AgentPassage",
|
||||
back_populates="agent",
|
||||
lazy="selectin",
|
||||
order_by="AgentPassage.created_at.desc()",
|
||||
cascade="all, delete-orphan",
|
||||
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
|
||||
doc="All passages derived created by this agent.",
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> PydanticAgentState:
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
|
||||
@@ -9,7 +9,8 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.passage import SourcePassage
|
||||
|
||||
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
||||
"""Represents metadata for an uploaded file."""
|
||||
@@ -27,4 +28,4 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
||||
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
|
||||
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
|
||||
source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
|
||||
|
||||
@@ -31,30 +31,19 @@ class UserMixin(Base):
|
||||
|
||||
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
||||
|
||||
class FileMixin(Base):
|
||||
"""Mixin for models that belong to a file."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id"))
|
||||
|
||||
class AgentMixin(Base):
|
||||
"""Mixin for models that belong to an agent."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"))
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"))
|
||||
|
||||
class FileMixin(Base):
|
||||
"""Mixin for models that belong to a file."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
file_id: Mapped[Optional[str]] = mapped_column(
|
||||
String,
|
||||
ForeignKey("files.id", ondelete="CASCADE"),
|
||||
nullable=True
|
||||
)
|
||||
file_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("files.id", ondelete="CASCADE"))
|
||||
|
||||
|
||||
class SourceMixin(Base):
|
||||
@@ -62,7 +51,7 @@ class SourceMixin(Base):
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id"))
|
||||
source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
|
||||
class SandboxConfigMixin(Base):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -35,6 +35,22 @@ class Organization(SqlalchemyBase):
|
||||
)
|
||||
|
||||
# relationships
|
||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
|
||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
||||
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||
"SourcePassage",
|
||||
back_populates="organization",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
agent_passages: Mapped[List["AgentPassage"]] = relationship(
|
||||
"AgentPassage",
|
||||
back_populates="organization",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||
"""Convenience property to get all passages"""
|
||||
return self.source_passages + self.agent_passages
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +1,35 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from sqlalchemy import Column, JSON, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from letta.orm.mixins import FileMixin, OrganizationMixin
|
||||
from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.orm.custom_columns import CommonVector
|
||||
from letta.orm.mixins import FileMixin, OrganizationMixin
|
||||
from letta.orm.source import EmbeddingConfigColumn
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
|
||||
config = LettaConfig()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.agent import Agent
|
||||
|
||||
|
||||
# TODO: After migration to Passage, will need to manually delete passages where files
|
||||
# are deleted on web
|
||||
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
||||
"""Defines data model for storing Passages"""
|
||||
|
||||
__tablename__ = "passages"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||
"""Base class for all passage types with common fields"""
|
||||
__abstract__ = True
|
||||
__pydantic_model__ = PydanticPassage
|
||||
|
||||
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
|
||||
text: Mapped[str] = mapped_column(doc="Passage text content")
|
||||
source_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Source identifier")
|
||||
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
|
||||
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
|
||||
# Vector embedding field based on database type
|
||||
if settings.letta_pg_uri_no_default:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
@@ -41,9 +37,49 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
||||
else:
|
||||
embedding = Column(CommonVector)
|
||||
|
||||
# Foreign keys
|
||||
agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id"), nullable=True)
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
"""Relationship to organization"""
|
||||
return relationship("Organization", back_populates="passages", lazy="selectin")
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="passages", lazy="selectin")
|
||||
file: Mapped["FileMetadata"] = relationship("FileMetadata", back_populates="passages", lazy="selectin")
|
||||
@declared_attr
|
||||
def __table_args__(cls):
|
||||
if settings.letta_pg_uri_no_default:
|
||||
return (
|
||||
Index(f'{cls.__tablename__}_org_idx', 'organization_id'),
|
||||
{"extend_existing": True}
|
||||
)
|
||||
return ({"extend_existing": True},)
|
||||
|
||||
|
||||
class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
"""Passages derived from external files/sources"""
|
||||
__tablename__ = "source_passages"
|
||||
|
||||
@declared_attr
|
||||
def file(cls) -> Mapped["FileMetadata"]:
|
||||
"""Relationship to file"""
|
||||
return relationship("FileMetadata", back_populates="source_passages", lazy="selectin")
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
return relationship("Organization", back_populates="source_passages", lazy="selectin")
|
||||
|
||||
@declared_attr
|
||||
def source(cls) -> Mapped["Source"]:
|
||||
"""Relationship to source"""
|
||||
return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True)
|
||||
|
||||
|
||||
class AgentPassage(BasePassage, AgentMixin):
|
||||
"""Passages created by agents as archival memories"""
|
||||
__tablename__ = "agent_passages"
|
||||
|
||||
@declared_attr
|
||||
def organization(cls) -> Mapped["Organization"]:
|
||||
return relationship("Organization", back_populates="agent_passages", lazy="selectin")
|
||||
|
||||
@declared_attr
|
||||
def agent(cls) -> Mapped["Agent"]:
|
||||
"""Relationship to agent"""
|
||||
return relationship("Agent", back_populates="agent_passages", lazy="selectin", passive_deletes=True)
|
||||
|
||||
@@ -12,6 +12,9 @@ from letta.schemas.source import Source as PydanticSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.passage import SourcePassage
|
||||
from letta.orm.agent import Agent
|
||||
|
||||
|
||||
class Source(SqlalchemyBase, OrganizationMixin):
|
||||
@@ -28,4 +31,5 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
|
||||
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
|
||||
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from sqlalchemy import String, desc, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.exc import DBAPIError, IntegrityError
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
@@ -242,7 +242,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
session.commit()
|
||||
session.refresh(self)
|
||||
return self
|
||||
except DBAPIError as e:
|
||||
except (DBAPIError, IntegrityError) as e:
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
|
||||
@@ -10,7 +10,7 @@ from letta.utils import get_utc_time
|
||||
|
||||
|
||||
class PassageBase(OrmMetadataBase):
|
||||
__id_prefix__ = "passage_legacy"
|
||||
__id_prefix__ = "passage"
|
||||
|
||||
is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")
|
||||
|
||||
|
||||
@@ -932,7 +932,7 @@ class SyncServer(Server):
|
||||
|
||||
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user))
|
||||
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
|
||||
|
||||
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
@@ -949,18 +949,9 @@ class SyncServer(Server):
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
passages = self.agent_manager.list_passages(agent_id=agent_id, actor=actor)
|
||||
|
||||
# iterate over records
|
||||
records = letta_agent.passage_manager.list_passages(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return records
|
||||
return passages
|
||||
|
||||
def get_agent_archival_cursor(
|
||||
self,
|
||||
@@ -974,15 +965,13 @@ class SyncServer(Server):
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# iterate over records
|
||||
records = letta_agent.passage_manager.list_passages(
|
||||
actor=self.default_user,
|
||||
records = self.agent_manager.list_passages(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
ascending=not reverse,
|
||||
)
|
||||
return records
|
||||
|
||||
@@ -1098,7 +1087,8 @@ class SyncServer(Server):
|
||||
self.source_manager.delete_source(source_id=source_id, actor=actor)
|
||||
|
||||
# delete data from passage store
|
||||
self.passage_manager.delete_passages(actor=actor, limit=None, source_id=source_id)
|
||||
passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None)
|
||||
self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted)
|
||||
|
||||
# TODO: delete data from agent passage stores (?)
|
||||
|
||||
@@ -1129,9 +1119,11 @@ class SyncServer(Server):
|
||||
for agent_state in agent_states:
|
||||
agent_id = agent_state.id
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
curr_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
|
||||
|
||||
# Attach source to agent
|
||||
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
||||
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
|
||||
new_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
|
||||
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
||||
assert new_passage_size >= curr_passage_size # in case empty files are added
|
||||
|
||||
return job
|
||||
@@ -1195,14 +1187,9 @@ class SyncServer(Server):
|
||||
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
elif source_name:
|
||||
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
||||
source_id = source.id
|
||||
else:
|
||||
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
||||
source_id = source.id
|
||||
|
||||
# TODO: This should be done with the ORM?
|
||||
# delete all Passage objects with source_id==source_id from agent's archival memory
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id)
|
||||
|
||||
# delete agent-source mapping
|
||||
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
@@ -1224,7 +1211,7 @@ class SyncServer(Server):
|
||||
for source in sources:
|
||||
|
||||
# count number of passages
|
||||
num_passages = self.passage_manager.size(actor=actor, source_id=source.id)
|
||||
num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id)
|
||||
|
||||
# TODO: add when files table implemented
|
||||
## count number of files
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from sqlalchemy import select, union_all, literal, func, Select
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import Tool as ToolModel
|
||||
from letta.orm import AgentPassage, SourcePassage
|
||||
from letta.orm import SourcesAgents
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -21,9 +30,9 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
_process_tags,
|
||||
derive_system_message,
|
||||
)
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -229,13 +238,6 @@ class AgentManager:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
|
||||
# TODO: This is done very hacky on purpose
|
||||
# TODO: 1000 limit is also wack
|
||||
passage_manager = PassageManager()
|
||||
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
|
||||
|
||||
agent_state = agent.to_pydantic()
|
||||
agent.hard_delete(session)
|
||||
return agent_state
|
||||
@@ -407,6 +409,262 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Passage Management
|
||||
# ======================================================================================================================
|
||||
def _build_passage_query(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
cursor: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False,
|
||||
) -> Select:
|
||||
"""Helper function to build the base passage query with all filters applied.
|
||||
|
||||
Returns the query before any limit or count operations are applied.
|
||||
"""
|
||||
embedded_text = None
|
||||
if embed_query:
|
||||
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
||||
assert query_text is not None, "query_text must be specified for vector search"
|
||||
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
||||
embedded_text = np.array(embedded_text)
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Start with base query for source passages
|
||||
source_passages = None
|
||||
if not agent_only: # Include source passages
|
||||
if agent_id is not None:
|
||||
source_passages = (
|
||||
select(
|
||||
SourcePassage,
|
||||
literal(None).label('agent_id')
|
||||
)
|
||||
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
||||
.where(SourcesAgents.agent_id == agent_id)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
else:
|
||||
source_passages = (
|
||||
select(
|
||||
SourcePassage,
|
||||
literal(None).label('agent_id')
|
||||
)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
|
||||
if source_id:
|
||||
source_passages = source_passages.where(SourcePassage.source_id == source_id)
|
||||
if file_id:
|
||||
source_passages = source_passages.where(SourcePassage.file_id == file_id)
|
||||
|
||||
# Add agent passages query
|
||||
agent_passages = None
|
||||
if agent_id is not None:
|
||||
agent_passages = (
|
||||
select(
|
||||
AgentPassage.id,
|
||||
AgentPassage.text,
|
||||
AgentPassage.embedding_config,
|
||||
AgentPassage.metadata_,
|
||||
AgentPassage.embedding,
|
||||
AgentPassage.created_at,
|
||||
AgentPassage.updated_at,
|
||||
AgentPassage.is_deleted,
|
||||
AgentPassage._created_by_id,
|
||||
AgentPassage._last_updated_by_id,
|
||||
AgentPassage.organization_id,
|
||||
literal(None).label('file_id'),
|
||||
literal(None).label('source_id'),
|
||||
AgentPassage.agent_id
|
||||
)
|
||||
.where(AgentPassage.agent_id == agent_id)
|
||||
.where(AgentPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
|
||||
# Combine queries
|
||||
if source_passages is not None and agent_passages is not None:
|
||||
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
|
||||
elif agent_passages is not None:
|
||||
combined_query = agent_passages.cte('combined_passages')
|
||||
elif source_passages is not None:
|
||||
combined_query = source_passages.cte('combined_passages')
|
||||
else:
|
||||
raise ValueError("No passages found")
|
||||
|
||||
# Build main query from combined CTE
|
||||
main_query = select(combined_query)
|
||||
|
||||
# Apply filters
|
||||
if start_date:
|
||||
main_query = main_query.where(combined_query.c.created_at >= start_date)
|
||||
if end_date:
|
||||
main_query = main_query.where(combined_query.c.created_at <= end_date)
|
||||
if source_id:
|
||||
main_query = main_query.where(combined_query.c.source_id == source_id)
|
||||
if file_id:
|
||||
main_query = main_query.where(combined_query.c.file_id == file_id)
|
||||
|
||||
# Vector search
|
||||
if embedded_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL with pgvector
|
||||
main_query = main_query.order_by(
|
||||
combined_query.c.embedding.cosine_distance(embedded_text).asc()
|
||||
)
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
query_embedding_binary = adapt_array(embedded_text)
|
||||
if ascending:
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
combined_query.c.created_at.asc(),
|
||||
combined_query.c.id.asc()
|
||||
)
|
||||
else:
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
combined_query.c.created_at.desc(),
|
||||
combined_query.c.id.asc()
|
||||
)
|
||||
else:
|
||||
if query_text:
|
||||
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if cursor:
|
||||
cursor_query = select(combined_query.c.created_at).where(
|
||||
combined_query.c.id == cursor
|
||||
).scalar_subquery()
|
||||
|
||||
if ascending:
|
||||
main_query = main_query.where(
|
||||
combined_query.c.created_at > cursor_query
|
||||
)
|
||||
else:
|
||||
main_query = main_query.where(
|
||||
combined_query.c.created_at < cursor_query
|
||||
)
|
||||
|
||||
# Add ordering if not already ordered by similarity
|
||||
if not embed_query:
|
||||
if ascending:
|
||||
main_query = main_query.order_by(
|
||||
combined_query.c.created_at.asc(),
|
||||
combined_query.c.id.asc(),
|
||||
)
|
||||
else:
|
||||
main_query = main_query.order_by(
|
||||
combined_query.c.created_at.desc(),
|
||||
combined_query.c.id.asc(),
|
||||
)
|
||||
|
||||
return main_query
|
||||
|
||||
@enforce_types
|
||||
def list_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
cursor: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
with self.session_maker() as session:
|
||||
main_query = self._build_passage_query(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
query_text=query_text,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
cursor=cursor,
|
||||
source_id=source_id,
|
||||
embed_query=embed_query,
|
||||
ascending=ascending,
|
||||
embedding_config=embedding_config,
|
||||
agent_only=agent_only,
|
||||
)
|
||||
|
||||
# Add limit
|
||||
if limit:
|
||||
main_query = main_query.limit(limit)
|
||||
|
||||
# Execute query
|
||||
results = list(session.execute(main_query))
|
||||
|
||||
passages = []
|
||||
for row in results:
|
||||
data = dict(row._mapping)
|
||||
if data['agent_id'] is not None:
|
||||
# This is an AgentPassage - remove source fields
|
||||
data.pop('source_id', None)
|
||||
data.pop('file_id', None)
|
||||
passage = AgentPassage(**data)
|
||||
else:
|
||||
# This is a SourcePassage - remove agent field
|
||||
data.pop('agent_id', None)
|
||||
passage = SourcePassage(**data)
|
||||
passages.append(passage)
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
|
||||
@enforce_types
|
||||
def passage_size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
cursor: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False
|
||||
) -> int:
|
||||
"""Returns the count of passages matching the given criteria."""
|
||||
with self.session_maker() as session:
|
||||
main_query = self._build_passage_query(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
query_text=query_text,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
cursor=cursor,
|
||||
source_id=source_id,
|
||||
embed_query=embed_query,
|
||||
ascending=ascending,
|
||||
embedding_config=embedding_config,
|
||||
agent_only=agent_only,
|
||||
)
|
||||
|
||||
# Convert to count query
|
||||
count_query = select(func.count()).select_from(main_query.subquery())
|
||||
return session.scalar(count_query) or 0
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tool Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from sqlalchemy import select, union_all, literal
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import embedding_model, parse_and_chunk_text
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.passage import Passage as PassageModel
|
||||
from letta.orm.passage import AgentPassage, SourcePassage
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
@@ -14,6 +15,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
|
||||
class PassageManager:
|
||||
"""Manager class to handle business logic related to Passages."""
|
||||
|
||||
@@ -26,14 +28,51 @@ class PassageManager:
|
||||
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch a passage by ID."""
|
||||
with self.session_maker() as session:
|
||||
passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
# Try source passages first
|
||||
try:
|
||||
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
except NoResultFound:
|
||||
# Try archival passages
|
||||
try:
|
||||
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
|
||||
|
||||
@enforce_types
|
||||
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new passage."""
|
||||
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
||||
# Common fields for both passage types
|
||||
data = pydantic_passage.model_dump()
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata_", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.utcnow()),
|
||||
}
|
||||
|
||||
if "agent_id" in data and data["agent_id"]:
|
||||
assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
|
||||
agent_fields = {
|
||||
"agent_id": data["agent_id"],
|
||||
}
|
||||
passage = AgentPassage(**common_fields, **agent_fields)
|
||||
elif "source_id" in data and data["source_id"]:
|
||||
assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
|
||||
source_fields = {
|
||||
"source_id": data["source_id"],
|
||||
"file_id": data.get("file_id"),
|
||||
}
|
||||
passage = SourcePassage(**common_fields, **source_fields)
|
||||
else:
|
||||
raise ValueError("Passage must have either agent_id or source_id")
|
||||
|
||||
with self.session_maker() as session:
|
||||
passage = PassageModel(**pydantic_passage.model_dump())
|
||||
passage.create(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@@ -93,14 +132,23 @@ class PassageManager:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Fetch existing message from database
|
||||
curr_passage = PassageModel.read(
|
||||
db_session=session,
|
||||
identifier=passage_id,
|
||||
actor=actor,
|
||||
)
|
||||
if not curr_passage:
|
||||
raise ValueError(f"Passage with id {passage_id} does not exist.")
|
||||
# Try source passages first
|
||||
try:
|
||||
curr_passage = SourcePassage.read(
|
||||
db_session=session,
|
||||
identifier=passage_id,
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
# Try agent passages
|
||||
try:
|
||||
curr_passage = AgentPassage.read(
|
||||
db_session=session,
|
||||
identifier=passage_id,
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Passage with id {passage_id} does not exist.")
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
update_data = passage.model_dump(exclude_unset=True, exclude_none=True)
|
||||
@@ -113,104 +161,32 @@ class PassageManager:
|
||||
|
||||
@enforce_types
|
||||
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete a passage."""
|
||||
"""Delete a passage from either source or archival passages."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Try source passages first
|
||||
try:
|
||||
passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
passage.hard_delete(session, actor=actor)
|
||||
return True
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Passage with id {passage_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def list_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
ascending: bool = True,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""List passages with pagination."""
|
||||
with self.session_maker() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if file_id:
|
||||
filters["file_id"] = file_id
|
||||
if source_id:
|
||||
filters["source_id"] = source_id
|
||||
|
||||
embedded_text = None
|
||||
if embed_query:
|
||||
assert embedding_config is not None
|
||||
|
||||
# Embed the text
|
||||
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
||||
|
||||
# Pad the embedding with zeros
|
||||
embedded_text = np.array(embedded_text)
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
results = PassageModel.list(
|
||||
db_session=session,
|
||||
cursor=cursor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
ascending=ascending,
|
||||
query_text=query_text if not embedded_text else None,
|
||||
query_embedding=embedded_text,
|
||||
**filters,
|
||||
)
|
||||
return [p.to_pydantic() for p in results]
|
||||
|
||||
@enforce_types
|
||||
def size(self, actor: PydanticUser, agent_id: Optional[str] = None, **kwargs) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
|
||||
Args:
|
||||
actor : The user requesting the count
|
||||
agent_id: The agent ID
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs)
|
||||
# Try archival passages
|
||||
try:
|
||||
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
passage.hard_delete(session, actor=actor)
|
||||
return True
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Passage with id {passage_id} not found.")
|
||||
|
||||
def delete_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
cursor: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
passages: List[PydanticPassage],
|
||||
) -> bool:
|
||||
|
||||
passages = self.list_passages(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
query_text=query_text,
|
||||
source_id=source_id,
|
||||
)
|
||||
|
||||
# TODO: This is very inefficient
|
||||
# TODO: We should have a base `delete_all_matching_filters`-esque function
|
||||
for passage in passages:
|
||||
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
||||
return True
|
||||
|
||||
@@ -482,7 +482,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
# check agent archival memory size
|
||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||
print(archival_memories)
|
||||
assert len(archival_memories) == 0
|
||||
|
||||
# load a file into a source (non-blocking job)
|
||||
|
||||
@@ -2,6 +2,8 @@ import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from httpx._transports import default
|
||||
from numpy import source
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -17,7 +19,8 @@ from letta.orm import (
|
||||
Job,
|
||||
Message,
|
||||
Organization,
|
||||
Passage,
|
||||
AgentPassage,
|
||||
SourcePassage,
|
||||
SandboxConfig,
|
||||
SandboxEnvironmentVariable,
|
||||
Source,
|
||||
@@ -82,7 +85,8 @@ def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Passage))
|
||||
session.execute(delete(AgentPassage))
|
||||
session.execute(delete(SourcePassage))
|
||||
session.execute(delete(Job))
|
||||
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||
session.execute(delete(BlocksAgents))
|
||||
@@ -189,39 +193,79 @@ def print_tool(server: SyncServer, default_user, default_organization):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hello_world_passage_fixture(server: SyncServer, default_user, default_file, sarah_agent):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
# Set up passage
|
||||
dummy_embedding = [0.0] * 2
|
||||
message = PydanticPassage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
text="Hello, world!",
|
||||
embedding=dummy_embedding,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
"""Fixture to create an agent passage."""
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text="Hello, I am an agent passage",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
|
||||
msg = server.passage_manager.create_passage(message, actor=default_user)
|
||||
yield msg
|
||||
yield passage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent) -> list[PydanticPassage]:
|
||||
"""Helper function to create test passages for all tests"""
|
||||
dummy_embedding = [0] * 2
|
||||
passages = [
|
||||
def source_passage_fixture(server: SyncServer, default_user, default_file, default_source):
|
||||
"""Fixture to create a source passage."""
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
text="Hello, I am a source passage",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
text=f"Test passage {i}",
|
||||
embedding=dummy_embedding,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
yield passage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent, default_source):
|
||||
"""Helper function to create test passages for all tests."""
|
||||
# Create agent passages
|
||||
passages = []
|
||||
for i in range(5):
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text=f"Agent passage {i}",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
server.passage_manager.create_many_passages(passages, actor=default_user)
|
||||
passages.append(passage)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
|
||||
# Create source passages
|
||||
for i in range(5):
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text=f"Source passage {i}",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
passages.append(passage)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
|
||||
return passages
|
||||
|
||||
|
||||
@@ -389,6 +433,49 @@ def server():
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
"""Setup fixture for agent passages tests"""
|
||||
agent_id = sarah_agent.id
|
||||
actor = default_user
|
||||
|
||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=default_source.id, actor=actor)
|
||||
|
||||
# Create some source passages
|
||||
source_passages = []
|
||||
for i in range(3):
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
source_id=default_source.id,
|
||||
text=f"Source passage {i}",
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=actor
|
||||
)
|
||||
source_passages.append(passage)
|
||||
|
||||
# Create some agent passages
|
||||
agent_passages = []
|
||||
for i in range(2):
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
text=f"Agent passage {i}",
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=actor
|
||||
)
|
||||
agent_passages.append(passage)
|
||||
|
||||
yield agent_passages, source_passages
|
||||
|
||||
# Cleanup
|
||||
server.source_manager.delete_source(default_source.id, actor=actor)
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -849,6 +936,199 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
|
||||
assert block.label == default_block.label
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Agent Manager - Passages Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test basic listing functionality of agent passages"""
|
||||
|
||||
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
|
||||
assert len(all_passages) == 5 # 3 source + 2 agent passages
|
||||
|
||||
|
||||
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test ordering of agent passages"""
|
||||
|
||||
# Test ascending order
|
||||
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
|
||||
assert len(asc_passages) == 5
|
||||
for i in range(1, len(asc_passages)):
|
||||
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
|
||||
|
||||
# Test descending order
|
||||
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
|
||||
assert len(desc_passages) == 5
|
||||
for i in range(1, len(desc_passages)):
|
||||
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
|
||||
|
||||
|
||||
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test pagination of agent passages"""
|
||||
|
||||
# Test limit
|
||||
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
|
||||
assert len(limited_passages) == 3
|
||||
|
||||
# Test cursor-based pagination
|
||||
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
|
||||
assert len(first_page) == 2
|
||||
|
||||
second_page = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
cursor=first_page[-1].id,
|
||||
limit=2,
|
||||
ascending=True
|
||||
)
|
||||
assert len(second_page) == 2
|
||||
assert first_page[-1].id != second_page[0].id
|
||||
assert first_page[-1].created_at <= second_page[0].created_at
|
||||
|
||||
|
||||
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
# Test text search for source passages
|
||||
source_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="Source passage"
|
||||
)
|
||||
assert len(source_text_passages) == 3
|
||||
|
||||
# Test text search for agent passages
|
||||
agent_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="Agent passage"
|
||||
)
|
||||
assert len(agent_text_passages) == 2
|
||||
|
||||
|
||||
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
# Test text search for agent passages
|
||||
agent_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
agent_only=True
|
||||
)
|
||||
assert len(agent_text_passages) == 2
|
||||
|
||||
|
||||
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
|
||||
"""Test filtering functionality of agent passages"""
|
||||
|
||||
# Test source filtering
|
||||
source_filtered = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
source_id=default_source.id
|
||||
)
|
||||
assert len(source_filtered) == 3
|
||||
|
||||
# Test date filtering
|
||||
now = datetime.utcnow()
|
||||
future_date = now + timedelta(days=1)
|
||||
past_date = now - timedelta(days=1)
|
||||
|
||||
date_filtered = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
start_date=past_date,
|
||||
end_date=future_date
|
||||
)
|
||||
assert len(date_filtered) == 5
|
||||
|
||||
|
||||
def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source):
|
||||
"""Test vector search functionality of agent passages"""
|
||||
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
|
||||
# Create passages with different embeddings
|
||||
test_passages = [
|
||||
"I like red",
|
||||
"random text",
|
||||
"blue shoes",
|
||||
]
|
||||
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
for i, text in enumerate(test_passages):
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
if i % 2 == 0:
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
)
|
||||
else:
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
source_id=default_source.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
)
|
||||
created_passage = server.passage_manager.create_passage(passage, default_user)
|
||||
passages.append(created_passage)
|
||||
|
||||
# Query vector similar to "red" embedding
|
||||
query_key = "What's my favorite color?"
|
||||
|
||||
# Test vector search with all passages
|
||||
results = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text=query_key,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
# Verify results are ordered by similarity
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "I like red"
|
||||
assert "random" in results[1].text or "random" in results[2].text
|
||||
assert "blue" in results[1].text or "blue" in results[2].text
|
||||
|
||||
# Test vector search with agent_only=True
|
||||
agent_only_results = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text=query_key,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
agent_only=True
|
||||
)
|
||||
|
||||
# Verify agent-only results
|
||||
assert len(agent_only_results) == 2
|
||||
assert agent_only_results[0].text == "I like red"
|
||||
assert agent_only_results[1].text == "blue shoes"
|
||||
|
||||
|
||||
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
|
||||
"""Test listing passages from a source without specifying an agent."""
|
||||
|
||||
# List passages by source_id without agent_id
|
||||
source_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
source_id=default_source.id,
|
||||
)
|
||||
|
||||
# Verify we get only source passages (3 from agent_passages_setup)
|
||||
assert len(source_passages) == 3
|
||||
assert all(p.source_id == default_source.id for p in source_passages)
|
||||
assert all(p.agent_id is None for p in source_passages)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Organization Manager Tests
|
||||
# ======================================================================================================================
|
||||
@@ -900,266 +1180,86 @@ def test_list_organizations_pagination(server: SyncServer):
|
||||
# Passage Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test creating a passage using hello_world_passage_fixture fixture"""
|
||||
assert hello_world_passage_fixture.id is not None
|
||||
assert hello_world_passage_fixture.text == "Hello, world!"
|
||||
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
|
||||
"""Test creating a passage using agent_passage_fixture fixture"""
|
||||
assert agent_passage_fixture.id is not None
|
||||
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||
|
||||
# Verify we can retrieve it
|
||||
retrieved = server.passage_manager.get_passage_by_id(
|
||||
hello_world_passage_fixture.id,
|
||||
agent_passage_fixture.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == hello_world_passage_fixture.id
|
||||
assert retrieved.text == hello_world_passage_fixture.text
|
||||
assert retrieved.id == agent_passage_fixture.id
|
||||
assert retrieved.text == agent_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_get_by_id(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test retrieving a passage by ID"""
|
||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
def test_passage_create_source(server: SyncServer, source_passage_fixture, default_user):
|
||||
"""Test creating a source passage."""
|
||||
assert source_passage_fixture is not None
|
||||
assert source_passage_fixture.text == "Hello, I am a source passage"
|
||||
|
||||
# Verify we can retrieve it
|
||||
retrieved = server.passage_manager.get_passage_by_id(
|
||||
source_passage_fixture.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == hello_world_passage_fixture.id
|
||||
assert retrieved.text == hello_world_passage_fixture.text
|
||||
assert retrieved.id == source_passage_fixture.id
|
||||
assert retrieved.text == source_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_update(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test updating a passage"""
|
||||
new_text = "Updated text"
|
||||
hello_world_passage_fixture.text = new_text
|
||||
updated = server.passage_manager.update_passage_by_id(hello_world_passage_fixture.id, hello_world_passage_fixture, actor=default_user)
|
||||
assert updated is not None
|
||||
assert updated.text == new_text
|
||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
assert retrieved.text == new_text
|
||||
|
||||
|
||||
def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test deleting a passage"""
|
||||
server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
with pytest.raises(NoResultFound):
|
||||
server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
|
||||
|
||||
def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test counting passages with filters"""
|
||||
base_passage = hello_world_passage_fixture
|
||||
|
||||
# Test total count
|
||||
total = server.passage_manager.size(actor=default_user)
|
||||
assert total == 5 # base passage + 4 test passages
|
||||
# TODO: change login passage to be a system not user passage
|
||||
|
||||
# Test count with agent filter
|
||||
agent_count = server.passage_manager.size(actor=default_user, agent_id=base_passage.agent_id)
|
||||
assert agent_count == 5
|
||||
|
||||
# Test count with role filter
|
||||
role_count = server.passage_manager.size(actor=default_user)
|
||||
assert role_count == 5
|
||||
|
||||
# Test count with non-existent filter
|
||||
empty_count = server.passage_manager.size(actor=default_user, agent_id="non-existent")
|
||||
assert empty_count == 0
|
||||
|
||||
|
||||
def test_passage_listing_basic(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test basic passage listing with limit"""
|
||||
results = server.passage_manager.list_passages(actor=default_user, limit=3)
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test cursor-based pagination functionality"""
|
||||
|
||||
# Make sure there are 5 passages
|
||||
assert server.passage_manager.size(actor=default_user) == 5
|
||||
|
||||
# Get first page
|
||||
first_page = server.passage_manager.list_passages(actor=default_user, limit=3)
|
||||
assert len(first_page) == 3
|
||||
|
||||
last_id_on_first_page = first_page[-1].id
|
||||
|
||||
# Get second page
|
||||
second_page = server.passage_manager.list_passages(actor=default_user, cursor=last_id_on_first_page, limit=3)
|
||||
assert len(second_page) == 2 # Should have 2 remaining passages
|
||||
assert all(r1.id != r2.id for r1 in first_page for r2 in second_page)
|
||||
|
||||
|
||||
def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
|
||||
"""Test filtering passages by agent ID"""
|
||||
agent_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, limit=10)
|
||||
assert len(agent_results) == 5 # base passage + 4 test passages
|
||||
assert all(msg.agent_id == hello_world_passage_fixture.agent_id for msg in agent_results)
|
||||
|
||||
|
||||
def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
|
||||
"""Test searching passages by text content"""
|
||||
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10)
|
||||
assert len(search_results) == 4
|
||||
assert all("Test passage" in msg.text for msg in search_results)
|
||||
|
||||
# Test no results
|
||||
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10)
|
||||
assert len(search_results) == 0
|
||||
|
||||
|
||||
def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_passage_fixture, default_user, default_file, sarah_agent):
|
||||
"""Test filtering passages by date range with various scenarios"""
|
||||
# Set up test data with known dates
|
||||
base_time = datetime.utcnow()
|
||||
|
||||
# Create passages at different times
|
||||
passages = []
|
||||
time_offsets = [
|
||||
timedelta(days=-2), # 2 days ago
|
||||
timedelta(days=-1), # Yesterday
|
||||
timedelta(hours=-2), # 2 hours ago
|
||||
timedelta(minutes=-30), # 30 minutes ago
|
||||
timedelta(minutes=-1), # 1 minute ago
|
||||
timedelta(minutes=0), # Now
|
||||
]
|
||||
|
||||
for i, offset in enumerate(time_offsets):
|
||||
timestamp = base_time + offset
|
||||
passage = server.passage_manager.create_passage(
|
||||
def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user):
|
||||
"""Test creating an agent passage."""
|
||||
assert agent_passage_fixture is not None
|
||||
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||
|
||||
# Try to create an invalid passage (with both agent_id and source_id)
|
||||
with pytest.raises(AssertionError):
|
||||
server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text="Invalid passage",
|
||||
agent_id="123",
|
||||
source_id="456",
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
text=f"Test passage {i}",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
embedding=[0.1] * 1024,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
created_at=timestamp,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
# Test cases
|
||||
test_cases = [
|
||||
{
|
||||
"name": "Recent passages (last hour)",
|
||||
"start_date": base_time - timedelta(hours=1),
|
||||
"end_date": base_time + timedelta(minutes=1),
|
||||
"expected_count": 1 + 3, # Should include base + -30min, -1min, and now
|
||||
},
|
||||
{
|
||||
"name": "Yesterday's passages",
|
||||
"start_date": base_time - timedelta(days=1, hours=12),
|
||||
"end_date": base_time - timedelta(hours=12),
|
||||
"expected_count": 1, # Should only include yesterday's passage
|
||||
},
|
||||
{
|
||||
"name": "Future time range",
|
||||
"start_date": base_time + timedelta(days=1),
|
||||
"end_date": base_time + timedelta(days=2),
|
||||
"expected_count": 0, # Should find no passages
|
||||
},
|
||||
{
|
||||
"name": "All time",
|
||||
"start_date": base_time - timedelta(days=3),
|
||||
"end_date": base_time + timedelta(days=1),
|
||||
"expected_count": 1 + len(passages), # Should find all passages
|
||||
},
|
||||
{
|
||||
"name": "Exact timestamp match",
|
||||
"start_date": passages[0].created_at - timedelta(microseconds=1),
|
||||
"end_date": passages[0].created_at + timedelta(microseconds=1),
|
||||
"expected_count": 1, # Should find exactly one passage
|
||||
},
|
||||
{
|
||||
"name": "Small time window",
|
||||
"start_date": base_time - timedelta(seconds=30),
|
||||
"end_date": base_time + timedelta(seconds=30),
|
||||
"expected_count": 1 + 1, # date + "now"
|
||||
},
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for case in test_cases:
|
||||
results = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id, actor=default_user, start_date=case["start_date"], end_date=case["end_date"], limit=10
|
||||
actor=default_user
|
||||
)
|
||||
|
||||
# Verify count
|
||||
assert (
|
||||
len(results) == case["expected_count"]
|
||||
), f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}"
|
||||
|
||||
# Test edge cases
|
||||
def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user):
|
||||
"""Test retrieving a passage by ID"""
|
||||
retrieved = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == agent_passage_fixture.id
|
||||
assert retrieved.text == agent_passage_fixture.text
|
||||
|
||||
# Test with start_date but no end_date
|
||||
results_start_only = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id, actor=default_user, start_date=base_time - timedelta(minutes=2), end_date=None, limit=10
|
||||
)
|
||||
assert len(results_start_only) >= 2, "Should find passages after start_date"
|
||||
|
||||
# Test with end_date but no start_date
|
||||
results_end_only = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id, actor=default_user, start_date=None, end_date=base_time - timedelta(days=1), limit=10
|
||||
)
|
||||
assert len(results_end_only) >= 1, "Should find passages before end_date"
|
||||
|
||||
# Test limit enforcement
|
||||
limited_results = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
start_date=base_time - timedelta(days=3),
|
||||
end_date=base_time + timedelta(days=1),
|
||||
limit=3,
|
||||
)
|
||||
assert len(limited_results) <= 3, "Should respect the limit parameter"
|
||||
retrieved = server.passage_manager.get_passage_by_id(source_passage_fixture.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == source_passage_fixture.id
|
||||
assert retrieved.text == source_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent):
|
||||
"""Test vector search functionality for passages."""
|
||||
passage_manager = server.passage_manager
|
||||
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
|
||||
# Create passages with different embeddings
|
||||
test_passages = [
|
||||
"I like red",
|
||||
"random text",
|
||||
"blue shoes",
|
||||
]
|
||||
|
||||
for text in test_passages:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding,
|
||||
)
|
||||
created_passage = passage_manager.create_passage(passage, default_user)
|
||||
passages.append(created_passage)
|
||||
assert passage_manager.size(actor=default_user) == len(passages)
|
||||
|
||||
# Query vector similar to "cats" embedding
|
||||
query_key = "What's my favorite color?"
|
||||
|
||||
# List passages with vector search
|
||||
results = passage_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text=query_key,
|
||||
limit=3,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
# Verify results are ordered by similarity
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "I like red"
|
||||
assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes"
|
||||
assert results[2].text == "blue shoes"
|
||||
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
|
||||
"""Test that passages are deleted when their parent (agent or source) is deleted."""
|
||||
# Verify passages exist
|
||||
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
|
||||
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||
assert agent_passage is not None
|
||||
assert source_passage is not None
|
||||
|
||||
# Delete agent and verify its passages are deleted
|
||||
server.agent_manager.delete_agent(sarah_agent.id, default_user)
|
||||
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
|
||||
assert len(agentic_passages) == 0
|
||||
|
||||
# Delete source and verify its passages are deleted
|
||||
server.source_manager.delete_source(default_source.id, default_user)
|
||||
with pytest.raises(NoResultFound):
|
||||
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -1220,6 +1320,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
|
||||
assert print_tool.organization_id == default_organization.id
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
||||
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
||||
data = print_tool.model_dump(exclude=["id"])
|
||||
@@ -1787,6 +1888,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
|
||||
# ======================================================================================================================
|
||||
# Source Manager Tests - Files
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_get_file_by_id(server: SyncServer, default_user, default_source):
|
||||
"""Test retrieving a file by ID."""
|
||||
file_metadata = PydanticFileMetadata(
|
||||
@@ -1857,6 +1959,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
|
||||
# ======================================================================================================================
|
||||
# SandboxConfigManager Tests - Sandbox Configs
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
|
||||
sandbox_config_create = SandboxConfigCreate(
|
||||
config=E2BSandboxConfig(),
|
||||
@@ -1935,6 +2038,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
|
||||
# ======================================================================================================================
|
||||
# SandboxConfigManager Tests - Environment Variables
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
|
||||
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
|
||||
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
|
||||
@@ -2007,7 +2111,6 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
|
||||
# JobManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_job(server: SyncServer, default_user):
|
||||
"""Test creating a job."""
|
||||
job_data = PydanticJob(
|
||||
|
||||
@@ -390,12 +390,16 @@ def test_user_message_memory(server, user_id, agent_id):
|
||||
|
||||
@pytest.mark.order(3)
|
||||
def test_load_data(server, user_id, agent_id):
|
||||
user = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# create source
|
||||
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
|
||||
passages_before = server.agent_manager.list_passages(
|
||||
actor=user, agent_id=agent_id, cursor=None, limit=10000
|
||||
)
|
||||
assert len(passages_before) == 0
|
||||
|
||||
source = server.source_manager.create_source(
|
||||
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=server.default_user
|
||||
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
|
||||
)
|
||||
|
||||
# load data
|
||||
@@ -409,15 +413,11 @@ def test_load_data(server, user_id, agent_id):
|
||||
connector = DummyDataConnector(archival_memories)
|
||||
server.load_data(user_id, connector, source.name)
|
||||
|
||||
# @pytest.mark.order(3)
|
||||
# def test_attach_source_to_agent(server, user_id, agent_id):
|
||||
# check archival memory size
|
||||
|
||||
# attach source
|
||||
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
||||
|
||||
# check archival memory size
|
||||
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
|
||||
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
|
||||
assert len(passages_after) == 5
|
||||
|
||||
|
||||
@@ -465,7 +465,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
user = server.user_manager.get_user_by_id(user_id=user_id)
|
||||
|
||||
# List latest 2 passages
|
||||
passages_1 = server.passage_manager.list_passages(
|
||||
passages_1 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
@@ -475,7 +475,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
|
||||
# List next 3 passages (earliest 3)
|
||||
cursor1 = passages_1[-1].id
|
||||
passages_2 = server.passage_manager.list_passages(
|
||||
passages_2 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
@@ -484,24 +484,28 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
|
||||
# List all 5
|
||||
cursor2 = passages_1[0].created_at
|
||||
passages_3 = server.passage_manager.list_passages(
|
||||
passages_3 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
end_date=cursor2,
|
||||
limit=1000,
|
||||
)
|
||||
# assert passages_1[0].text == "Cinderella wore a blue dress"
|
||||
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
|
||||
latest = passages_1[0]
|
||||
earliest = passages_2[-1]
|
||||
|
||||
# test archival memory
|
||||
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1)
|
||||
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
|
||||
assert len(passage_1) == 1
|
||||
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000)
|
||||
assert passage_1[0].text == "alpha"
|
||||
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
|
||||
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert all("alpha" not in passage.text for passage in passage_2)
|
||||
# test safe empty return
|
||||
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000)
|
||||
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
@@ -955,6 +959,14 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
|
||||
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
existing_sources = server.source_manager.list_sources(actor=actor)
|
||||
if len(existing_sources) > 0:
|
||||
for source in existing_sources:
|
||||
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
|
||||
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
assert initial_passage_count == 0
|
||||
|
||||
|
||||
# Create a source
|
||||
source = server.source_manager.create_source(
|
||||
PydanticSource(
|
||||
@@ -973,10 +985,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
# Attach source to agent first
|
||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
|
||||
|
||||
# Get initial passage count
|
||||
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
assert initial_passage_count == 0
|
||||
|
||||
# Create a job for loading the first file
|
||||
job = server.job_manager.create_job(
|
||||
PydanticJob(
|
||||
@@ -1001,7 +1009,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert job.metadata_["num_documents"] == 1
|
||||
|
||||
# Verify passages were added
|
||||
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
assert first_file_passage_count > initial_passage_count
|
||||
|
||||
# Create a second test file with different content
|
||||
@@ -1032,14 +1040,13 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert job2.metadata_["num_documents"] == 1
|
||||
|
||||
# Verify passages were appended (not replaced)
|
||||
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
assert final_passage_count > first_file_passage_count
|
||||
|
||||
# Verify both old and new content is searchable
|
||||
passages = server.passage_manager.list_passages(
|
||||
actor=actor,
|
||||
passages = server.agent_manager.list_passages(
|
||||
agent_id=agent_id,
|
||||
source_id=source.id,
|
||||
actor=actor,
|
||||
query_text="what does Timber like to eat",
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
embed_query=True,
|
||||
@@ -1048,35 +1055,27 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert any("chicken" in passage.text.lower() for passage in passages)
|
||||
assert any("Anna".lower() in passage.text.lower() for passage in passages)
|
||||
|
||||
# TODO: Add this test back in after separation of `Passage tables` (LET-449)
|
||||
# # Load second agent
|
||||
# agent2 = server.load_agent(agent_id=other_agent_id)
|
||||
# Initially should have no passages
|
||||
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
|
||||
assert initial_agent2_passages == 0
|
||||
|
||||
# # Initially should have no passages
|
||||
# initial_agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
|
||||
# assert initial_agent2_passages == 0
|
||||
# Attach source to second agent
|
||||
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)
|
||||
|
||||
# # Attach source to second agent
|
||||
# agent2.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)
|
||||
# Verify second agent has same number of passages as first agent
|
||||
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
|
||||
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
|
||||
assert agent2_passages == agent1_passages
|
||||
|
||||
# # Verify second agent has same number of passages as first agent
|
||||
# agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
|
||||
# agent1_passages = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
|
||||
# assert agent2_passages == agent1_passages
|
||||
|
||||
# # Verify second agent can query the same content
|
||||
# passages2 = server.passage_manager.list_passages(
|
||||
# actor=user,
|
||||
# agent_id=other_agent_id,
|
||||
# source_id=source.id,
|
||||
# query_text="what does Timber like to eat",
|
||||
# embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
# embed_query=True,
|
||||
# limit=10,
|
||||
# )
|
||||
# assert len(passages2) == len(passages)
|
||||
# assert any("chicken" in passage.text.lower() for passage in passages2)
|
||||
# assert any("sleep" in passage.text.lower() for passage in passages2)
|
||||
|
||||
# # Cleanup
|
||||
# server.delete_agent(user_id=user_id, agent_id=agent2_state.id)
|
||||
# Verify second agent can query the same content
|
||||
passages2 = server.agent_manager.list_passages(
|
||||
actor=actor,
|
||||
agent_id=other_agent_id,
|
||||
source_id=source.id,
|
||||
query_text="what does Timber like to eat",
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
embed_query=True,
|
||||
)
|
||||
assert len(passages2) == len(passages)
|
||||
assert any("chicken" in passage.text.lower() for passage in passages2)
|
||||
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
|
||||
|
||||
Reference in New Issue
Block a user