feat: separate Passages tables (#2245)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -482,7 +482,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
# check agent archival memory size
|
||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||
print(archival_memories)
|
||||
assert len(archival_memories) == 0
|
||||
|
||||
# load a file into a source (non-blocking job)
|
||||
|
||||
@@ -2,6 +2,8 @@ import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from httpx._transports import default
|
||||
from numpy import source
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -17,7 +19,8 @@ from letta.orm import (
|
||||
Job,
|
||||
Message,
|
||||
Organization,
|
||||
Passage,
|
||||
AgentPassage,
|
||||
SourcePassage,
|
||||
SandboxConfig,
|
||||
SandboxEnvironmentVariable,
|
||||
Source,
|
||||
@@ -82,7 +85,8 @@ def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Passage))
|
||||
session.execute(delete(AgentPassage))
|
||||
session.execute(delete(SourcePassage))
|
||||
session.execute(delete(Job))
|
||||
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||
session.execute(delete(BlocksAgents))
|
||||
@@ -189,39 +193,79 @@ def print_tool(server: SyncServer, default_user, default_organization):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hello_world_passage_fixture(server: SyncServer, default_user, default_file, sarah_agent):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
# Set up passage
|
||||
dummy_embedding = [0.0] * 2
|
||||
message = PydanticPassage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
text="Hello, world!",
|
||||
embedding=dummy_embedding,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
"""Fixture to create an agent passage."""
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text="Hello, I am an agent passage",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
|
||||
msg = server.passage_manager.create_passage(message, actor=default_user)
|
||||
yield msg
|
||||
yield passage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent) -> list[PydanticPassage]:
|
||||
"""Helper function to create test passages for all tests"""
|
||||
dummy_embedding = [0] * 2
|
||||
passages = [
|
||||
def source_passage_fixture(server: SyncServer, default_user, default_file, default_source):
|
||||
"""Fixture to create a source passage."""
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
text="Hello, I am a source passage",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
text=f"Test passage {i}",
|
||||
embedding=dummy_embedding,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
yield passage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent, default_source):
|
||||
"""Helper function to create test passages for all tests."""
|
||||
# Create agent passages
|
||||
passages = []
|
||||
for i in range(5):
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text=f"Agent passage {i}",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
server.passage_manager.create_many_passages(passages, actor=default_user)
|
||||
passages.append(passage)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
|
||||
# Create source passages
|
||||
for i in range(5):
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text=f"Source passage {i}",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
passages.append(passage)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
|
||||
return passages
|
||||
|
||||
|
||||
@@ -389,6 +433,49 @@ def server():
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
"""Setup fixture for agent passages tests"""
|
||||
agent_id = sarah_agent.id
|
||||
actor = default_user
|
||||
|
||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=default_source.id, actor=actor)
|
||||
|
||||
# Create some source passages
|
||||
source_passages = []
|
||||
for i in range(3):
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
source_id=default_source.id,
|
||||
text=f"Source passage {i}",
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=actor
|
||||
)
|
||||
source_passages.append(passage)
|
||||
|
||||
# Create some agent passages
|
||||
agent_passages = []
|
||||
for i in range(2):
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
text=f"Agent passage {i}",
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=actor
|
||||
)
|
||||
agent_passages.append(passage)
|
||||
|
||||
yield agent_passages, source_passages
|
||||
|
||||
# Cleanup
|
||||
server.source_manager.delete_source(default_source.id, actor=actor)
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -849,6 +936,199 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
|
||||
assert block.label == default_block.label
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Agent Manager - Passages Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test basic listing functionality of agent passages"""
|
||||
|
||||
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
|
||||
assert len(all_passages) == 5 # 3 source + 2 agent passages
|
||||
|
||||
|
||||
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test ordering of agent passages"""
|
||||
|
||||
# Test ascending order
|
||||
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
|
||||
assert len(asc_passages) == 5
|
||||
for i in range(1, len(asc_passages)):
|
||||
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
|
||||
|
||||
# Test descending order
|
||||
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
|
||||
assert len(desc_passages) == 5
|
||||
for i in range(1, len(desc_passages)):
|
||||
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
|
||||
|
||||
|
||||
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test pagination of agent passages"""
|
||||
|
||||
# Test limit
|
||||
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
|
||||
assert len(limited_passages) == 3
|
||||
|
||||
# Test cursor-based pagination
|
||||
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
|
||||
assert len(first_page) == 2
|
||||
|
||||
second_page = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
cursor=first_page[-1].id,
|
||||
limit=2,
|
||||
ascending=True
|
||||
)
|
||||
assert len(second_page) == 2
|
||||
assert first_page[-1].id != second_page[0].id
|
||||
assert first_page[-1].created_at <= second_page[0].created_at
|
||||
|
||||
|
||||
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
# Test text search for source passages
|
||||
source_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="Source passage"
|
||||
)
|
||||
assert len(source_text_passages) == 3
|
||||
|
||||
# Test text search for agent passages
|
||||
agent_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="Agent passage"
|
||||
)
|
||||
assert len(agent_text_passages) == 2
|
||||
|
||||
|
||||
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
# Test text search for agent passages
|
||||
agent_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
agent_only=True
|
||||
)
|
||||
assert len(agent_text_passages) == 2
|
||||
|
||||
|
||||
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
|
||||
"""Test filtering functionality of agent passages"""
|
||||
|
||||
# Test source filtering
|
||||
source_filtered = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
source_id=default_source.id
|
||||
)
|
||||
assert len(source_filtered) == 3
|
||||
|
||||
# Test date filtering
|
||||
now = datetime.utcnow()
|
||||
future_date = now + timedelta(days=1)
|
||||
past_date = now - timedelta(days=1)
|
||||
|
||||
date_filtered = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
start_date=past_date,
|
||||
end_date=future_date
|
||||
)
|
||||
assert len(date_filtered) == 5
|
||||
|
||||
|
||||
def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source):
|
||||
"""Test vector search functionality of agent passages"""
|
||||
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
|
||||
# Create passages with different embeddings
|
||||
test_passages = [
|
||||
"I like red",
|
||||
"random text",
|
||||
"blue shoes",
|
||||
]
|
||||
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
for i, text in enumerate(test_passages):
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
if i % 2 == 0:
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
)
|
||||
else:
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
source_id=default_source.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
)
|
||||
created_passage = server.passage_manager.create_passage(passage, default_user)
|
||||
passages.append(created_passage)
|
||||
|
||||
# Query vector similar to "red" embedding
|
||||
query_key = "What's my favorite color?"
|
||||
|
||||
# Test vector search with all passages
|
||||
results = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text=query_key,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
# Verify results are ordered by similarity
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "I like red"
|
||||
assert "random" in results[1].text or "random" in results[2].text
|
||||
assert "blue" in results[1].text or "blue" in results[2].text
|
||||
|
||||
# Test vector search with agent_only=True
|
||||
agent_only_results = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text=query_key,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
agent_only=True
|
||||
)
|
||||
|
||||
# Verify agent-only results
|
||||
assert len(agent_only_results) == 2
|
||||
assert agent_only_results[0].text == "I like red"
|
||||
assert agent_only_results[1].text == "blue shoes"
|
||||
|
||||
|
||||
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
|
||||
"""Test listing passages from a source without specifying an agent."""
|
||||
|
||||
# List passages by source_id without agent_id
|
||||
source_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
source_id=default_source.id,
|
||||
)
|
||||
|
||||
# Verify we get only source passages (3 from agent_passages_setup)
|
||||
assert len(source_passages) == 3
|
||||
assert all(p.source_id == default_source.id for p in source_passages)
|
||||
assert all(p.agent_id is None for p in source_passages)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Organization Manager Tests
|
||||
# ======================================================================================================================
|
||||
@@ -900,266 +1180,86 @@ def test_list_organizations_pagination(server: SyncServer):
|
||||
# Passage Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test creating a passage using hello_world_passage_fixture fixture"""
|
||||
assert hello_world_passage_fixture.id is not None
|
||||
assert hello_world_passage_fixture.text == "Hello, world!"
|
||||
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
|
||||
"""Test creating a passage using agent_passage_fixture fixture"""
|
||||
assert agent_passage_fixture.id is not None
|
||||
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||
|
||||
# Verify we can retrieve it
|
||||
retrieved = server.passage_manager.get_passage_by_id(
|
||||
hello_world_passage_fixture.id,
|
||||
agent_passage_fixture.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == hello_world_passage_fixture.id
|
||||
assert retrieved.text == hello_world_passage_fixture.text
|
||||
assert retrieved.id == agent_passage_fixture.id
|
||||
assert retrieved.text == agent_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_get_by_id(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test retrieving a passage by ID"""
|
||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
def test_passage_create_source(server: SyncServer, source_passage_fixture, default_user):
|
||||
"""Test creating a source passage."""
|
||||
assert source_passage_fixture is not None
|
||||
assert source_passage_fixture.text == "Hello, I am a source passage"
|
||||
|
||||
# Verify we can retrieve it
|
||||
retrieved = server.passage_manager.get_passage_by_id(
|
||||
source_passage_fixture.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == hello_world_passage_fixture.id
|
||||
assert retrieved.text == hello_world_passage_fixture.text
|
||||
assert retrieved.id == source_passage_fixture.id
|
||||
assert retrieved.text == source_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_update(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test updating a passage"""
|
||||
new_text = "Updated text"
|
||||
hello_world_passage_fixture.text = new_text
|
||||
updated = server.passage_manager.update_passage_by_id(hello_world_passage_fixture.id, hello_world_passage_fixture, actor=default_user)
|
||||
assert updated is not None
|
||||
assert updated.text == new_text
|
||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
assert retrieved.text == new_text
|
||||
|
||||
|
||||
def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test deleting a passage"""
|
||||
server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
with pytest.raises(NoResultFound):
|
||||
server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
|
||||
|
||||
def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test counting passages with filters"""
|
||||
base_passage = hello_world_passage_fixture
|
||||
|
||||
# Test total count
|
||||
total = server.passage_manager.size(actor=default_user)
|
||||
assert total == 5 # base passage + 4 test passages
|
||||
# TODO: change login passage to be a system not user passage
|
||||
|
||||
# Test count with agent filter
|
||||
agent_count = server.passage_manager.size(actor=default_user, agent_id=base_passage.agent_id)
|
||||
assert agent_count == 5
|
||||
|
||||
# Test count with role filter
|
||||
role_count = server.passage_manager.size(actor=default_user)
|
||||
assert role_count == 5
|
||||
|
||||
# Test count with non-existent filter
|
||||
empty_count = server.passage_manager.size(actor=default_user, agent_id="non-existent")
|
||||
assert empty_count == 0
|
||||
|
||||
|
||||
def test_passage_listing_basic(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test basic passage listing with limit"""
|
||||
results = server.passage_manager.list_passages(actor=default_user, limit=3)
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test cursor-based pagination functionality"""
|
||||
|
||||
# Make sure there are 5 passages
|
||||
assert server.passage_manager.size(actor=default_user) == 5
|
||||
|
||||
# Get first page
|
||||
first_page = server.passage_manager.list_passages(actor=default_user, limit=3)
|
||||
assert len(first_page) == 3
|
||||
|
||||
last_id_on_first_page = first_page[-1].id
|
||||
|
||||
# Get second page
|
||||
second_page = server.passage_manager.list_passages(actor=default_user, cursor=last_id_on_first_page, limit=3)
|
||||
assert len(second_page) == 2 # Should have 2 remaining passages
|
||||
assert all(r1.id != r2.id for r1 in first_page for r2 in second_page)
|
||||
|
||||
|
||||
def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
|
||||
"""Test filtering passages by agent ID"""
|
||||
agent_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, limit=10)
|
||||
assert len(agent_results) == 5 # base passage + 4 test passages
|
||||
assert all(msg.agent_id == hello_world_passage_fixture.agent_id for msg in agent_results)
|
||||
|
||||
|
||||
def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
|
||||
"""Test searching passages by text content"""
|
||||
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10)
|
||||
assert len(search_results) == 4
|
||||
assert all("Test passage" in msg.text for msg in search_results)
|
||||
|
||||
# Test no results
|
||||
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10)
|
||||
assert len(search_results) == 0
|
||||
|
||||
|
||||
def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_passage_fixture, default_user, default_file, sarah_agent):
|
||||
"""Test filtering passages by date range with various scenarios"""
|
||||
# Set up test data with known dates
|
||||
base_time = datetime.utcnow()
|
||||
|
||||
# Create passages at different times
|
||||
passages = []
|
||||
time_offsets = [
|
||||
timedelta(days=-2), # 2 days ago
|
||||
timedelta(days=-1), # Yesterday
|
||||
timedelta(hours=-2), # 2 hours ago
|
||||
timedelta(minutes=-30), # 30 minutes ago
|
||||
timedelta(minutes=-1), # 1 minute ago
|
||||
timedelta(minutes=0), # Now
|
||||
]
|
||||
|
||||
for i, offset in enumerate(time_offsets):
|
||||
timestamp = base_time + offset
|
||||
passage = server.passage_manager.create_passage(
|
||||
def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user):
|
||||
"""Test creating an agent passage."""
|
||||
assert agent_passage_fixture is not None
|
||||
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||
|
||||
# Try to create an invalid passage (with both agent_id and source_id)
|
||||
with pytest.raises(AssertionError):
|
||||
server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text="Invalid passage",
|
||||
agent_id="123",
|
||||
source_id="456",
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
text=f"Test passage {i}",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
embedding=[0.1] * 1024,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
created_at=timestamp,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
# Test cases
|
||||
test_cases = [
|
||||
{
|
||||
"name": "Recent passages (last hour)",
|
||||
"start_date": base_time - timedelta(hours=1),
|
||||
"end_date": base_time + timedelta(minutes=1),
|
||||
"expected_count": 1 + 3, # Should include base + -30min, -1min, and now
|
||||
},
|
||||
{
|
||||
"name": "Yesterday's passages",
|
||||
"start_date": base_time - timedelta(days=1, hours=12),
|
||||
"end_date": base_time - timedelta(hours=12),
|
||||
"expected_count": 1, # Should only include yesterday's passage
|
||||
},
|
||||
{
|
||||
"name": "Future time range",
|
||||
"start_date": base_time + timedelta(days=1),
|
||||
"end_date": base_time + timedelta(days=2),
|
||||
"expected_count": 0, # Should find no passages
|
||||
},
|
||||
{
|
||||
"name": "All time",
|
||||
"start_date": base_time - timedelta(days=3),
|
||||
"end_date": base_time + timedelta(days=1),
|
||||
"expected_count": 1 + len(passages), # Should find all passages
|
||||
},
|
||||
{
|
||||
"name": "Exact timestamp match",
|
||||
"start_date": passages[0].created_at - timedelta(microseconds=1),
|
||||
"end_date": passages[0].created_at + timedelta(microseconds=1),
|
||||
"expected_count": 1, # Should find exactly one passage
|
||||
},
|
||||
{
|
||||
"name": "Small time window",
|
||||
"start_date": base_time - timedelta(seconds=30),
|
||||
"end_date": base_time + timedelta(seconds=30),
|
||||
"expected_count": 1 + 1, # date + "now"
|
||||
},
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for case in test_cases:
|
||||
results = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id, actor=default_user, start_date=case["start_date"], end_date=case["end_date"], limit=10
|
||||
actor=default_user
|
||||
)
|
||||
|
||||
# Verify count
|
||||
assert (
|
||||
len(results) == case["expected_count"]
|
||||
), f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}"
|
||||
|
||||
# Test edge cases
|
||||
def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user):
|
||||
"""Test retrieving a passage by ID"""
|
||||
retrieved = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == agent_passage_fixture.id
|
||||
assert retrieved.text == agent_passage_fixture.text
|
||||
|
||||
# Test with start_date but no end_date
|
||||
results_start_only = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id, actor=default_user, start_date=base_time - timedelta(minutes=2), end_date=None, limit=10
|
||||
)
|
||||
assert len(results_start_only) >= 2, "Should find passages after start_date"
|
||||
|
||||
# Test with end_date but no start_date
|
||||
results_end_only = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id, actor=default_user, start_date=None, end_date=base_time - timedelta(days=1), limit=10
|
||||
)
|
||||
assert len(results_end_only) >= 1, "Should find passages before end_date"
|
||||
|
||||
# Test limit enforcement
|
||||
limited_results = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
start_date=base_time - timedelta(days=3),
|
||||
end_date=base_time + timedelta(days=1),
|
||||
limit=3,
|
||||
)
|
||||
assert len(limited_results) <= 3, "Should respect the limit parameter"
|
||||
retrieved = server.passage_manager.get_passage_by_id(source_passage_fixture.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == source_passage_fixture.id
|
||||
assert retrieved.text == source_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent):
|
||||
"""Test vector search functionality for passages."""
|
||||
passage_manager = server.passage_manager
|
||||
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
|
||||
# Create passages with different embeddings
|
||||
test_passages = [
|
||||
"I like red",
|
||||
"random text",
|
||||
"blue shoes",
|
||||
]
|
||||
|
||||
for text in test_passages:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding,
|
||||
)
|
||||
created_passage = passage_manager.create_passage(passage, default_user)
|
||||
passages.append(created_passage)
|
||||
assert passage_manager.size(actor=default_user) == len(passages)
|
||||
|
||||
# Query vector similar to "cats" embedding
|
||||
query_key = "What's my favorite color?"
|
||||
|
||||
# List passages with vector search
|
||||
results = passage_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text=query_key,
|
||||
limit=3,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
# Verify results are ordered by similarity
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "I like red"
|
||||
assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes"
|
||||
assert results[2].text == "blue shoes"
|
||||
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
|
||||
"""Test that passages are deleted when their parent (agent or source) is deleted."""
|
||||
# Verify passages exist
|
||||
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
|
||||
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||
assert agent_passage is not None
|
||||
assert source_passage is not None
|
||||
|
||||
# Delete agent and verify its passages are deleted
|
||||
server.agent_manager.delete_agent(sarah_agent.id, default_user)
|
||||
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
|
||||
assert len(agentic_passages) == 0
|
||||
|
||||
# Delete source and verify its passages are deleted
|
||||
server.source_manager.delete_source(default_source.id, default_user)
|
||||
with pytest.raises(NoResultFound):
|
||||
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -1220,6 +1320,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
|
||||
assert print_tool.organization_id == default_organization.id
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
||||
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
||||
data = print_tool.model_dump(exclude=["id"])
|
||||
@@ -1787,6 +1888,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
|
||||
# ======================================================================================================================
|
||||
# Source Manager Tests - Files
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_get_file_by_id(server: SyncServer, default_user, default_source):
|
||||
"""Test retrieving a file by ID."""
|
||||
file_metadata = PydanticFileMetadata(
|
||||
@@ -1857,6 +1959,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
|
||||
# ======================================================================================================================
|
||||
# SandboxConfigManager Tests - Sandbox Configs
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
|
||||
sandbox_config_create = SandboxConfigCreate(
|
||||
config=E2BSandboxConfig(),
|
||||
@@ -1935,6 +2038,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
|
||||
# ======================================================================================================================
|
||||
# SandboxConfigManager Tests - Environment Variables
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
|
||||
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
|
||||
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
|
||||
@@ -2007,7 +2111,6 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
|
||||
# JobManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_job(server: SyncServer, default_user):
|
||||
"""Test creating a job."""
|
||||
job_data = PydanticJob(
|
||||
|
||||
@@ -390,12 +390,16 @@ def test_user_message_memory(server, user_id, agent_id):
|
||||
|
||||
@pytest.mark.order(3)
|
||||
def test_load_data(server, user_id, agent_id):
|
||||
user = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# create source
|
||||
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
|
||||
passages_before = server.agent_manager.list_passages(
|
||||
actor=user, agent_id=agent_id, cursor=None, limit=10000
|
||||
)
|
||||
assert len(passages_before) == 0
|
||||
|
||||
source = server.source_manager.create_source(
|
||||
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=server.default_user
|
||||
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
|
||||
)
|
||||
|
||||
# load data
|
||||
@@ -409,15 +413,11 @@ def test_load_data(server, user_id, agent_id):
|
||||
connector = DummyDataConnector(archival_memories)
|
||||
server.load_data(user_id, connector, source.name)
|
||||
|
||||
# @pytest.mark.order(3)
|
||||
# def test_attach_source_to_agent(server, user_id, agent_id):
|
||||
# check archival memory size
|
||||
|
||||
# attach source
|
||||
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
||||
|
||||
# check archival memory size
|
||||
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
|
||||
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
|
||||
assert len(passages_after) == 5
|
||||
|
||||
|
||||
@@ -465,7 +465,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
user = server.user_manager.get_user_by_id(user_id=user_id)
|
||||
|
||||
# List latest 2 passages
|
||||
passages_1 = server.passage_manager.list_passages(
|
||||
passages_1 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
@@ -475,7 +475,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
|
||||
# List next 3 passages (earliest 3)
|
||||
cursor1 = passages_1[-1].id
|
||||
passages_2 = server.passage_manager.list_passages(
|
||||
passages_2 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
@@ -484,24 +484,28 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
|
||||
# List all 5
|
||||
cursor2 = passages_1[0].created_at
|
||||
passages_3 = server.passage_manager.list_passages(
|
||||
passages_3 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
end_date=cursor2,
|
||||
limit=1000,
|
||||
)
|
||||
# assert passages_1[0].text == "Cinderella wore a blue dress"
|
||||
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
|
||||
latest = passages_1[0]
|
||||
earliest = passages_2[-1]
|
||||
|
||||
# test archival memory
|
||||
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1)
|
||||
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
|
||||
assert len(passage_1) == 1
|
||||
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000)
|
||||
assert passage_1[0].text == "alpha"
|
||||
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
|
||||
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert all("alpha" not in passage.text for passage in passage_2)
|
||||
# test safe empty return
|
||||
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000)
|
||||
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
@@ -955,6 +959,14 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
|
||||
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
existing_sources = server.source_manager.list_sources(actor=actor)
|
||||
if len(existing_sources) > 0:
|
||||
for source in existing_sources:
|
||||
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
|
||||
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
assert initial_passage_count == 0
|
||||
|
||||
|
||||
# Create a source
|
||||
source = server.source_manager.create_source(
|
||||
PydanticSource(
|
||||
@@ -973,10 +985,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
# Attach source to agent first
|
||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
|
||||
|
||||
# Get initial passage count
|
||||
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
assert initial_passage_count == 0
|
||||
|
||||
# Create a job for loading the first file
|
||||
job = server.job_manager.create_job(
|
||||
PydanticJob(
|
||||
@@ -1001,7 +1009,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert job.metadata_["num_documents"] == 1
|
||||
|
||||
# Verify passages were added
|
||||
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
assert first_file_passage_count > initial_passage_count
|
||||
|
||||
# Create a second test file with different content
|
||||
@@ -1032,14 +1040,13 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert job2.metadata_["num_documents"] == 1
|
||||
|
||||
# Verify passages were appended (not replaced)
|
||||
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
|
||||
assert final_passage_count > first_file_passage_count
|
||||
|
||||
# Verify both old and new content is searchable
|
||||
passages = server.passage_manager.list_passages(
|
||||
actor=actor,
|
||||
passages = server.agent_manager.list_passages(
|
||||
agent_id=agent_id,
|
||||
source_id=source.id,
|
||||
actor=actor,
|
||||
query_text="what does Timber like to eat",
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
embed_query=True,
|
||||
@@ -1048,35 +1055,27 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
assert any("chicken" in passage.text.lower() for passage in passages)
|
||||
assert any("Anna".lower() in passage.text.lower() for passage in passages)
|
||||
|
||||
# TODO: Add this test back in after separation of `Passage tables` (LET-449)
|
||||
# # Load second agent
|
||||
# agent2 = server.load_agent(agent_id=other_agent_id)
|
||||
# Initially should have no passages
|
||||
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
|
||||
assert initial_agent2_passages == 0
|
||||
|
||||
# # Initially should have no passages
|
||||
# initial_agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
|
||||
# assert initial_agent2_passages == 0
|
||||
# Attach source to second agent
|
||||
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)
|
||||
|
||||
# # Attach source to second agent
|
||||
# agent2.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)
|
||||
# Verify second agent has same number of passages as first agent
|
||||
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
|
||||
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
|
||||
assert agent2_passages == agent1_passages
|
||||
|
||||
# # Verify second agent has same number of passages as first agent
|
||||
# agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
|
||||
# agent1_passages = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
|
||||
# assert agent2_passages == agent1_passages
|
||||
|
||||
# # Verify second agent can query the same content
|
||||
# passages2 = server.passage_manager.list_passages(
|
||||
# actor=user,
|
||||
# agent_id=other_agent_id,
|
||||
# source_id=source.id,
|
||||
# query_text="what does Timber like to eat",
|
||||
# embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
# embed_query=True,
|
||||
# limit=10,
|
||||
# )
|
||||
# assert len(passages2) == len(passages)
|
||||
# assert any("chicken" in passage.text.lower() for passage in passages2)
|
||||
# assert any("sleep" in passage.text.lower() for passage in passages2)
|
||||
|
||||
# # Cleanup
|
||||
# server.delete_agent(user_id=user_id, agent_id=agent2_state.id)
|
||||
# Verify second agent can query the same content
|
||||
passages2 = server.agent_manager.list_passages(
|
||||
actor=actor,
|
||||
agent_id=other_agent_id,
|
||||
source_id=source.id,
|
||||
query_text="what does Timber like to eat",
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
embed_query=True,
|
||||
)
|
||||
assert len(passages2) == len(passages)
|
||||
assert any("chicken" in passage.text.lower() for passage in passages2)
|
||||
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
|
||||
|
||||
Reference in New Issue
Block a user