feat: separate Passages tables (#2245)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
105
alembic/versions/54dec07619c4_divide_passage_table_into_.py
Normal file
105
alembic/versions/54dec07619c4_divide_passage_table_into_.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""divide passage table into SourcePassages and AgentPassages
|
||||||
|
|
||||||
|
Revision ID: 54dec07619c4
|
||||||
|
Revises: 4e88e702f85e
|
||||||
|
Create Date: 2024-12-14 17:23:08.772554
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
from letta.orm.custom_columns import EmbeddingConfigColumn
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '54dec07619c4'
|
||||||
|
down_revision: Union[str, None] = '4e88e702f85e'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
'agent_passages',
|
||||||
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
|
sa.Column('text', sa.String(), nullable=False),
|
||||||
|
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
|
||||||
|
sa.Column('metadata_', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('embedding', Vector(dim=4096), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||||
|
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('organization_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('agent_id', sa.String(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index('agent_passages_org_idx', 'agent_passages', ['organization_id'], unique=False)
|
||||||
|
op.create_table(
|
||||||
|
'source_passages',
|
||||||
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
|
sa.Column('text', sa.String(), nullable=False),
|
||||||
|
sa.Column('embedding_config', EmbeddingConfigColumn(), nullable=False),
|
||||||
|
sa.Column('metadata_', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('embedding', Vector(dim=4096), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||||
|
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||||
|
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('organization_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('file_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('source_id', sa.String(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['file_id'], ['files.id'], ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['source_id'], ['sources.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index('source_passages_org_idx', 'source_passages', ['organization_id'], unique=False)
|
||||||
|
op.drop_table('passages')
|
||||||
|
op.drop_constraint('files_source_id_fkey', 'files', type_='foreignkey')
|
||||||
|
op.create_foreign_key(None, 'files', 'sources', ['source_id'], ['id'], ondelete='CASCADE')
|
||||||
|
op.drop_constraint('messages_agent_id_fkey', 'messages', type_='foreignkey')
|
||||||
|
op.create_foreign_key(None, 'messages', 'agents', ['agent_id'], ['id'], ondelete='CASCADE')
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint(None, 'messages', type_='foreignkey')
|
||||||
|
op.create_foreign_key('messages_agent_id_fkey', 'messages', 'agents', ['agent_id'], ['id'])
|
||||||
|
op.drop_constraint(None, 'files', type_='foreignkey')
|
||||||
|
op.create_foreign_key('files_source_id_fkey', 'files', 'sources', ['source_id'], ['id'])
|
||||||
|
op.create_table(
|
||||||
|
'passages',
|
||||||
|
sa.Column('id', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('text', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('file_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('agent_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('source_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('embedding', Vector(dim=4096), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('embedding_config', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('metadata_', postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), server_default=sa.text('now()'), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('is_deleted', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
|
||||||
|
sa.Column('_created_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('_last_updated_by_id', sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('organization_id', sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], name='passages_agent_id_fkey'),
|
||||||
|
sa.ForeignKeyConstraint(['file_id'], ['files.id'], name='passages_file_id_fkey', ondelete='CASCADE'),
|
||||||
|
sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], name='passages_organization_id_fkey'),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='passages_pkey')
|
||||||
|
)
|
||||||
|
op.drop_index('source_passages_org_idx', table_name='source_passages')
|
||||||
|
op.drop_table('source_passages')
|
||||||
|
op.drop_index('agent_passages_org_idx', table_name='agent_passages')
|
||||||
|
op.drop_table('agent_passages')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -41,7 +41,6 @@ from letta.schemas.openai.chat_completion_response import (
|
|||||||
Message as ChatCompletionMessage,
|
Message as ChatCompletionMessage,
|
||||||
)
|
)
|
||||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||||
from letta.schemas.passage import Passage
|
|
||||||
from letta.schemas.tool import Tool
|
from letta.schemas.tool import Tool
|
||||||
from letta.schemas.tool_rule import TerminalToolRule
|
from letta.schemas.tool_rule import TerminalToolRule
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
@@ -82,7 +81,7 @@ def compile_memory_metadata_block(
|
|||||||
actor: PydanticUser,
|
actor: PydanticUser,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
memory_edit_timestamp: datetime.datetime,
|
memory_edit_timestamp: datetime.datetime,
|
||||||
passage_manager: Optional[PassageManager] = None,
|
agent_manager: Optional[AgentManager] = None,
|
||||||
message_manager: Optional[MessageManager] = None,
|
message_manager: Optional[MessageManager] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||||
@@ -93,7 +92,7 @@ def compile_memory_metadata_block(
|
|||||||
[
|
[
|
||||||
f"### Memory [last modified: {timestamp_str}]",
|
f"### Memory [last modified: {timestamp_str}]",
|
||||||
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||||
f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
|
f"{agent_manager.passage_size(actor=actor, agent_id=agent_id) if agent_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -106,7 +105,7 @@ def compile_system_message(
|
|||||||
in_context_memory: Memory,
|
in_context_memory: Memory,
|
||||||
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
||||||
actor: PydanticUser,
|
actor: PydanticUser,
|
||||||
passage_manager: Optional[PassageManager] = None,
|
agent_manager: Optional[AgentManager] = None,
|
||||||
message_manager: Optional[MessageManager] = None,
|
message_manager: Optional[MessageManager] = None,
|
||||||
user_defined_variables: Optional[dict] = None,
|
user_defined_variables: Optional[dict] = None,
|
||||||
append_icm_if_missing: bool = True,
|
append_icm_if_missing: bool = True,
|
||||||
@@ -135,7 +134,7 @@ def compile_system_message(
|
|||||||
actor=actor,
|
actor=actor,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
memory_edit_timestamp=in_context_memory_last_edit,
|
memory_edit_timestamp=in_context_memory_last_edit,
|
||||||
passage_manager=passage_manager,
|
agent_manager=agent_manager,
|
||||||
message_manager=message_manager,
|
message_manager=message_manager,
|
||||||
)
|
)
|
||||||
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
||||||
@@ -172,7 +171,7 @@ def initialize_message_sequence(
|
|||||||
agent_id: str,
|
agent_id: str,
|
||||||
memory: Memory,
|
memory: Memory,
|
||||||
actor: PydanticUser,
|
actor: PydanticUser,
|
||||||
passage_manager: Optional[PassageManager] = None,
|
agent_manager: Optional[AgentManager] = None,
|
||||||
message_manager: Optional[MessageManager] = None,
|
message_manager: Optional[MessageManager] = None,
|
||||||
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
||||||
include_initial_boot_message: bool = True,
|
include_initial_boot_message: bool = True,
|
||||||
@@ -181,7 +180,7 @@ def initialize_message_sequence(
|
|||||||
memory_edit_timestamp = get_local_time()
|
memory_edit_timestamp = get_local_time()
|
||||||
|
|
||||||
# full_system_message = construct_system_with_memory(
|
# full_system_message = construct_system_with_memory(
|
||||||
# system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory
|
# system, memory, memory_edit_timestamp, agent_manager=agent_manager, recall_memory=recall_memory
|
||||||
# )
|
# )
|
||||||
full_system_message = compile_system_message(
|
full_system_message = compile_system_message(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
@@ -189,7 +188,7 @@ def initialize_message_sequence(
|
|||||||
in_context_memory=memory,
|
in_context_memory=memory,
|
||||||
in_context_memory_last_edit=memory_edit_timestamp,
|
in_context_memory_last_edit=memory_edit_timestamp,
|
||||||
actor=actor,
|
actor=actor,
|
||||||
passage_manager=passage_manager,
|
agent_manager=agent_manager,
|
||||||
message_manager=message_manager,
|
message_manager=message_manager,
|
||||||
user_defined_variables=None,
|
user_defined_variables=None,
|
||||||
append_icm_if_missing=True,
|
append_icm_if_missing=True,
|
||||||
@@ -291,8 +290,9 @@ class Agent(BaseAgent):
|
|||||||
self.interface = interface
|
self.interface = interface
|
||||||
|
|
||||||
# Create the persistence manager object based on the AgentState info
|
# Create the persistence manager object based on the AgentState info
|
||||||
self.passage_manager = PassageManager()
|
|
||||||
self.message_manager = MessageManager()
|
self.message_manager = MessageManager()
|
||||||
|
self.passage_manager = PassageManager()
|
||||||
|
self.agent_manager = AgentManager()
|
||||||
|
|
||||||
# State needed for heartbeat pausing
|
# State needed for heartbeat pausing
|
||||||
self.pause_heartbeats_start = None
|
self.pause_heartbeats_start = None
|
||||||
@@ -322,7 +322,7 @@ class Agent(BaseAgent):
|
|||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
memory=self.agent_state.memory,
|
memory=self.agent_state.memory,
|
||||||
actor=self.user,
|
actor=self.user,
|
||||||
passage_manager=None,
|
agent_manager=None,
|
||||||
message_manager=None,
|
message_manager=None,
|
||||||
memory_edit_timestamp=get_utc_time(),
|
memory_edit_timestamp=get_utc_time(),
|
||||||
include_initial_boot_message=True,
|
include_initial_boot_message=True,
|
||||||
@@ -347,7 +347,7 @@ class Agent(BaseAgent):
|
|||||||
memory=self.agent_state.memory,
|
memory=self.agent_state.memory,
|
||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
actor=self.user,
|
actor=self.user,
|
||||||
passage_manager=None,
|
agent_manager=None,
|
||||||
message_manager=None,
|
message_manager=None,
|
||||||
memory_edit_timestamp=get_utc_time(),
|
memory_edit_timestamp=get_utc_time(),
|
||||||
include_initial_boot_message=True,
|
include_initial_boot_message=True,
|
||||||
@@ -1297,7 +1297,7 @@ class Agent(BaseAgent):
|
|||||||
in_context_memory=self.agent_state.memory,
|
in_context_memory=self.agent_state.memory,
|
||||||
in_context_memory_last_edit=memory_edit_timestamp,
|
in_context_memory_last_edit=memory_edit_timestamp,
|
||||||
actor=self.user,
|
actor=self.user,
|
||||||
passage_manager=self.passage_manager,
|
agent_manager=self.agent_manager,
|
||||||
message_manager=self.message_manager,
|
message_manager=self.message_manager,
|
||||||
user_defined_variables=None,
|
user_defined_variables=None,
|
||||||
append_icm_if_missing=True,
|
append_icm_if_missing=True,
|
||||||
@@ -1368,33 +1368,24 @@ class Agent(BaseAgent):
|
|||||||
source_id: str,
|
source_id: str,
|
||||||
source_manager: SourceManager,
|
source_manager: SourceManager,
|
||||||
agent_manager: AgentManager,
|
agent_manager: AgentManager,
|
||||||
page_size: Optional[int] = None,
|
|
||||||
):
|
):
|
||||||
"""Attach data with name `source_name` to the agent from source_connector."""
|
"""Attach a source to the agent using the SourcesAgents ORM relationship.
|
||||||
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
|
|
||||||
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)
|
Args:
|
||||||
|
user: User performing the action
|
||||||
for passage in passages:
|
source_id: ID of the source to attach
|
||||||
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
|
source_manager: SourceManager instance to verify source exists
|
||||||
passage.agent_id = self.agent_state.id
|
agent_manager: AgentManager instance to manage agent-source relationship
|
||||||
self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user)
|
"""
|
||||||
|
# Verify source exists and user has permission to access it
|
||||||
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
|
|
||||||
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
|
|
||||||
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
|
|
||||||
assert len(agents_passages) == passage_size # sanity check
|
|
||||||
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"
|
|
||||||
|
|
||||||
# attach to agent
|
|
||||||
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
|
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
|
||||||
assert source is not None, f"Source {source_id} not found in metadata store"
|
assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})"
|
||||||
|
|
||||||
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
|
# Use the agent_manager to create the relationship
|
||||||
# TODO: delete @matt and remove
|
|
||||||
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)
|
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)
|
||||||
|
|
||||||
printd(
|
printd(
|
||||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
|
f"Attached data source {source.name} to agent {self.agent_state.name}.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
|
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
|
||||||
@@ -1550,13 +1541,13 @@ class Agent(BaseAgent):
|
|||||||
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
|
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id)
|
||||||
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||||
external_memory_summary = compile_memory_metadata_block(
|
external_memory_summary = compile_memory_metadata_block(
|
||||||
actor=self.user,
|
actor=self.user,
|
||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
memory_edit_timestamp=get_utc_time(), # dummy timestamp
|
memory_edit_timestamp=get_utc_time(), # dummy timestamp
|
||||||
passage_manager=self.passage_manager,
|
agent_manager=self.agent_manager,
|
||||||
message_manager=self.message_manager,
|
message_manager=self.message_manager,
|
||||||
)
|
)
|
||||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||||
@@ -1582,7 +1573,7 @@ class Agent(BaseAgent):
|
|||||||
return ContextWindowOverview(
|
return ContextWindowOverview(
|
||||||
# context window breakdown (in messages)
|
# context window breakdown (in messages)
|
||||||
num_messages=len(self._messages),
|
num_messages=len(self._messages),
|
||||||
num_archival_memory=passage_manager_size,
|
num_archival_memory=agent_manager_passage_size,
|
||||||
num_recall_memory=message_manager_size,
|
num_recall_memory=message_manager_size,
|
||||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||||
# top-level information
|
# top-level information
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from letta.agent import Agent
|
from letta.agent import Agent
|
||||||
from letta.constants import MAX_PAUSE_HEARTBEATS
|
from letta.constants import MAX_PAUSE_HEARTBEATS
|
||||||
|
from letta.services.agent_manager import AgentManager
|
||||||
|
|
||||||
# import math
|
# import math
|
||||||
# from letta.utils import json_dumps
|
# from letta.utils import json_dumps
|
||||||
@@ -200,8 +201,9 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get results using passage manager
|
# Get results using passage manager
|
||||||
all_results = self.passage_manager.list_passages(
|
all_results = self.agent_manager.list_passages(
|
||||||
actor=self.user,
|
actor=self.user,
|
||||||
|
agent_id=self.agent_state.id,
|
||||||
query_text=query,
|
query_text=query,
|
||||||
limit=count + start, # Request enough results to handle offset
|
limit=count + start, # Request enough results to handle offset
|
||||||
embedding_config=self.agent_state.embedding_config,
|
embedding_config=self.agent_state.embedding_config,
|
||||||
|
|||||||
@@ -312,11 +312,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
|||||||
for param in sig.parameters.values():
|
for param in sig.parameters.values():
|
||||||
# Exclude 'self' parameter
|
# Exclude 'self' parameter
|
||||||
# TODO: eventually remove this (only applies to BASE_TOOLS)
|
# TODO: eventually remove this (only applies to BASE_TOOLS)
|
||||||
if param.name == "self":
|
if param.name in ["self", "agent_state"]: # Add agent_manager to excluded
|
||||||
continue
|
|
||||||
|
|
||||||
# exclude 'agent_state' parameter
|
|
||||||
if param.name == "agent_state":
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Assert that the parameter has a type annotation
|
# Assert that the parameter has a type annotation
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from letta.orm.file import FileMetadata
|
|||||||
from letta.orm.job import Job
|
from letta.orm.job import Job
|
||||||
from letta.orm.message import Message
|
from letta.orm.message import Message
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
from letta.orm.passage import Passage
|
from letta.orm.passage import BasePassage, AgentPassage, SourcePassage
|
||||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||||
from letta.orm.source import Source
|
from letta.orm.source import Source
|
||||||
from letta.orm.sources_agents import SourcesAgents
|
from letta.orm.sources_agents import SourcesAgents
|
||||||
|
|||||||
@@ -82,7 +82,25 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
|||||||
lazy="selectin",
|
lazy="selectin",
|
||||||
doc="Tags associated with the agent.",
|
doc="Tags associated with the agent.",
|
||||||
)
|
)
|
||||||
# passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="agent", lazy="selectin")
|
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||||
|
"SourcePassage",
|
||||||
|
secondary="sources_agents", # The join table for Agent -> Source
|
||||||
|
primaryjoin="Agent.id == sources_agents.c.agent_id",
|
||||||
|
secondaryjoin="and_(SourcePassage.source_id == sources_agents.c.source_id)",
|
||||||
|
lazy="selectin",
|
||||||
|
order_by="SourcePassage.created_at.desc()",
|
||||||
|
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
|
||||||
|
doc="All passages derived from sources associated with this agent.",
|
||||||
|
)
|
||||||
|
agent_passages: Mapped[List["AgentPassage"]] = relationship(
|
||||||
|
"AgentPassage",
|
||||||
|
back_populates="agent",
|
||||||
|
lazy="selectin",
|
||||||
|
order_by="AgentPassage.created_at.desc()",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
|
||||||
|
doc="All passages derived created by this agent.",
|
||||||
|
)
|
||||||
|
|
||||||
def to_pydantic(self) -> PydanticAgentState:
|
def to_pydantic(self) -> PydanticAgentState:
|
||||||
"""converts to the basic pydantic model counterpart"""
|
"""converts to the basic pydantic model counterpart"""
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
|
from letta.orm.source import Source
|
||||||
|
from letta.orm.passage import SourcePassage
|
||||||
|
|
||||||
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
||||||
"""Represents metadata for an uploaded file."""
|
"""Represents metadata for an uploaded file."""
|
||||||
@@ -27,4 +28,4 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
|||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
||||||
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
|
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
|
||||||
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
|
source_passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
|
||||||
|
|||||||
@@ -31,30 +31,19 @@ class UserMixin(Base):
|
|||||||
|
|
||||||
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
||||||
|
|
||||||
class FileMixin(Base):
|
|
||||||
"""Mixin for models that belong to a file."""
|
|
||||||
|
|
||||||
__abstract__ = True
|
|
||||||
|
|
||||||
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id"))
|
|
||||||
|
|
||||||
class AgentMixin(Base):
|
class AgentMixin(Base):
|
||||||
"""Mixin for models that belong to an agent."""
|
"""Mixin for models that belong to an agent."""
|
||||||
|
|
||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
|
|
||||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"))
|
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"))
|
||||||
|
|
||||||
class FileMixin(Base):
|
class FileMixin(Base):
|
||||||
"""Mixin for models that belong to a file."""
|
"""Mixin for models that belong to a file."""
|
||||||
|
|
||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
|
|
||||||
file_id: Mapped[Optional[str]] = mapped_column(
|
file_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("files.id", ondelete="CASCADE"))
|
||||||
String,
|
|
||||||
ForeignKey("files.id", ondelete="CASCADE"),
|
|
||||||
nullable=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SourceMixin(Base):
|
class SourceMixin(Base):
|
||||||
@@ -62,7 +51,7 @@ class SourceMixin(Base):
|
|||||||
|
|
||||||
__abstract__ = True
|
__abstract__ = True
|
||||||
|
|
||||||
source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id"))
|
source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class SandboxConfigMixin(Base):
|
class SandboxConfigMixin(Base):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@@ -35,6 +35,22 @@ class Organization(SqlalchemyBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# relationships
|
# relationships
|
||||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
|
||||||
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||||
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
|
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
||||||
|
source_passages: Mapped[List["SourcePassage"]] = relationship(
|
||||||
|
"SourcePassage",
|
||||||
|
back_populates="organization",
|
||||||
|
cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
agent_passages: Mapped[List["AgentPassage"]] = relationship(
|
||||||
|
"AgentPassage",
|
||||||
|
back_populates="organization",
|
||||||
|
cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def passages(self) -> List[Union["SourcePassage", "AgentPassage"]]:
|
||||||
|
"""Convenience property to get all passages"""
|
||||||
|
return self.source_passages + self.agent_passages
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,39 +1,35 @@
|
|||||||
from datetime import datetime
|
from typing import TYPE_CHECKING
|
||||||
from typing import TYPE_CHECKING, Optional
|
from sqlalchemy import Column, JSON, Index
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr
|
||||||
|
|
||||||
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
|
from letta.orm.mixins import FileMixin, OrganizationMixin
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
|
||||||
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||||
|
from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin
|
||||||
|
from letta.schemas.passage import Passage as PydanticPassage
|
||||||
|
from letta.settings import settings
|
||||||
|
|
||||||
from letta.config import LettaConfig
|
from letta.config import LettaConfig
|
||||||
from letta.constants import MAX_EMBEDDING_DIM
|
from letta.constants import MAX_EMBEDDING_DIM
|
||||||
from letta.orm.custom_columns import CommonVector
|
|
||||||
from letta.orm.mixins import FileMixin, OrganizationMixin
|
|
||||||
from letta.orm.source import EmbeddingConfigColumn
|
|
||||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
||||||
from letta.schemas.passage import Passage as PydanticPassage
|
|
||||||
from letta.settings import settings
|
|
||||||
|
|
||||||
config = LettaConfig()
|
config = LettaConfig()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
|
from letta.orm.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
# TODO: After migration to Passage, will need to manually delete passages where files
|
class BasePassage(SqlalchemyBase, OrganizationMixin):
|
||||||
# are deleted on web
|
"""Base class for all passage types with common fields"""
|
||||||
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
__abstract__ = True
|
||||||
"""Defines data model for storing Passages"""
|
|
||||||
|
|
||||||
__tablename__ = "passages"
|
|
||||||
__table_args__ = {"extend_existing": True}
|
|
||||||
__pydantic_model__ = PydanticPassage
|
__pydantic_model__ = PydanticPassage
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
|
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
|
||||||
text: Mapped[str] = mapped_column(doc="Passage text content")
|
text: Mapped[str] = mapped_column(doc="Passage text content")
|
||||||
source_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Source identifier")
|
|
||||||
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
|
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
|
||||||
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
|
||||||
|
# Vector embedding field based on database type
|
||||||
if settings.letta_pg_uri_no_default:
|
if settings.letta_pg_uri_no_default:
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
|
|
||||||
@@ -41,9 +37,49 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
|||||||
else:
|
else:
|
||||||
embedding = Column(CommonVector)
|
embedding = Column(CommonVector)
|
||||||
|
|
||||||
# Foreign keys
|
@declared_attr
|
||||||
agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id"), nullable=True)
|
def organization(cls) -> Mapped["Organization"]:
|
||||||
|
"""Relationship to organization"""
|
||||||
|
return relationship("Organization", back_populates="passages", lazy="selectin")
|
||||||
|
|
||||||
# Relationships
|
@declared_attr
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="passages", lazy="selectin")
|
def __table_args__(cls):
|
||||||
file: Mapped["FileMetadata"] = relationship("FileMetadata", back_populates="passages", lazy="selectin")
|
if settings.letta_pg_uri_no_default:
|
||||||
|
return (
|
||||||
|
Index(f'{cls.__tablename__}_org_idx', 'organization_id'),
|
||||||
|
{"extend_existing": True}
|
||||||
|
)
|
||||||
|
return ({"extend_existing": True},)
|
||||||
|
|
||||||
|
|
||||||
|
class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||||
|
"""Passages derived from external files/sources"""
|
||||||
|
__tablename__ = "source_passages"
|
||||||
|
|
||||||
|
@declared_attr
|
||||||
|
def file(cls) -> Mapped["FileMetadata"]:
|
||||||
|
"""Relationship to file"""
|
||||||
|
return relationship("FileMetadata", back_populates="source_passages", lazy="selectin")
|
||||||
|
|
||||||
|
@declared_attr
|
||||||
|
def organization(cls) -> Mapped["Organization"]:
|
||||||
|
return relationship("Organization", back_populates="source_passages", lazy="selectin")
|
||||||
|
|
||||||
|
@declared_attr
|
||||||
|
def source(cls) -> Mapped["Source"]:
|
||||||
|
"""Relationship to source"""
|
||||||
|
return relationship("Source", back_populates="passages", lazy="selectin", passive_deletes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentPassage(BasePassage, AgentMixin):
|
||||||
|
"""Passages created by agents as archival memories"""
|
||||||
|
__tablename__ = "agent_passages"
|
||||||
|
|
||||||
|
@declared_attr
|
||||||
|
def organization(cls) -> Mapped["Organization"]:
|
||||||
|
return relationship("Organization", back_populates="agent_passages", lazy="selectin")
|
||||||
|
|
||||||
|
@declared_attr
|
||||||
|
def agent(cls) -> Mapped["Agent"]:
|
||||||
|
"""Relationship to agent"""
|
||||||
|
return relationship("Agent", back_populates="agent_passages", lazy="selectin", passive_deletes=True)
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ from letta.schemas.source import Source as PydanticSource
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.orm.organization import Organization
|
from letta.orm.organization import Organization
|
||||||
|
from letta.orm.file import FileMetadata
|
||||||
|
from letta.orm.passage import SourcePassage
|
||||||
|
from letta.orm.agent import Agent
|
||||||
|
|
||||||
|
|
||||||
class Source(SqlalchemyBase, OrganizationMixin):
|
class Source(SqlalchemyBase, OrganizationMixin):
|
||||||
@@ -28,4 +31,5 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
|||||||
# relationships
|
# relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
|
||||||
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
|
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
|
||||||
|
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
|
||||||
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
|
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from enum import Enum
|
|||||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||||
|
|
||||||
from sqlalchemy import String, desc, func, or_, select
|
from sqlalchemy import String, desc, func, or_, select
|
||||||
from sqlalchemy.exc import DBAPIError
|
from sqlalchemy.exc import DBAPIError, IntegrityError
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
|
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
@@ -242,7 +242,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|||||||
session.commit()
|
session.commit()
|
||||||
session.refresh(self)
|
session.refresh(self)
|
||||||
return self
|
return self
|
||||||
except DBAPIError as e:
|
except (DBAPIError, IntegrityError) as e:
|
||||||
self._handle_dbapi_error(e)
|
self._handle_dbapi_error(e)
|
||||||
|
|
||||||
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from letta.utils import get_utc_time
|
|||||||
|
|
||||||
|
|
||||||
class PassageBase(OrmMetadataBase):
|
class PassageBase(OrmMetadataBase):
|
||||||
__id_prefix__ = "passage_legacy"
|
__id_prefix__ = "passage"
|
||||||
|
|
||||||
is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")
|
is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")
|
||||||
|
|
||||||
|
|||||||
@@ -932,7 +932,7 @@ class SyncServer(Server):
|
|||||||
|
|
||||||
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
|
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
|
||||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||||
return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user))
|
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
|
||||||
|
|
||||||
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
|
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
|
||||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||||
@@ -949,18 +949,9 @@ class SyncServer(Server):
|
|||||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||||
|
|
||||||
# Get the agent object (loaded in memory)
|
passages = self.agent_manager.list_passages(agent_id=agent_id, actor=actor)
|
||||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
||||||
|
|
||||||
# iterate over records
|
return passages
|
||||||
records = letta_agent.passage_manager.list_passages(
|
|
||||||
actor=actor,
|
|
||||||
agent_id=agent_id,
|
|
||||||
cursor=cursor,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
return records
|
|
||||||
|
|
||||||
def get_agent_archival_cursor(
|
def get_agent_archival_cursor(
|
||||||
self,
|
self,
|
||||||
@@ -974,15 +965,13 @@ class SyncServer(Server):
|
|||||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||||
|
|
||||||
# Get the agent object (loaded in memory)
|
|
||||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
||||||
|
|
||||||
# iterate over records
|
# iterate over records
|
||||||
records = letta_agent.passage_manager.list_passages(
|
records = self.agent_manager.list_passages(
|
||||||
actor=self.default_user,
|
actor=actor,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
cursor=cursor,
|
cursor=cursor,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
ascending=not reverse,
|
||||||
)
|
)
|
||||||
return records
|
return records
|
||||||
|
|
||||||
@@ -1098,7 +1087,8 @@ class SyncServer(Server):
|
|||||||
self.source_manager.delete_source(source_id=source_id, actor=actor)
|
self.source_manager.delete_source(source_id=source_id, actor=actor)
|
||||||
|
|
||||||
# delete data from passage store
|
# delete data from passage store
|
||||||
self.passage_manager.delete_passages(actor=actor, limit=None, source_id=source_id)
|
passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None)
|
||||||
|
self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted)
|
||||||
|
|
||||||
# TODO: delete data from agent passage stores (?)
|
# TODO: delete data from agent passage stores (?)
|
||||||
|
|
||||||
@@ -1129,9 +1119,11 @@ class SyncServer(Server):
|
|||||||
for agent_state in agent_states:
|
for agent_state in agent_states:
|
||||||
agent_id = agent_state.id
|
agent_id = agent_state.id
|
||||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||||
curr_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
|
|
||||||
|
# Attach source to agent
|
||||||
|
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
||||||
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
|
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
|
||||||
new_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
|
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
|
||||||
assert new_passage_size >= curr_passage_size # in case empty files are added
|
assert new_passage_size >= curr_passage_size # in case empty files are added
|
||||||
|
|
||||||
return job
|
return job
|
||||||
@@ -1195,14 +1187,9 @@ class SyncServer(Server):
|
|||||||
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||||
elif source_name:
|
elif source_name:
|
||||||
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
||||||
|
source_id = source.id
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
||||||
source_id = source.id
|
|
||||||
|
|
||||||
# TODO: This should be done with the ORM?
|
|
||||||
# delete all Passage objects with source_id==source_id from agent's archival memory
|
|
||||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
|
||||||
agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id)
|
|
||||||
|
|
||||||
# delete agent-source mapping
|
# delete agent-source mapping
|
||||||
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||||
@@ -1224,7 +1211,7 @@ class SyncServer(Server):
|
|||||||
for source in sources:
|
for source in sources:
|
||||||
|
|
||||||
# count number of passages
|
# count number of passages
|
||||||
num_passages = self.passage_manager.size(actor=actor, source_id=source.id)
|
num_passages = self.agent_manager.passage_size(actor=actor, source_id=source.id)
|
||||||
|
|
||||||
# TODO: add when files table implemented
|
# TODO: add when files table implemented
|
||||||
## count number of files
|
## count number of files
|
||||||
|
|||||||
@@ -1,17 +1,26 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
from sqlalchemy import select, union_all, literal, func, Select
|
||||||
|
|
||||||
|
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
|
||||||
|
from letta.embeddings import embedding_model
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
from letta.orm import Agent as AgentModel
|
from letta.orm import Agent as AgentModel
|
||||||
from letta.orm import Block as BlockModel
|
from letta.orm import Block as BlockModel
|
||||||
from letta.orm import Source as SourceModel
|
from letta.orm import Source as SourceModel
|
||||||
from letta.orm import Tool as ToolModel
|
from letta.orm import Tool as ToolModel
|
||||||
|
from letta.orm import AgentPassage, SourcePassage
|
||||||
|
from letta.orm import SourcesAgents
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
|
from letta.orm.sqlite_functions import adapt_array
|
||||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||||
from letta.schemas.block import Block as PydanticBlock
|
from letta.schemas.block import Block as PydanticBlock
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
|
from letta.schemas.passage import Passage as PydanticPassage
|
||||||
from letta.schemas.source import Source as PydanticSource
|
from letta.schemas.source import Source as PydanticSource
|
||||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||||
from letta.schemas.user import User as PydanticUser
|
from letta.schemas.user import User as PydanticUser
|
||||||
@@ -21,9 +30,9 @@ from letta.services.helpers.agent_manager_helper import (
|
|||||||
_process_tags,
|
_process_tags,
|
||||||
derive_system_message,
|
derive_system_message,
|
||||||
)
|
)
|
||||||
from letta.services.passage_manager import PassageManager
|
|
||||||
from letta.services.source_manager import SourceManager
|
from letta.services.source_manager import SourceManager
|
||||||
from letta.services.tool_manager import ToolManager
|
from letta.services.tool_manager import ToolManager
|
||||||
|
from letta.settings import settings
|
||||||
from letta.utils import enforce_types
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -229,13 +238,6 @@ class AgentManager:
|
|||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
# Retrieve the agent
|
# Retrieve the agent
|
||||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||||
|
|
||||||
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
|
|
||||||
# TODO: This is done very hacky on purpose
|
|
||||||
# TODO: 1000 limit is also wack
|
|
||||||
passage_manager = PassageManager()
|
|
||||||
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
|
|
||||||
|
|
||||||
agent_state = agent.to_pydantic()
|
agent_state = agent.to_pydantic()
|
||||||
agent.hard_delete(session)
|
agent.hard_delete(session)
|
||||||
return agent_state
|
return agent_state
|
||||||
@@ -407,6 +409,262 @@ class AgentManager:
|
|||||||
agent.update(session, actor=actor)
|
agent.update(session, actor=actor)
|
||||||
return agent.to_pydantic()
|
return agent.to_pydantic()
|
||||||
|
|
||||||
|
# ======================================================================================================================
|
||||||
|
# Passage Management
|
||||||
|
# ======================================================================================================================
|
||||||
|
def _build_passage_query(
|
||||||
|
self,
|
||||||
|
actor: PydanticUser,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
file_id: Optional[str] = None,
|
||||||
|
query_text: Optional[str] = None,
|
||||||
|
start_date: Optional[datetime] = None,
|
||||||
|
end_date: Optional[datetime] = None,
|
||||||
|
cursor: Optional[str] = None,
|
||||||
|
source_id: Optional[str] = None,
|
||||||
|
embed_query: bool = False,
|
||||||
|
ascending: bool = True,
|
||||||
|
embedding_config: Optional[EmbeddingConfig] = None,
|
||||||
|
agent_only: bool = False,
|
||||||
|
) -> Select:
|
||||||
|
"""Helper function to build the base passage query with all filters applied.
|
||||||
|
|
||||||
|
Returns the query before any limit or count operations are applied.
|
||||||
|
"""
|
||||||
|
embedded_text = None
|
||||||
|
if embed_query:
|
||||||
|
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
||||||
|
assert query_text is not None, "query_text must be specified for vector search"
|
||||||
|
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
||||||
|
embedded_text = np.array(embedded_text)
|
||||||
|
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||||
|
|
||||||
|
with self.session_maker() as session:
|
||||||
|
# Start with base query for source passages
|
||||||
|
source_passages = None
|
||||||
|
if not agent_only: # Include source passages
|
||||||
|
if agent_id is not None:
|
||||||
|
source_passages = (
|
||||||
|
select(
|
||||||
|
SourcePassage,
|
||||||
|
literal(None).label('agent_id')
|
||||||
|
)
|
||||||
|
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
||||||
|
.where(SourcesAgents.agent_id == agent_id)
|
||||||
|
.where(SourcePassage.organization_id == actor.organization_id)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
source_passages = (
|
||||||
|
select(
|
||||||
|
SourcePassage,
|
||||||
|
literal(None).label('agent_id')
|
||||||
|
)
|
||||||
|
.where(SourcePassage.organization_id == actor.organization_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_id:
|
||||||
|
source_passages = source_passages.where(SourcePassage.source_id == source_id)
|
||||||
|
if file_id:
|
||||||
|
source_passages = source_passages.where(SourcePassage.file_id == file_id)
|
||||||
|
|
||||||
|
# Add agent passages query
|
||||||
|
agent_passages = None
|
||||||
|
if agent_id is not None:
|
||||||
|
agent_passages = (
|
||||||
|
select(
|
||||||
|
AgentPassage.id,
|
||||||
|
AgentPassage.text,
|
||||||
|
AgentPassage.embedding_config,
|
||||||
|
AgentPassage.metadata_,
|
||||||
|
AgentPassage.embedding,
|
||||||
|
AgentPassage.created_at,
|
||||||
|
AgentPassage.updated_at,
|
||||||
|
AgentPassage.is_deleted,
|
||||||
|
AgentPassage._created_by_id,
|
||||||
|
AgentPassage._last_updated_by_id,
|
||||||
|
AgentPassage.organization_id,
|
||||||
|
literal(None).label('file_id'),
|
||||||
|
literal(None).label('source_id'),
|
||||||
|
AgentPassage.agent_id
|
||||||
|
)
|
||||||
|
.where(AgentPassage.agent_id == agent_id)
|
||||||
|
.where(AgentPassage.organization_id == actor.organization_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine queries
|
||||||
|
if source_passages is not None and agent_passages is not None:
|
||||||
|
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
|
||||||
|
elif agent_passages is not None:
|
||||||
|
combined_query = agent_passages.cte('combined_passages')
|
||||||
|
elif source_passages is not None:
|
||||||
|
combined_query = source_passages.cte('combined_passages')
|
||||||
|
else:
|
||||||
|
raise ValueError("No passages found")
|
||||||
|
|
||||||
|
# Build main query from combined CTE
|
||||||
|
main_query = select(combined_query)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if start_date:
|
||||||
|
main_query = main_query.where(combined_query.c.created_at >= start_date)
|
||||||
|
if end_date:
|
||||||
|
main_query = main_query.where(combined_query.c.created_at <= end_date)
|
||||||
|
if source_id:
|
||||||
|
main_query = main_query.where(combined_query.c.source_id == source_id)
|
||||||
|
if file_id:
|
||||||
|
main_query = main_query.where(combined_query.c.file_id == file_id)
|
||||||
|
|
||||||
|
# Vector search
|
||||||
|
if embedded_text:
|
||||||
|
if settings.letta_pg_uri_no_default:
|
||||||
|
# PostgreSQL with pgvector
|
||||||
|
main_query = main_query.order_by(
|
||||||
|
combined_query.c.embedding.cosine_distance(embedded_text).asc()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# SQLite with custom vector type
|
||||||
|
query_embedding_binary = adapt_array(embedded_text)
|
||||||
|
if ascending:
|
||||||
|
main_query = main_query.order_by(
|
||||||
|
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||||
|
combined_query.c.created_at.asc(),
|
||||||
|
combined_query.c.id.asc()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
main_query = main_query.order_by(
|
||||||
|
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||||
|
combined_query.c.created_at.desc(),
|
||||||
|
combined_query.c.id.asc()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if query_text:
|
||||||
|
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
|
||||||
|
|
||||||
|
# Handle cursor-based pagination
|
||||||
|
if cursor:
|
||||||
|
cursor_query = select(combined_query.c.created_at).where(
|
||||||
|
combined_query.c.id == cursor
|
||||||
|
).scalar_subquery()
|
||||||
|
|
||||||
|
if ascending:
|
||||||
|
main_query = main_query.where(
|
||||||
|
combined_query.c.created_at > cursor_query
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
main_query = main_query.where(
|
||||||
|
combined_query.c.created_at < cursor_query
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add ordering if not already ordered by similarity
|
||||||
|
if not embed_query:
|
||||||
|
if ascending:
|
||||||
|
main_query = main_query.order_by(
|
||||||
|
combined_query.c.created_at.asc(),
|
||||||
|
combined_query.c.id.asc(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
main_query = main_query.order_by(
|
||||||
|
combined_query.c.created_at.desc(),
|
||||||
|
combined_query.c.id.asc(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return main_query
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def list_passages(
|
||||||
|
self,
|
||||||
|
actor: PydanticUser,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
file_id: Optional[str] = None,
|
||||||
|
limit: Optional[int] = 50,
|
||||||
|
query_text: Optional[str] = None,
|
||||||
|
start_date: Optional[datetime] = None,
|
||||||
|
end_date: Optional[datetime] = None,
|
||||||
|
cursor: Optional[str] = None,
|
||||||
|
source_id: Optional[str] = None,
|
||||||
|
embed_query: bool = False,
|
||||||
|
ascending: bool = True,
|
||||||
|
embedding_config: Optional[EmbeddingConfig] = None,
|
||||||
|
agent_only: bool = False
|
||||||
|
) -> List[PydanticPassage]:
|
||||||
|
"""Lists all passages attached to an agent."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
main_query = self._build_passage_query(
|
||||||
|
actor=actor,
|
||||||
|
agent_id=agent_id,
|
||||||
|
file_id=file_id,
|
||||||
|
query_text=query_text,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
cursor=cursor,
|
||||||
|
source_id=source_id,
|
||||||
|
embed_query=embed_query,
|
||||||
|
ascending=ascending,
|
||||||
|
embedding_config=embedding_config,
|
||||||
|
agent_only=agent_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add limit
|
||||||
|
if limit:
|
||||||
|
main_query = main_query.limit(limit)
|
||||||
|
|
||||||
|
# Execute query
|
||||||
|
results = list(session.execute(main_query))
|
||||||
|
|
||||||
|
passages = []
|
||||||
|
for row in results:
|
||||||
|
data = dict(row._mapping)
|
||||||
|
if data['agent_id'] is not None:
|
||||||
|
# This is an AgentPassage - remove source fields
|
||||||
|
data.pop('source_id', None)
|
||||||
|
data.pop('file_id', None)
|
||||||
|
passage = AgentPassage(**data)
|
||||||
|
else:
|
||||||
|
# This is a SourcePassage - remove agent field
|
||||||
|
data.pop('agent_id', None)
|
||||||
|
passage = SourcePassage(**data)
|
||||||
|
passages.append(passage)
|
||||||
|
|
||||||
|
return [p.to_pydantic() for p in passages]
|
||||||
|
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def passage_size(
|
||||||
|
self,
|
||||||
|
actor: PydanticUser,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
file_id: Optional[str] = None,
|
||||||
|
query_text: Optional[str] = None,
|
||||||
|
start_date: Optional[datetime] = None,
|
||||||
|
end_date: Optional[datetime] = None,
|
||||||
|
cursor: Optional[str] = None,
|
||||||
|
source_id: Optional[str] = None,
|
||||||
|
embed_query: bool = False,
|
||||||
|
ascending: bool = True,
|
||||||
|
embedding_config: Optional[EmbeddingConfig] = None,
|
||||||
|
agent_only: bool = False
|
||||||
|
) -> int:
|
||||||
|
"""Returns the count of passages matching the given criteria."""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
main_query = self._build_passage_query(
|
||||||
|
actor=actor,
|
||||||
|
agent_id=agent_id,
|
||||||
|
file_id=file_id,
|
||||||
|
query_text=query_text,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
cursor=cursor,
|
||||||
|
source_id=source_id,
|
||||||
|
embed_query=embed_query,
|
||||||
|
ascending=ascending,
|
||||||
|
embedding_config=embedding_config,
|
||||||
|
agent_only=agent_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to count query
|
||||||
|
count_query = select(func.count()).select_from(main_query.subquery())
|
||||||
|
return session.scalar(count_query) or 0
|
||||||
|
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
# Tool Management
|
# Tool Management
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from sqlalchemy import select, union_all, literal
|
||||||
|
|
||||||
from letta.constants import MAX_EMBEDDING_DIM
|
from letta.constants import MAX_EMBEDDING_DIM
|
||||||
from letta.embeddings import embedding_model, parse_and_chunk_text
|
from letta.embeddings import embedding_model, parse_and_chunk_text
|
||||||
from letta.orm.errors import NoResultFound
|
from letta.orm.errors import NoResultFound
|
||||||
from letta.orm.passage import Passage as PassageModel
|
from letta.orm.passage import AgentPassage, SourcePassage
|
||||||
from letta.schemas.agent import AgentState
|
from letta.schemas.agent import AgentState
|
||||||
from letta.schemas.embedding_config import EmbeddingConfig
|
from letta.schemas.embedding_config import EmbeddingConfig
|
||||||
from letta.schemas.passage import Passage as PydanticPassage
|
from letta.schemas.passage import Passage as PydanticPassage
|
||||||
@@ -14,6 +15,7 @@ from letta.schemas.user import User as PydanticUser
|
|||||||
from letta.utils import enforce_types
|
from letta.utils import enforce_types
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
"""Manager class to handle business logic related to Passages."""
|
"""Manager class to handle business logic related to Passages."""
|
||||||
|
|
||||||
@@ -26,14 +28,51 @@ class PassageManager:
|
|||||||
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||||
"""Fetch a passage by ID."""
|
"""Fetch a passage by ID."""
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor)
|
# Try source passages first
|
||||||
return passage.to_pydantic()
|
try:
|
||||||
|
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||||
|
return passage.to_pydantic()
|
||||||
|
except NoResultFound:
|
||||||
|
# Try archival passages
|
||||||
|
try:
|
||||||
|
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||||
|
return passage.to_pydantic()
|
||||||
|
except NoResultFound:
|
||||||
|
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||||
"""Create a new passage."""
|
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
||||||
|
# Common fields for both passage types
|
||||||
|
data = pydantic_passage.model_dump()
|
||||||
|
common_fields = {
|
||||||
|
"id": data.get("id"),
|
||||||
|
"text": data["text"],
|
||||||
|
"embedding": data["embedding"],
|
||||||
|
"embedding_config": data["embedding_config"],
|
||||||
|
"organization_id": data["organization_id"],
|
||||||
|
"metadata_": data.get("metadata_", {}),
|
||||||
|
"is_deleted": data.get("is_deleted", False),
|
||||||
|
"created_at": data.get("created_at", datetime.utcnow()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "agent_id" in data and data["agent_id"]:
|
||||||
|
assert not data.get("source_id"), "Passage cannot have both agent_id and source_id"
|
||||||
|
agent_fields = {
|
||||||
|
"agent_id": data["agent_id"],
|
||||||
|
}
|
||||||
|
passage = AgentPassage(**common_fields, **agent_fields)
|
||||||
|
elif "source_id" in data and data["source_id"]:
|
||||||
|
assert not data.get("agent_id"), "Passage cannot have both agent_id and source_id"
|
||||||
|
source_fields = {
|
||||||
|
"source_id": data["source_id"],
|
||||||
|
"file_id": data.get("file_id"),
|
||||||
|
}
|
||||||
|
passage = SourcePassage(**common_fields, **source_fields)
|
||||||
|
else:
|
||||||
|
raise ValueError("Passage must have either agent_id or source_id")
|
||||||
|
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
passage = PassageModel(**pydantic_passage.model_dump())
|
|
||||||
passage.create(session, actor=actor)
|
passage.create(session, actor=actor)
|
||||||
return passage.to_pydantic()
|
return passage.to_pydantic()
|
||||||
|
|
||||||
@@ -93,14 +132,23 @@ class PassageManager:
|
|||||||
raise ValueError("Passage ID must be provided.")
|
raise ValueError("Passage ID must be provided.")
|
||||||
|
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
# Fetch existing message from database
|
# Try source passages first
|
||||||
curr_passage = PassageModel.read(
|
try:
|
||||||
db_session=session,
|
curr_passage = SourcePassage.read(
|
||||||
identifier=passage_id,
|
db_session=session,
|
||||||
actor=actor,
|
identifier=passage_id,
|
||||||
)
|
actor=actor,
|
||||||
if not curr_passage:
|
)
|
||||||
raise ValueError(f"Passage with id {passage_id} does not exist.")
|
except NoResultFound:
|
||||||
|
# Try agent passages
|
||||||
|
try:
|
||||||
|
curr_passage = AgentPassage.read(
|
||||||
|
db_session=session,
|
||||||
|
identifier=passage_id,
|
||||||
|
actor=actor,
|
||||||
|
)
|
||||||
|
except NoResultFound:
|
||||||
|
raise ValueError(f"Passage with id {passage_id} does not exist.")
|
||||||
|
|
||||||
# Update the database record with values from the provided record
|
# Update the database record with values from the provided record
|
||||||
update_data = passage.model_dump(exclude_unset=True, exclude_none=True)
|
update_data = passage.model_dump(exclude_unset=True, exclude_none=True)
|
||||||
@@ -113,104 +161,32 @@ class PassageManager:
|
|||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
|
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||||
"""Delete a passage."""
|
"""Delete a passage from either source or archival passages."""
|
||||||
if not passage_id:
|
if not passage_id:
|
||||||
raise ValueError("Passage ID must be provided.")
|
raise ValueError("Passage ID must be provided.")
|
||||||
|
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
|
# Try source passages first
|
||||||
try:
|
try:
|
||||||
passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor)
|
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||||
passage.hard_delete(session, actor=actor)
|
passage.hard_delete(session, actor=actor)
|
||||||
|
return True
|
||||||
except NoResultFound:
|
except NoResultFound:
|
||||||
raise ValueError(f"Passage with id {passage_id} not found.")
|
# Try archival passages
|
||||||
|
try:
|
||||||
@enforce_types
|
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||||
def list_passages(
|
passage.hard_delete(session, actor=actor)
|
||||||
self,
|
return True
|
||||||
actor: PydanticUser,
|
except NoResultFound:
|
||||||
agent_id: Optional[str] = None,
|
raise NoResultFound(f"Passage with id {passage_id} not found.")
|
||||||
file_id: Optional[str] = None,
|
|
||||||
cursor: Optional[str] = None,
|
|
||||||
limit: Optional[int] = 50,
|
|
||||||
query_text: Optional[str] = None,
|
|
||||||
start_date: Optional[datetime] = None,
|
|
||||||
end_date: Optional[datetime] = None,
|
|
||||||
ascending: bool = True,
|
|
||||||
source_id: Optional[str] = None,
|
|
||||||
embed_query: bool = False,
|
|
||||||
embedding_config: Optional[EmbeddingConfig] = None,
|
|
||||||
) -> List[PydanticPassage]:
|
|
||||||
"""List passages with pagination."""
|
|
||||||
with self.session_maker() as session:
|
|
||||||
filters = {"organization_id": actor.organization_id}
|
|
||||||
if agent_id:
|
|
||||||
filters["agent_id"] = agent_id
|
|
||||||
if file_id:
|
|
||||||
filters["file_id"] = file_id
|
|
||||||
if source_id:
|
|
||||||
filters["source_id"] = source_id
|
|
||||||
|
|
||||||
embedded_text = None
|
|
||||||
if embed_query:
|
|
||||||
assert embedding_config is not None
|
|
||||||
|
|
||||||
# Embed the text
|
|
||||||
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
|
||||||
|
|
||||||
# Pad the embedding with zeros
|
|
||||||
embedded_text = np.array(embedded_text)
|
|
||||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
|
||||||
|
|
||||||
results = PassageModel.list(
|
|
||||||
db_session=session,
|
|
||||||
cursor=cursor,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
limit=limit,
|
|
||||||
ascending=ascending,
|
|
||||||
query_text=query_text if not embedded_text else None,
|
|
||||||
query_embedding=embedded_text,
|
|
||||||
**filters,
|
|
||||||
)
|
|
||||||
return [p.to_pydantic() for p in results]
|
|
||||||
|
|
||||||
@enforce_types
|
|
||||||
def size(self, actor: PydanticUser, agent_id: Optional[str] = None, **kwargs) -> int:
|
|
||||||
"""Get the total count of messages with optional filters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
actor : The user requesting the count
|
|
||||||
agent_id: The agent ID
|
|
||||||
"""
|
|
||||||
with self.session_maker() as session:
|
|
||||||
return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs)
|
|
||||||
|
|
||||||
def delete_passages(
|
def delete_passages(
|
||||||
self,
|
self,
|
||||||
actor: PydanticUser,
|
actor: PydanticUser,
|
||||||
agent_id: Optional[str] = None,
|
passages: List[PydanticPassage],
|
||||||
file_id: Optional[str] = None,
|
|
||||||
start_date: Optional[datetime] = None,
|
|
||||||
end_date: Optional[datetime] = None,
|
|
||||||
limit: Optional[int] = 50,
|
|
||||||
cursor: Optional[str] = None,
|
|
||||||
query_text: Optional[str] = None,
|
|
||||||
source_id: Optional[str] = None,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
||||||
passages = self.list_passages(
|
|
||||||
actor=actor,
|
|
||||||
agent_id=agent_id,
|
|
||||||
file_id=file_id,
|
|
||||||
cursor=cursor,
|
|
||||||
limit=limit,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
query_text=query_text,
|
|
||||||
source_id=source_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: This is very inefficient
|
# TODO: This is very inefficient
|
||||||
# TODO: We should have a base `delete_all_matching_filters`-esque function
|
# TODO: We should have a base `delete_all_matching_filters`-esque function
|
||||||
for passage in passages:
|
for passage in passages:
|
||||||
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
||||||
|
return True
|
||||||
|
|||||||
@@ -482,7 +482,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
|||||||
|
|
||||||
# check agent archival memory size
|
# check agent archival memory size
|
||||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||||
print(archival_memories)
|
|
||||||
assert len(archival_memories) == 0
|
assert len(archival_memories) == 0
|
||||||
|
|
||||||
# load a file into a source (non-blocking job)
|
# load a file into a source (non-blocking job)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import os
|
|||||||
import time
|
import time
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from httpx._transports import default
|
||||||
|
from numpy import source
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
@@ -17,7 +19,8 @@ from letta.orm import (
|
|||||||
Job,
|
Job,
|
||||||
Message,
|
Message,
|
||||||
Organization,
|
Organization,
|
||||||
Passage,
|
AgentPassage,
|
||||||
|
SourcePassage,
|
||||||
SandboxConfig,
|
SandboxConfig,
|
||||||
SandboxEnvironmentVariable,
|
SandboxEnvironmentVariable,
|
||||||
Source,
|
Source,
|
||||||
@@ -82,7 +85,8 @@ def clear_tables(server: SyncServer):
|
|||||||
"""Fixture to clear the organization table before each test."""
|
"""Fixture to clear the organization table before each test."""
|
||||||
with server.organization_manager.session_maker() as session:
|
with server.organization_manager.session_maker() as session:
|
||||||
session.execute(delete(Message))
|
session.execute(delete(Message))
|
||||||
session.execute(delete(Passage))
|
session.execute(delete(AgentPassage))
|
||||||
|
session.execute(delete(SourcePassage))
|
||||||
session.execute(delete(Job))
|
session.execute(delete(Job))
|
||||||
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||||
session.execute(delete(BlocksAgents))
|
session.execute(delete(BlocksAgents))
|
||||||
@@ -189,39 +193,79 @@ def print_tool(server: SyncServer, default_user, default_organization):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def hello_world_passage_fixture(server: SyncServer, default_user, default_file, sarah_agent):
|
def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
"""Fixture to create an agent passage."""
|
||||||
# Set up passage
|
passage = server.passage_manager.create_passage(
|
||||||
dummy_embedding = [0.0] * 2
|
PydanticPassage(
|
||||||
message = PydanticPassage(
|
text="Hello, I am an agent passage",
|
||||||
organization_id=default_user.organization_id,
|
agent_id=sarah_agent.id,
|
||||||
agent_id=sarah_agent.id,
|
organization_id=default_user.organization_id,
|
||||||
file_id=default_file.id,
|
embedding=[0.1],
|
||||||
text="Hello, world!",
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
embedding=dummy_embedding,
|
metadata_={"type": "test"}
|
||||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
),
|
||||||
|
actor=default_user
|
||||||
)
|
)
|
||||||
|
yield passage
|
||||||
msg = server.passage_manager.create_passage(message, actor=default_user)
|
|
||||||
yield msg
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent) -> list[PydanticPassage]:
|
def source_passage_fixture(server: SyncServer, default_user, default_file, default_source):
|
||||||
"""Helper function to create test passages for all tests"""
|
"""Fixture to create a source passage."""
|
||||||
dummy_embedding = [0] * 2
|
passage = server.passage_manager.create_passage(
|
||||||
passages = [
|
|
||||||
PydanticPassage(
|
PydanticPassage(
|
||||||
organization_id=default_user.organization_id,
|
text="Hello, I am a source passage",
|
||||||
agent_id=sarah_agent.id,
|
source_id=default_source.id,
|
||||||
file_id=default_file.id,
|
file_id=default_file.id,
|
||||||
text=f"Test passage {i}",
|
organization_id=default_user.organization_id,
|
||||||
embedding=dummy_embedding,
|
embedding=[0.1],
|
||||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
metadata_={"type": "test"}
|
||||||
|
),
|
||||||
|
actor=default_user
|
||||||
|
)
|
||||||
|
yield passage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent, default_source):
|
||||||
|
"""Helper function to create test passages for all tests."""
|
||||||
|
# Create agent passages
|
||||||
|
passages = []
|
||||||
|
for i in range(5):
|
||||||
|
passage = server.passage_manager.create_passage(
|
||||||
|
PydanticPassage(
|
||||||
|
text=f"Agent passage {i}",
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
organization_id=default_user.organization_id,
|
||||||
|
embedding=[0.1],
|
||||||
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
metadata_={"type": "test"}
|
||||||
|
),
|
||||||
|
actor=default_user
|
||||||
)
|
)
|
||||||
for i in range(4)
|
passages.append(passage)
|
||||||
]
|
if USING_SQLITE:
|
||||||
server.passage_manager.create_many_passages(passages, actor=default_user)
|
time.sleep(CREATE_DELAY_SQLITE)
|
||||||
|
|
||||||
|
# Create source passages
|
||||||
|
for i in range(5):
|
||||||
|
passage = server.passage_manager.create_passage(
|
||||||
|
PydanticPassage(
|
||||||
|
text=f"Source passage {i}",
|
||||||
|
source_id=default_source.id,
|
||||||
|
file_id=default_file.id,
|
||||||
|
organization_id=default_user.organization_id,
|
||||||
|
embedding=[0.1],
|
||||||
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
metadata_={"type": "test"}
|
||||||
|
),
|
||||||
|
actor=default_user
|
||||||
|
)
|
||||||
|
passages.append(passage)
|
||||||
|
if USING_SQLITE:
|
||||||
|
time.sleep(CREATE_DELAY_SQLITE)
|
||||||
|
|
||||||
return passages
|
return passages
|
||||||
|
|
||||||
|
|
||||||
@@ -389,6 +433,49 @@ def server():
|
|||||||
return server
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||||
|
"""Setup fixture for agent passages tests"""
|
||||||
|
agent_id = sarah_agent.id
|
||||||
|
actor = default_user
|
||||||
|
|
||||||
|
server.agent_manager.attach_source(agent_id=agent_id, source_id=default_source.id, actor=actor)
|
||||||
|
|
||||||
|
# Create some source passages
|
||||||
|
source_passages = []
|
||||||
|
for i in range(3):
|
||||||
|
passage = server.passage_manager.create_passage(
|
||||||
|
PydanticPassage(
|
||||||
|
organization_id=actor.organization_id,
|
||||||
|
source_id=default_source.id,
|
||||||
|
text=f"Source passage {i}",
|
||||||
|
embedding=[0.1], # Default OpenAI embedding size
|
||||||
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
),
|
||||||
|
actor=actor
|
||||||
|
)
|
||||||
|
source_passages.append(passage)
|
||||||
|
|
||||||
|
# Create some agent passages
|
||||||
|
agent_passages = []
|
||||||
|
for i in range(2):
|
||||||
|
passage = server.passage_manager.create_passage(
|
||||||
|
PydanticPassage(
|
||||||
|
organization_id=actor.organization_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
text=f"Agent passage {i}",
|
||||||
|
embedding=[0.1], # Default OpenAI embedding size
|
||||||
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
),
|
||||||
|
actor=actor
|
||||||
|
)
|
||||||
|
agent_passages.append(passage)
|
||||||
|
|
||||||
|
yield agent_passages, source_passages
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
server.source_manager.delete_source(default_source.id, actor=actor)
|
||||||
|
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
# AgentManager Tests - Basic
|
# AgentManager Tests - Basic
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
@@ -849,6 +936,199 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
|
|||||||
assert block.label == default_block.label
|
assert block.label == default_block.label
|
||||||
|
|
||||||
|
|
||||||
|
# ======================================================================================================================
|
||||||
|
# Agent Manager - Passages Tests
|
||||||
|
# ======================================================================================================================
|
||||||
|
|
||||||
|
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
|
||||||
|
"""Test basic listing functionality of agent passages"""
|
||||||
|
|
||||||
|
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
|
||||||
|
assert len(all_passages) == 5 # 3 source + 2 agent passages
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
|
||||||
|
"""Test ordering of agent passages"""
|
||||||
|
|
||||||
|
# Test ascending order
|
||||||
|
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
|
||||||
|
assert len(asc_passages) == 5
|
||||||
|
for i in range(1, len(asc_passages)):
|
||||||
|
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
|
||||||
|
|
||||||
|
# Test descending order
|
||||||
|
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
|
||||||
|
assert len(desc_passages) == 5
|
||||||
|
for i in range(1, len(desc_passages)):
|
||||||
|
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
|
||||||
|
"""Test pagination of agent passages"""
|
||||||
|
|
||||||
|
# Test limit
|
||||||
|
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
|
||||||
|
assert len(limited_passages) == 3
|
||||||
|
|
||||||
|
# Test cursor-based pagination
|
||||||
|
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
|
||||||
|
assert len(first_page) == 2
|
||||||
|
|
||||||
|
second_page = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
cursor=first_page[-1].id,
|
||||||
|
limit=2,
|
||||||
|
ascending=True
|
||||||
|
)
|
||||||
|
assert len(second_page) == 2
|
||||||
|
assert first_page[-1].id != second_page[0].id
|
||||||
|
assert first_page[-1].created_at <= second_page[0].created_at
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
|
||||||
|
"""Test text search functionality of agent passages"""
|
||||||
|
|
||||||
|
# Test text search for source passages
|
||||||
|
source_text_passages = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
query_text="Source passage"
|
||||||
|
)
|
||||||
|
assert len(source_text_passages) == 3
|
||||||
|
|
||||||
|
# Test text search for agent passages
|
||||||
|
agent_text_passages = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
query_text="Agent passage"
|
||||||
|
)
|
||||||
|
assert len(agent_text_passages) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
|
||||||
|
"""Test text search functionality of agent passages"""
|
||||||
|
|
||||||
|
# Test text search for agent passages
|
||||||
|
agent_text_passages = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
agent_only=True
|
||||||
|
)
|
||||||
|
assert len(agent_text_passages) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
|
||||||
|
"""Test filtering functionality of agent passages"""
|
||||||
|
|
||||||
|
# Test source filtering
|
||||||
|
source_filtered = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
source_id=default_source.id
|
||||||
|
)
|
||||||
|
assert len(source_filtered) == 3
|
||||||
|
|
||||||
|
# Test date filtering
|
||||||
|
now = datetime.utcnow()
|
||||||
|
future_date = now + timedelta(days=1)
|
||||||
|
past_date = now - timedelta(days=1)
|
||||||
|
|
||||||
|
date_filtered = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
start_date=past_date,
|
||||||
|
end_date=future_date
|
||||||
|
)
|
||||||
|
assert len(date_filtered) == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source):
|
||||||
|
"""Test vector search functionality of agent passages"""
|
||||||
|
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
||||||
|
|
||||||
|
# Create passages with known embeddings
|
||||||
|
passages = []
|
||||||
|
|
||||||
|
# Create passages with different embeddings
|
||||||
|
test_passages = [
|
||||||
|
"I like red",
|
||||||
|
"random text",
|
||||||
|
"blue shoes",
|
||||||
|
]
|
||||||
|
|
||||||
|
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||||
|
|
||||||
|
for i, text in enumerate(test_passages):
|
||||||
|
embedding = embed_model.get_text_embedding(text)
|
||||||
|
if i % 2 == 0:
|
||||||
|
passage = PydanticPassage(
|
||||||
|
text=text,
|
||||||
|
organization_id=default_user.organization_id,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
embedding=embedding
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
passage = PydanticPassage(
|
||||||
|
text=text,
|
||||||
|
organization_id=default_user.organization_id,
|
||||||
|
source_id=default_source.id,
|
||||||
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
embedding=embedding
|
||||||
|
)
|
||||||
|
created_passage = server.passage_manager.create_passage(passage, default_user)
|
||||||
|
passages.append(created_passage)
|
||||||
|
|
||||||
|
# Query vector similar to "red" embedding
|
||||||
|
query_key = "What's my favorite color?"
|
||||||
|
|
||||||
|
# Test vector search with all passages
|
||||||
|
results = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
query_text=query_key,
|
||||||
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
embed_query=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify results are ordered by similarity
|
||||||
|
assert len(results) == 3
|
||||||
|
assert results[0].text == "I like red"
|
||||||
|
assert "random" in results[1].text or "random" in results[2].text
|
||||||
|
assert "blue" in results[1].text or "blue" in results[2].text
|
||||||
|
|
||||||
|
# Test vector search with agent_only=True
|
||||||
|
agent_only_results = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
agent_id=sarah_agent.id,
|
||||||
|
query_text=query_key,
|
||||||
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
|
embed_query=True,
|
||||||
|
agent_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify agent-only results
|
||||||
|
assert len(agent_only_results) == 2
|
||||||
|
assert agent_only_results[0].text == "I like red"
|
||||||
|
assert agent_only_results[1].text == "blue shoes"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
|
||||||
|
"""Test listing passages from a source without specifying an agent."""
|
||||||
|
|
||||||
|
# List passages by source_id without agent_id
|
||||||
|
source_passages = server.agent_manager.list_passages(
|
||||||
|
actor=default_user,
|
||||||
|
source_id=default_source.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we get only source passages (3 from agent_passages_setup)
|
||||||
|
assert len(source_passages) == 3
|
||||||
|
assert all(p.source_id == default_source.id for p in source_passages)
|
||||||
|
assert all(p.agent_id is None for p in source_passages)
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
# Organization Manager Tests
|
# Organization Manager Tests
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
@@ -900,266 +1180,86 @@ def test_list_organizations_pagination(server: SyncServer):
|
|||||||
# Passage Manager Tests
|
# Passage Manager Tests
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
|
|
||||||
|
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
|
||||||
def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user):
|
"""Test creating a passage using agent_passage_fixture fixture"""
|
||||||
"""Test creating a passage using hello_world_passage_fixture fixture"""
|
assert agent_passage_fixture.id is not None
|
||||||
assert hello_world_passage_fixture.id is not None
|
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||||
assert hello_world_passage_fixture.text == "Hello, world!"
|
|
||||||
|
|
||||||
# Verify we can retrieve it
|
# Verify we can retrieve it
|
||||||
retrieved = server.passage_manager.get_passage_by_id(
|
retrieved = server.passage_manager.get_passage_by_id(
|
||||||
hello_world_passage_fixture.id,
|
agent_passage_fixture.id,
|
||||||
actor=default_user,
|
actor=default_user,
|
||||||
)
|
)
|
||||||
assert retrieved is not None
|
assert retrieved is not None
|
||||||
assert retrieved.id == hello_world_passage_fixture.id
|
assert retrieved.id == agent_passage_fixture.id
|
||||||
assert retrieved.text == hello_world_passage_fixture.text
|
assert retrieved.text == agent_passage_fixture.text
|
||||||
|
|
||||||
|
|
||||||
def test_passage_get_by_id(server: SyncServer, hello_world_passage_fixture, default_user):
|
def test_passage_create_source(server: SyncServer, source_passage_fixture, default_user):
|
||||||
"""Test retrieving a passage by ID"""
|
"""Test creating a source passage."""
|
||||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
assert source_passage_fixture is not None
|
||||||
|
assert source_passage_fixture.text == "Hello, I am a source passage"
|
||||||
|
|
||||||
|
# Verify we can retrieve it
|
||||||
|
retrieved = server.passage_manager.get_passage_by_id(
|
||||||
|
source_passage_fixture.id,
|
||||||
|
actor=default_user,
|
||||||
|
)
|
||||||
assert retrieved is not None
|
assert retrieved is not None
|
||||||
assert retrieved.id == hello_world_passage_fixture.id
|
assert retrieved.id == source_passage_fixture.id
|
||||||
assert retrieved.text == hello_world_passage_fixture.text
|
assert retrieved.text == source_passage_fixture.text
|
||||||
|
|
||||||
|
|
||||||
def test_passage_update(server: SyncServer, hello_world_passage_fixture, default_user):
|
def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user):
|
||||||
"""Test updating a passage"""
|
"""Test creating an agent passage."""
|
||||||
new_text = "Updated text"
|
assert agent_passage_fixture is not None
|
||||||
hello_world_passage_fixture.text = new_text
|
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||||
updated = server.passage_manager.update_passage_by_id(hello_world_passage_fixture.id, hello_world_passage_fixture, actor=default_user)
|
|
||||||
assert updated is not None
|
# Try to create an invalid passage (with both agent_id and source_id)
|
||||||
assert updated.text == new_text
|
with pytest.raises(AssertionError):
|
||||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
server.passage_manager.create_passage(
|
||||||
assert retrieved.text == new_text
|
|
||||||
|
|
||||||
|
|
||||||
def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user):
|
|
||||||
"""Test deleting a passage"""
|
|
||||||
server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
|
||||||
with pytest.raises(NoResultFound):
|
|
||||||
server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
|
||||||
|
|
||||||
|
|
||||||
def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
|
||||||
"""Test counting passages with filters"""
|
|
||||||
base_passage = hello_world_passage_fixture
|
|
||||||
|
|
||||||
# Test total count
|
|
||||||
total = server.passage_manager.size(actor=default_user)
|
|
||||||
assert total == 5 # base passage + 4 test passages
|
|
||||||
# TODO: change login passage to be a system not user passage
|
|
||||||
|
|
||||||
# Test count with agent filter
|
|
||||||
agent_count = server.passage_manager.size(actor=default_user, agent_id=base_passage.agent_id)
|
|
||||||
assert agent_count == 5
|
|
||||||
|
|
||||||
# Test count with role filter
|
|
||||||
role_count = server.passage_manager.size(actor=default_user)
|
|
||||||
assert role_count == 5
|
|
||||||
|
|
||||||
# Test count with non-existent filter
|
|
||||||
empty_count = server.passage_manager.size(actor=default_user, agent_id="non-existent")
|
|
||||||
assert empty_count == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_passage_listing_basic(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
|
||||||
"""Test basic passage listing with limit"""
|
|
||||||
results = server.passage_manager.list_passages(actor=default_user, limit=3)
|
|
||||||
assert len(results) == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
|
||||||
"""Test cursor-based pagination functionality"""
|
|
||||||
|
|
||||||
# Make sure there are 5 passages
|
|
||||||
assert server.passage_manager.size(actor=default_user) == 5
|
|
||||||
|
|
||||||
# Get first page
|
|
||||||
first_page = server.passage_manager.list_passages(actor=default_user, limit=3)
|
|
||||||
assert len(first_page) == 3
|
|
||||||
|
|
||||||
last_id_on_first_page = first_page[-1].id
|
|
||||||
|
|
||||||
# Get second page
|
|
||||||
second_page = server.passage_manager.list_passages(actor=default_user, cursor=last_id_on_first_page, limit=3)
|
|
||||||
assert len(second_page) == 2 # Should have 2 remaining passages
|
|
||||||
assert all(r1.id != r2.id for r1 in first_page for r2 in second_page)
|
|
||||||
|
|
||||||
|
|
||||||
def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
|
|
||||||
"""Test filtering passages by agent ID"""
|
|
||||||
agent_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, limit=10)
|
|
||||||
assert len(agent_results) == 5 # base passage + 4 test passages
|
|
||||||
assert all(msg.agent_id == hello_world_passage_fixture.agent_id for msg in agent_results)
|
|
||||||
|
|
||||||
|
|
||||||
def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
|
|
||||||
"""Test searching passages by text content"""
|
|
||||||
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10)
|
|
||||||
assert len(search_results) == 4
|
|
||||||
assert all("Test passage" in msg.text for msg in search_results)
|
|
||||||
|
|
||||||
# Test no results
|
|
||||||
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10)
|
|
||||||
assert len(search_results) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_passage_fixture, default_user, default_file, sarah_agent):
|
|
||||||
"""Test filtering passages by date range with various scenarios"""
|
|
||||||
# Set up test data with known dates
|
|
||||||
base_time = datetime.utcnow()
|
|
||||||
|
|
||||||
# Create passages at different times
|
|
||||||
passages = []
|
|
||||||
time_offsets = [
|
|
||||||
timedelta(days=-2), # 2 days ago
|
|
||||||
timedelta(days=-1), # Yesterday
|
|
||||||
timedelta(hours=-2), # 2 hours ago
|
|
||||||
timedelta(minutes=-30), # 30 minutes ago
|
|
||||||
timedelta(minutes=-1), # 1 minute ago
|
|
||||||
timedelta(minutes=0), # Now
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, offset in enumerate(time_offsets):
|
|
||||||
timestamp = base_time + offset
|
|
||||||
passage = server.passage_manager.create_passage(
|
|
||||||
PydanticPassage(
|
PydanticPassage(
|
||||||
|
text="Invalid passage",
|
||||||
|
agent_id="123",
|
||||||
|
source_id="456",
|
||||||
organization_id=default_user.organization_id,
|
organization_id=default_user.organization_id,
|
||||||
agent_id=sarah_agent.id,
|
embedding=[0.1] * 1024,
|
||||||
file_id=default_file.id,
|
|
||||||
text=f"Test passage {i}",
|
|
||||||
embedding=[0.1, 0.2, 0.3],
|
|
||||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||||
created_at=timestamp,
|
|
||||||
),
|
),
|
||||||
actor=default_user,
|
actor=default_user
|
||||||
)
|
|
||||||
passages.append(passage)
|
|
||||||
|
|
||||||
# Test cases
|
|
||||||
test_cases = [
|
|
||||||
{
|
|
||||||
"name": "Recent passages (last hour)",
|
|
||||||
"start_date": base_time - timedelta(hours=1),
|
|
||||||
"end_date": base_time + timedelta(minutes=1),
|
|
||||||
"expected_count": 1 + 3, # Should include base + -30min, -1min, and now
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Yesterday's passages",
|
|
||||||
"start_date": base_time - timedelta(days=1, hours=12),
|
|
||||||
"end_date": base_time - timedelta(hours=12),
|
|
||||||
"expected_count": 1, # Should only include yesterday's passage
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Future time range",
|
|
||||||
"start_date": base_time + timedelta(days=1),
|
|
||||||
"end_date": base_time + timedelta(days=2),
|
|
||||||
"expected_count": 0, # Should find no passages
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "All time",
|
|
||||||
"start_date": base_time - timedelta(days=3),
|
|
||||||
"end_date": base_time + timedelta(days=1),
|
|
||||||
"expected_count": 1 + len(passages), # Should find all passages
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Exact timestamp match",
|
|
||||||
"start_date": passages[0].created_at - timedelta(microseconds=1),
|
|
||||||
"end_date": passages[0].created_at + timedelta(microseconds=1),
|
|
||||||
"expected_count": 1, # Should find exactly one passage
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Small time window",
|
|
||||||
"start_date": base_time - timedelta(seconds=30),
|
|
||||||
"end_date": base_time + timedelta(seconds=30),
|
|
||||||
"expected_count": 1 + 1, # date + "now"
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Run test cases
|
|
||||||
for case in test_cases:
|
|
||||||
results = server.passage_manager.list_passages(
|
|
||||||
agent_id=sarah_agent.id, actor=default_user, start_date=case["start_date"], end_date=case["end_date"], limit=10
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify count
|
|
||||||
assert (
|
|
||||||
len(results) == case["expected_count"]
|
|
||||||
), f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}"
|
|
||||||
|
|
||||||
# Test edge cases
|
def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user):
|
||||||
|
"""Test retrieving a passage by ID"""
|
||||||
|
retrieved = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, actor=default_user)
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.id == agent_passage_fixture.id
|
||||||
|
assert retrieved.text == agent_passage_fixture.text
|
||||||
|
|
||||||
# Test with start_date but no end_date
|
retrieved = server.passage_manager.get_passage_by_id(source_passage_fixture.id, actor=default_user)
|
||||||
results_start_only = server.passage_manager.list_passages(
|
assert retrieved is not None
|
||||||
agent_id=sarah_agent.id, actor=default_user, start_date=base_time - timedelta(minutes=2), end_date=None, limit=10
|
assert retrieved.id == source_passage_fixture.id
|
||||||
)
|
assert retrieved.text == source_passage_fixture.text
|
||||||
assert len(results_start_only) >= 2, "Should find passages after start_date"
|
|
||||||
|
|
||||||
# Test with end_date but no start_date
|
|
||||||
results_end_only = server.passage_manager.list_passages(
|
|
||||||
agent_id=sarah_agent.id, actor=default_user, start_date=None, end_date=base_time - timedelta(days=1), limit=10
|
|
||||||
)
|
|
||||||
assert len(results_end_only) >= 1, "Should find passages before end_date"
|
|
||||||
|
|
||||||
# Test limit enforcement
|
|
||||||
limited_results = server.passage_manager.list_passages(
|
|
||||||
agent_id=sarah_agent.id,
|
|
||||||
actor=default_user,
|
|
||||||
start_date=base_time - timedelta(days=3),
|
|
||||||
end_date=base_time + timedelta(days=1),
|
|
||||||
limit=3,
|
|
||||||
)
|
|
||||||
assert len(limited_results) <= 3, "Should respect the limit parameter"
|
|
||||||
|
|
||||||
|
|
||||||
def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent):
|
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
|
||||||
"""Test vector search functionality for passages."""
|
"""Test that passages are deleted when their parent (agent or source) is deleted."""
|
||||||
passage_manager = server.passage_manager
|
# Verify passages exist
|
||||||
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
|
||||||
|
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||||
# Create passages with known embeddings
|
assert agent_passage is not None
|
||||||
passages = []
|
assert source_passage is not None
|
||||||
|
|
||||||
# Create passages with different embeddings
|
# Delete agent and verify its passages are deleted
|
||||||
test_passages = [
|
server.agent_manager.delete_agent(sarah_agent.id, default_user)
|
||||||
"I like red",
|
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
|
||||||
"random text",
|
assert len(agentic_passages) == 0
|
||||||
"blue shoes",
|
|
||||||
]
|
# Delete source and verify its passages are deleted
|
||||||
|
server.source_manager.delete_source(default_source.id, default_user)
|
||||||
for text in test_passages:
|
with pytest.raises(NoResultFound):
|
||||||
embedding = embed_model.get_text_embedding(text)
|
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||||
passage = PydanticPassage(
|
|
||||||
text=text,
|
|
||||||
organization_id=default_user.organization_id,
|
|
||||||
agent_id=sarah_agent.id,
|
|
||||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
|
||||||
embedding=embedding,
|
|
||||||
)
|
|
||||||
created_passage = passage_manager.create_passage(passage, default_user)
|
|
||||||
passages.append(created_passage)
|
|
||||||
assert passage_manager.size(actor=default_user) == len(passages)
|
|
||||||
|
|
||||||
# Query vector similar to "cats" embedding
|
|
||||||
query_key = "What's my favorite color?"
|
|
||||||
|
|
||||||
# List passages with vector search
|
|
||||||
results = passage_manager.list_passages(
|
|
||||||
actor=default_user,
|
|
||||||
agent_id=sarah_agent.id,
|
|
||||||
query_text=query_key,
|
|
||||||
limit=3,
|
|
||||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
|
||||||
embed_query=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify results are ordered by similarity
|
|
||||||
assert len(results) == 3
|
|
||||||
assert results[0].text == "I like red"
|
|
||||||
assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes"
|
|
||||||
assert results[2].text == "blue shoes"
|
|
||||||
|
|
||||||
|
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
@@ -1220,6 +1320,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
|
|||||||
assert print_tool.organization_id == default_organization.id
|
assert print_tool.organization_id == default_organization.id
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
||||||
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
||||||
data = print_tool.model_dump(exclude=["id"])
|
data = print_tool.model_dump(exclude=["id"])
|
||||||
@@ -1787,6 +1888,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
|
|||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
# Source Manager Tests - Files
|
# Source Manager Tests - Files
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
|
|
||||||
def test_get_file_by_id(server: SyncServer, default_user, default_source):
|
def test_get_file_by_id(server: SyncServer, default_user, default_source):
|
||||||
"""Test retrieving a file by ID."""
|
"""Test retrieving a file by ID."""
|
||||||
file_metadata = PydanticFileMetadata(
|
file_metadata = PydanticFileMetadata(
|
||||||
@@ -1857,6 +1959,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
|
|||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
# SandboxConfigManager Tests - Sandbox Configs
|
# SandboxConfigManager Tests - Sandbox Configs
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
|
|
||||||
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
|
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
|
||||||
sandbox_config_create = SandboxConfigCreate(
|
sandbox_config_create = SandboxConfigCreate(
|
||||||
config=E2BSandboxConfig(),
|
config=E2BSandboxConfig(),
|
||||||
@@ -1935,6 +2038,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
|
|||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
# SandboxConfigManager Tests - Environment Variables
|
# SandboxConfigManager Tests - Environment Variables
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
|
|
||||||
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
|
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
|
||||||
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
|
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
|
||||||
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
|
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
|
||||||
@@ -2007,7 +2111,6 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
|
|||||||
# JobManager Tests
|
# JobManager Tests
|
||||||
# ======================================================================================================================
|
# ======================================================================================================================
|
||||||
|
|
||||||
|
|
||||||
def test_create_job(server: SyncServer, default_user):
|
def test_create_job(server: SyncServer, default_user):
|
||||||
"""Test creating a job."""
|
"""Test creating a job."""
|
||||||
job_data = PydanticJob(
|
job_data = PydanticJob(
|
||||||
|
|||||||
@@ -390,12 +390,16 @@ def test_user_message_memory(server, user_id, agent_id):
|
|||||||
|
|
||||||
@pytest.mark.order(3)
|
@pytest.mark.order(3)
|
||||||
def test_load_data(server, user_id, agent_id):
|
def test_load_data(server, user_id, agent_id):
|
||||||
|
user = server.user_manager.get_user_or_default(user_id=user_id)
|
||||||
|
|
||||||
# create source
|
# create source
|
||||||
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
|
passages_before = server.agent_manager.list_passages(
|
||||||
|
actor=user, agent_id=agent_id, cursor=None, limit=10000
|
||||||
|
)
|
||||||
assert len(passages_before) == 0
|
assert len(passages_before) == 0
|
||||||
|
|
||||||
source = server.source_manager.create_source(
|
source = server.source_manager.create_source(
|
||||||
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=server.default_user
|
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# load data
|
# load data
|
||||||
@@ -409,15 +413,11 @@ def test_load_data(server, user_id, agent_id):
|
|||||||
connector = DummyDataConnector(archival_memories)
|
connector = DummyDataConnector(archival_memories)
|
||||||
server.load_data(user_id, connector, source.name)
|
server.load_data(user_id, connector, source.name)
|
||||||
|
|
||||||
# @pytest.mark.order(3)
|
|
||||||
# def test_attach_source_to_agent(server, user_id, agent_id):
|
|
||||||
# check archival memory size
|
|
||||||
|
|
||||||
# attach source
|
# attach source
|
||||||
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
||||||
|
|
||||||
# check archival memory size
|
# check archival memory size
|
||||||
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
|
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
|
||||||
assert len(passages_after) == 5
|
assert len(passages_after) == 5
|
||||||
|
|
||||||
|
|
||||||
@@ -465,7 +465,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
|||||||
user = server.user_manager.get_user_by_id(user_id=user_id)
|
user = server.user_manager.get_user_by_id(user_id=user_id)
|
||||||
|
|
||||||
# List latest 2 passages
|
# List latest 2 passages
|
||||||
passages_1 = server.passage_manager.list_passages(
|
passages_1 = server.agent_manager.list_passages(
|
||||||
actor=user,
|
actor=user,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
ascending=False,
|
ascending=False,
|
||||||
@@ -475,7 +475,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
|||||||
|
|
||||||
# List next 3 passages (earliest 3)
|
# List next 3 passages (earliest 3)
|
||||||
cursor1 = passages_1[-1].id
|
cursor1 = passages_1[-1].id
|
||||||
passages_2 = server.passage_manager.list_passages(
|
passages_2 = server.agent_manager.list_passages(
|
||||||
actor=user,
|
actor=user,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
ascending=False,
|
ascending=False,
|
||||||
@@ -484,24 +484,28 @@ def test_get_archival_memory(server, user_id, agent_id):
|
|||||||
|
|
||||||
# List all 5
|
# List all 5
|
||||||
cursor2 = passages_1[0].created_at
|
cursor2 = passages_1[0].created_at
|
||||||
passages_3 = server.passage_manager.list_passages(
|
passages_3 = server.agent_manager.list_passages(
|
||||||
actor=user,
|
actor=user,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
ascending=False,
|
ascending=False,
|
||||||
end_date=cursor2,
|
end_date=cursor2,
|
||||||
limit=1000,
|
limit=1000,
|
||||||
)
|
)
|
||||||
# assert passages_1[0].text == "Cinderella wore a blue dress"
|
|
||||||
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
|
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
|
||||||
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||||
|
|
||||||
|
latest = passages_1[0]
|
||||||
|
earliest = passages_2[-1]
|
||||||
|
|
||||||
# test archival memory
|
# test archival memory
|
||||||
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1)
|
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
|
||||||
assert len(passage_1) == 1
|
assert len(passage_1) == 1
|
||||||
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000)
|
assert passage_1[0].text == "alpha"
|
||||||
|
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
|
||||||
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||||
|
assert all("alpha" not in passage.text for passage in passage_2)
|
||||||
# test safe empty return
|
# test safe empty return
|
||||||
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000)
|
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
|
||||||
assert len(passage_none) == 0
|
assert len(passage_none) == 0
|
||||||
|
|
||||||
|
|
||||||
@@ -955,6 +959,14 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
|
|||||||
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
|
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
|
||||||
actor = server.user_manager.get_user_or_default(user_id)
|
actor = server.user_manager.get_user_or_default(user_id)
|
||||||
|
|
||||||
|
existing_sources = server.source_manager.list_sources(actor=actor)
|
||||||
|
if len(existing_sources) > 0:
|
||||||
|
for source in existing_sources:
|
||||||
|
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
|
||||||
|
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||||
|
assert initial_passage_count == 0
|
||||||
|
|
||||||
|
|
||||||
# Create a source
|
# Create a source
|
||||||
source = server.source_manager.create_source(
|
source = server.source_manager.create_source(
|
||||||
PydanticSource(
|
PydanticSource(
|
||||||
@@ -973,10 +985,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
|||||||
# Attach source to agent first
|
# Attach source to agent first
|
||||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
|
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
|
||||||
|
|
||||||
# Get initial passage count
|
|
||||||
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
|
||||||
assert initial_passage_count == 0
|
|
||||||
|
|
||||||
# Create a job for loading the first file
|
# Create a job for loading the first file
|
||||||
job = server.job_manager.create_job(
|
job = server.job_manager.create_job(
|
||||||
PydanticJob(
|
PydanticJob(
|
||||||
@@ -1001,7 +1009,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
|||||||
assert job.metadata_["num_documents"] == 1
|
assert job.metadata_["num_documents"] == 1
|
||||||
|
|
||||||
# Verify passages were added
|
# Verify passages were added
|
||||||
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||||
assert first_file_passage_count > initial_passage_count
|
assert first_file_passage_count > initial_passage_count
|
||||||
|
|
||||||
# Create a second test file with different content
|
# Create a second test file with different content
|
||||||
@@ -1032,14 +1040,13 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
|||||||
assert job2.metadata_["num_documents"] == 1
|
assert job2.metadata_["num_documents"] == 1
|
||||||
|
|
||||||
# Verify passages were appended (not replaced)
|
# Verify passages were appended (not replaced)
|
||||||
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||||
assert final_passage_count > first_file_passage_count
|
assert final_passage_count > first_file_passage_count
|
||||||
|
|
||||||
# Verify both old and new content is searchable
|
# Verify both old and new content is searchable
|
||||||
passages = server.passage_manager.list_passages(
|
passages = server.agent_manager.list_passages(
|
||||||
actor=actor,
|
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
source_id=source.id,
|
actor=actor,
|
||||||
query_text="what does Timber like to eat",
|
query_text="what does Timber like to eat",
|
||||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||||
embed_query=True,
|
embed_query=True,
|
||||||
@@ -1048,35 +1055,27 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
|||||||
assert any("chicken" in passage.text.lower() for passage in passages)
|
assert any("chicken" in passage.text.lower() for passage in passages)
|
||||||
assert any("Anna".lower() in passage.text.lower() for passage in passages)
|
assert any("Anna".lower() in passage.text.lower() for passage in passages)
|
||||||
|
|
||||||
# TODO: Add this test back in after separation of `Passage tables` (LET-449)
|
# Initially should have no passages
|
||||||
# # Load second agent
|
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
|
||||||
# agent2 = server.load_agent(agent_id=other_agent_id)
|
assert initial_agent2_passages == 0
|
||||||
|
|
||||||
# # Initially should have no passages
|
# Attach source to second agent
|
||||||
# initial_agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
|
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)
|
||||||
# assert initial_agent2_passages == 0
|
|
||||||
|
|
||||||
# # Attach source to second agent
|
# Verify second agent has same number of passages as first agent
|
||||||
# agent2.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)
|
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
|
||||||
|
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
|
||||||
|
assert agent2_passages == agent1_passages
|
||||||
|
|
||||||
# # Verify second agent has same number of passages as first agent
|
# Verify second agent can query the same content
|
||||||
# agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
|
passages2 = server.agent_manager.list_passages(
|
||||||
# agent1_passages = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
|
actor=actor,
|
||||||
# assert agent2_passages == agent1_passages
|
agent_id=other_agent_id,
|
||||||
|
source_id=source.id,
|
||||||
# # Verify second agent can query the same content
|
query_text="what does Timber like to eat",
|
||||||
# passages2 = server.passage_manager.list_passages(
|
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||||
# actor=user,
|
embed_query=True,
|
||||||
# agent_id=other_agent_id,
|
)
|
||||||
# source_id=source.id,
|
assert len(passages2) == len(passages)
|
||||||
# query_text="what does Timber like to eat",
|
assert any("chicken" in passage.text.lower() for passage in passages2)
|
||||||
# embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
|
||||||
# embed_query=True,
|
|
||||||
# limit=10,
|
|
||||||
# )
|
|
||||||
# assert len(passages2) == len(passages)
|
|
||||||
# assert any("chicken" in passage.text.lower() for passage in passages2)
|
|
||||||
# assert any("sleep" in passage.text.lower() for passage in passages2)
|
|
||||||
|
|
||||||
# # Cleanup
|
|
||||||
# server.delete_agent(user_id=user_id, agent_id=agent2_state.id)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user