feat: separate Passages tables (#2245)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2024-12-16 15:24:20 -08:00
committed by GitHub
parent 10e610bb95
commit e2d916148e
19 changed files with 1026 additions and 546 deletions

View File

@@ -482,7 +482,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# check agent archival memory size
archival_memories = client.get_archival_memory(agent_id=agent.id)
print(archival_memories)
assert len(archival_memories) == 0
# load a file into a source (non-blocking job)

View File

@@ -2,6 +2,8 @@ import os
import time
from datetime import datetime, timedelta
from httpx._transports import default
from numpy import source
import pytest
from sqlalchemy import delete
from sqlalchemy.exc import IntegrityError
@@ -17,7 +19,8 @@ from letta.orm import (
Job,
Message,
Organization,
Passage,
AgentPassage,
SourcePassage,
SandboxConfig,
SandboxEnvironmentVariable,
Source,
@@ -82,7 +85,8 @@ def clear_tables(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Message))
session.execute(delete(Passage))
session.execute(delete(AgentPassage))
session.execute(delete(SourcePassage))
session.execute(delete(Job))
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
session.execute(delete(BlocksAgents))
@@ -189,39 +193,79 @@ def print_tool(server: SyncServer, default_user, default_organization):
@pytest.fixture
def hello_world_passage_fixture(server: SyncServer, default_user, default_file, sarah_agent):
"""Fixture to create a tool with default settings and clean up after the test."""
# Set up passage
dummy_embedding = [0.0] * 2
message = PydanticPassage(
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
file_id=default_file.id,
text="Hello, world!",
embedding=dummy_embedding,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
"""Fixture to create an agent passage."""
passage = server.passage_manager.create_passage(
PydanticPassage(
text="Hello, I am an agent passage",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
),
actor=default_user
)
msg = server.passage_manager.create_passage(message, actor=default_user)
yield msg
yield passage
@pytest.fixture
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent) -> list[PydanticPassage]:
"""Helper function to create test passages for all tests"""
dummy_embedding = [0] * 2
passages = [
def source_passage_fixture(server: SyncServer, default_user, default_file, default_source):
"""Fixture to create a source passage."""
passage = server.passage_manager.create_passage(
PydanticPassage(
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
text="Hello, I am a source passage",
source_id=default_source.id,
file_id=default_file.id,
text=f"Test passage {i}",
embedding=dummy_embedding,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
),
actor=default_user
)
yield passage
@pytest.fixture
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent, default_source):
"""Helper function to create test passages for all tests."""
# Create agent passages
passages = []
for i in range(5):
passage = server.passage_manager.create_passage(
PydanticPassage(
text=f"Agent passage {i}",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
),
actor=default_user
)
for i in range(4)
]
server.passage_manager.create_many_passages(passages, actor=default_user)
passages.append(passage)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
# Create source passages
for i in range(5):
passage = server.passage_manager.create_passage(
PydanticPassage(
text=f"Source passage {i}",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
),
actor=default_user
)
passages.append(passage)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
return passages
@@ -389,6 +433,49 @@ def server():
return server
@pytest.fixture
def agent_passages_setup(server, default_source, default_user, sarah_agent):
"""Setup fixture for agent passages tests"""
agent_id = sarah_agent.id
actor = default_user
server.agent_manager.attach_source(agent_id=agent_id, source_id=default_source.id, actor=actor)
# Create some source passages
source_passages = []
for i in range(3):
passage = server.passage_manager.create_passage(
PydanticPassage(
organization_id=actor.organization_id,
source_id=default_source.id,
text=f"Source passage {i}",
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=actor
)
source_passages.append(passage)
# Create some agent passages
agent_passages = []
for i in range(2):
passage = server.passage_manager.create_passage(
PydanticPassage(
organization_id=actor.organization_id,
agent_id=agent_id,
text=f"Agent passage {i}",
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=actor
)
agent_passages.append(passage)
yield agent_passages, source_passages
# Cleanup
server.source_manager.delete_source(default_source.id, actor=actor)
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -849,6 +936,199 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
assert block.label == default_block.label
# ======================================================================================================================
# Agent Manager - Passages Tests
# ======================================================================================================================
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
"""Test basic listing functionality of agent passages"""
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
assert len(all_passages) == 5 # 3 source + 2 agent passages
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
"""Test ordering of agent passages"""
# Test ascending order
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
assert len(asc_passages) == 5
for i in range(1, len(asc_passages)):
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
# Test descending order
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
assert len(desc_passages) == 5
for i in range(1, len(desc_passages)):
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
"""Test pagination of agent passages"""
# Test limit
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
assert len(limited_passages) == 3
# Test cursor-based pagination
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
assert len(first_page) == 2
second_page = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
cursor=first_page[-1].id,
limit=2,
ascending=True
)
assert len(second_page) == 2
assert first_page[-1].id != second_page[0].id
assert first_page[-1].created_at <= second_page[0].created_at
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
"""Test text search functionality of agent passages"""
# Test text search for source passages
source_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text="Source passage"
)
assert len(source_text_passages) == 3
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text="Agent passage"
)
assert len(agent_text_passages) == 2
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
"""Test text search functionality of agent passages"""
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
agent_only=True
)
assert len(agent_text_passages) == 2
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
"""Test filtering functionality of agent passages"""
# Test source filtering
source_filtered = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
source_id=default_source.id
)
assert len(source_filtered) == 3
# Test date filtering
now = datetime.utcnow()
future_date = now + timedelta(days=1)
past_date = now - timedelta(days=1)
date_filtered = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
start_date=past_date,
end_date=future_date
)
assert len(date_filtered) == 5
def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source):
"""Test vector search functionality of agent passages"""
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
# Create passages with known embeddings
passages = []
# Create passages with different embeddings
test_passages = [
"I like red",
"random text",
"blue shoes",
]
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
for i, text in enumerate(test_passages):
embedding = embed_model.get_text_embedding(text)
if i % 2 == 0:
passage = PydanticPassage(
text=text,
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding
)
else:
passage = PydanticPassage(
text=text,
organization_id=default_user.organization_id,
source_id=default_source.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding
)
created_passage = server.passage_manager.create_passage(passage, default_user)
passages.append(created_passage)
# Query vector similar to "red" embedding
query_key = "What's my favorite color?"
# Test vector search with all passages
results = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text=query_key,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
)
# Verify results are ordered by similarity
assert len(results) == 3
assert results[0].text == "I like red"
assert "random" in results[1].text or "random" in results[2].text
assert "blue" in results[1].text or "blue" in results[2].text
# Test vector search with agent_only=True
agent_only_results = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text=query_key,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
agent_only=True
)
# Verify agent-only results
assert len(agent_only_results) == 2
assert agent_only_results[0].text == "I like red"
assert agent_only_results[1].text == "blue shoes"
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
"""Test listing passages from a source without specifying an agent."""
# List passages by source_id without agent_id
source_passages = server.agent_manager.list_passages(
actor=default_user,
source_id=default_source.id,
)
# Verify we get only source passages (3 from agent_passages_setup)
assert len(source_passages) == 3
assert all(p.source_id == default_source.id for p in source_passages)
assert all(p.agent_id is None for p in source_passages)
# ======================================================================================================================
# Organization Manager Tests
# ======================================================================================================================
@@ -900,266 +1180,86 @@ def test_list_organizations_pagination(server: SyncServer):
# Passage Manager Tests
# ======================================================================================================================
def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user):
"""Test creating a passage using hello_world_passage_fixture fixture"""
assert hello_world_passage_fixture.id is not None
assert hello_world_passage_fixture.text == "Hello, world!"
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
"""Test creating a passage using agent_passage_fixture fixture"""
assert agent_passage_fixture.id is not None
assert agent_passage_fixture.text == "Hello, I am an agent passage"
# Verify we can retrieve it
retrieved = server.passage_manager.get_passage_by_id(
hello_world_passage_fixture.id,
agent_passage_fixture.id,
actor=default_user,
)
assert retrieved is not None
assert retrieved.id == hello_world_passage_fixture.id
assert retrieved.text == hello_world_passage_fixture.text
assert retrieved.id == agent_passage_fixture.id
assert retrieved.text == agent_passage_fixture.text
def test_passage_get_by_id(server: SyncServer, hello_world_passage_fixture, default_user):
"""Test retrieving a passage by ID"""
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
def test_passage_create_source(server: SyncServer, source_passage_fixture, default_user):
"""Test creating a source passage."""
assert source_passage_fixture is not None
assert source_passage_fixture.text == "Hello, I am a source passage"
# Verify we can retrieve it
retrieved = server.passage_manager.get_passage_by_id(
source_passage_fixture.id,
actor=default_user,
)
assert retrieved is not None
assert retrieved.id == hello_world_passage_fixture.id
assert retrieved.text == hello_world_passage_fixture.text
assert retrieved.id == source_passage_fixture.id
assert retrieved.text == source_passage_fixture.text
def test_passage_update(server: SyncServer, hello_world_passage_fixture, default_user):
"""Test updating a passage"""
new_text = "Updated text"
hello_world_passage_fixture.text = new_text
updated = server.passage_manager.update_passage_by_id(hello_world_passage_fixture.id, hello_world_passage_fixture, actor=default_user)
assert updated is not None
assert updated.text == new_text
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
assert retrieved.text == new_text
def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user):
"""Test deleting a passage"""
server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
with pytest.raises(NoResultFound):
server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
"""Test counting passages with filters"""
base_passage = hello_world_passage_fixture
# Test total count
total = server.passage_manager.size(actor=default_user)
assert total == 5 # base passage + 4 test passages
# TODO: change login passage to be a system not user passage
# Test count with agent filter
agent_count = server.passage_manager.size(actor=default_user, agent_id=base_passage.agent_id)
assert agent_count == 5
# Test count with role filter
role_count = server.passage_manager.size(actor=default_user)
assert role_count == 5
# Test count with non-existent filter
empty_count = server.passage_manager.size(actor=default_user, agent_id="non-existent")
assert empty_count == 0
def test_passage_listing_basic(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
"""Test basic passage listing with limit"""
results = server.passage_manager.list_passages(actor=default_user, limit=3)
assert len(results) == 3
def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
"""Test cursor-based pagination functionality"""
# Make sure there are 5 passages
assert server.passage_manager.size(actor=default_user) == 5
# Get first page
first_page = server.passage_manager.list_passages(actor=default_user, limit=3)
assert len(first_page) == 3
last_id_on_first_page = first_page[-1].id
# Get second page
second_page = server.passage_manager.list_passages(actor=default_user, cursor=last_id_on_first_page, limit=3)
assert len(second_page) == 2 # Should have 2 remaining passages
assert all(r1.id != r2.id for r1 in first_page for r2 in second_page)
def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
"""Test filtering passages by agent ID"""
agent_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, limit=10)
assert len(agent_results) == 5 # base passage + 4 test passages
assert all(msg.agent_id == hello_world_passage_fixture.agent_id for msg in agent_results)
def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
"""Test searching passages by text content"""
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10)
assert len(search_results) == 4
assert all("Test passage" in msg.text for msg in search_results)
# Test no results
search_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10)
assert len(search_results) == 0
def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_passage_fixture, default_user, default_file, sarah_agent):
"""Test filtering passages by date range with various scenarios"""
# Set up test data with known dates
base_time = datetime.utcnow()
# Create passages at different times
passages = []
time_offsets = [
timedelta(days=-2), # 2 days ago
timedelta(days=-1), # Yesterday
timedelta(hours=-2), # 2 hours ago
timedelta(minutes=-30), # 30 minutes ago
timedelta(minutes=-1), # 1 minute ago
timedelta(minutes=0), # Now
]
for i, offset in enumerate(time_offsets):
timestamp = base_time + offset
passage = server.passage_manager.create_passage(
def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user):
"""Test creating an agent passage."""
assert agent_passage_fixture is not None
assert agent_passage_fixture.text == "Hello, I am an agent passage"
# Try to create an invalid passage (with both agent_id and source_id)
with pytest.raises(AssertionError):
server.passage_manager.create_passage(
PydanticPassage(
text="Invalid passage",
agent_id="123",
source_id="456",
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
file_id=default_file.id,
text=f"Test passage {i}",
embedding=[0.1, 0.2, 0.3],
embedding=[0.1] * 1024,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
created_at=timestamp,
),
actor=default_user,
)
passages.append(passage)
# Test cases
test_cases = [
{
"name": "Recent passages (last hour)",
"start_date": base_time - timedelta(hours=1),
"end_date": base_time + timedelta(minutes=1),
"expected_count": 1 + 3, # Should include base + -30min, -1min, and now
},
{
"name": "Yesterday's passages",
"start_date": base_time - timedelta(days=1, hours=12),
"end_date": base_time - timedelta(hours=12),
"expected_count": 1, # Should only include yesterday's passage
},
{
"name": "Future time range",
"start_date": base_time + timedelta(days=1),
"end_date": base_time + timedelta(days=2),
"expected_count": 0, # Should find no passages
},
{
"name": "All time",
"start_date": base_time - timedelta(days=3),
"end_date": base_time + timedelta(days=1),
"expected_count": 1 + len(passages), # Should find all passages
},
{
"name": "Exact timestamp match",
"start_date": passages[0].created_at - timedelta(microseconds=1),
"end_date": passages[0].created_at + timedelta(microseconds=1),
"expected_count": 1, # Should find exactly one passage
},
{
"name": "Small time window",
"start_date": base_time - timedelta(seconds=30),
"end_date": base_time + timedelta(seconds=30),
"expected_count": 1 + 1, # date + "now"
},
]
# Run test cases
for case in test_cases:
results = server.passage_manager.list_passages(
agent_id=sarah_agent.id, actor=default_user, start_date=case["start_date"], end_date=case["end_date"], limit=10
actor=default_user
)
# Verify count
assert (
len(results) == case["expected_count"]
), f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}"
# Test edge cases
def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user):
"""Test retrieving a passage by ID"""
retrieved = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, actor=default_user)
assert retrieved is not None
assert retrieved.id == agent_passage_fixture.id
assert retrieved.text == agent_passage_fixture.text
# Test with start_date but no end_date
results_start_only = server.passage_manager.list_passages(
agent_id=sarah_agent.id, actor=default_user, start_date=base_time - timedelta(minutes=2), end_date=None, limit=10
)
assert len(results_start_only) >= 2, "Should find passages after start_date"
# Test with end_date but no start_date
results_end_only = server.passage_manager.list_passages(
agent_id=sarah_agent.id, actor=default_user, start_date=None, end_date=base_time - timedelta(days=1), limit=10
)
assert len(results_end_only) >= 1, "Should find passages before end_date"
# Test limit enforcement
limited_results = server.passage_manager.list_passages(
agent_id=sarah_agent.id,
actor=default_user,
start_date=base_time - timedelta(days=3),
end_date=base_time + timedelta(days=1),
limit=3,
)
assert len(limited_results) <= 3, "Should respect the limit parameter"
retrieved = server.passage_manager.get_passage_by_id(source_passage_fixture.id, actor=default_user)
assert retrieved is not None
assert retrieved.id == source_passage_fixture.id
assert retrieved.text == source_passage_fixture.text
def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent):
"""Test vector search functionality for passages."""
passage_manager = server.passage_manager
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
# Create passages with known embeddings
passages = []
# Create passages with different embeddings
test_passages = [
"I like red",
"random text",
"blue shoes",
]
for text in test_passages:
embedding = embed_model.get_text_embedding(text)
passage = PydanticPassage(
text=text,
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding,
)
created_passage = passage_manager.create_passage(passage, default_user)
passages.append(created_passage)
assert passage_manager.size(actor=default_user) == len(passages)
# Query vector similar to "cats" embedding
query_key = "What's my favorite color?"
# List passages with vector search
results = passage_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text=query_key,
limit=3,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
)
# Verify results are ordered by similarity
assert len(results) == 3
assert results[0].text == "I like red"
assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes"
assert results[2].text == "blue shoes"
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
"""Test that passages are deleted when their parent (agent or source) is deleted."""
# Verify passages exist
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
assert agent_passage is not None
assert source_passage is not None
# Delete agent and verify its passages are deleted
server.agent_manager.delete_agent(sarah_agent.id, default_user)
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
assert len(agentic_passages) == 0
# Delete source and verify its passages are deleted
server.source_manager.delete_source(default_source.id, default_user)
with pytest.raises(NoResultFound):
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
# ======================================================================================================================
@@ -1220,6 +1320,7 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
assert print_tool.organization_id == default_organization.id
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
data = print_tool.model_dump(exclude=["id"])
@@ -1787,6 +1888,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
# ======================================================================================================================
# Source Manager Tests - Files
# ======================================================================================================================
def test_get_file_by_id(server: SyncServer, default_user, default_source):
"""Test retrieving a file by ID."""
file_metadata = PydanticFileMetadata(
@@ -1857,6 +1959,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
# ======================================================================================================================
# SandboxConfigManager Tests - Sandbox Configs
# ======================================================================================================================
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
sandbox_config_create = SandboxConfigCreate(
config=E2BSandboxConfig(),
@@ -1935,6 +2038,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
# ======================================================================================================================
# SandboxConfigManager Tests - Environment Variables
# ======================================================================================================================
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
@@ -2007,7 +2111,6 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
# JobManager Tests
# ======================================================================================================================
def test_create_job(server: SyncServer, default_user):
"""Test creating a job."""
job_data = PydanticJob(

View File

@@ -390,12 +390,16 @@ def test_user_message_memory(server, user_id, agent_id):
@pytest.mark.order(3)
def test_load_data(server, user_id, agent_id):
user = server.user_manager.get_user_or_default(user_id=user_id)
# create source
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
passages_before = server.agent_manager.list_passages(
actor=user, agent_id=agent_id, cursor=None, limit=10000
)
assert len(passages_before) == 0
source = server.source_manager.create_source(
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=server.default_user
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
)
# load data
@@ -409,15 +413,11 @@ def test_load_data(server, user_id, agent_id):
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)
# @pytest.mark.order(3)
# def test_attach_source_to_agent(server, user_id, agent_id):
# check archival memory size
# attach source
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
# check archival memory size
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
assert len(passages_after) == 5
@@ -465,7 +465,7 @@ def test_get_archival_memory(server, user_id, agent_id):
user = server.user_manager.get_user_by_id(user_id=user_id)
# List latest 2 passages
passages_1 = server.passage_manager.list_passages(
passages_1 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
@@ -475,7 +475,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.passage_manager.list_passages(
passages_2 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
@@ -484,24 +484,28 @@ def test_get_archival_memory(server, user_id, agent_id):
# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.passage_manager.list_passages(
passages_3 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
limit=1000,
)
# assert passages_1[0].text == "Cinderella wore a blue dress"
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
latest = passages_1[0]
earliest = passages_2[-1]
# test archival memory
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1)
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000)
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000)
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0
@@ -955,6 +959,14 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
actor = server.user_manager.get_user_or_default(user_id)
existing_sources = server.source_manager.list_sources(actor=actor)
if len(existing_sources) > 0:
for source in existing_sources:
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert initial_passage_count == 0
# Create a source
source = server.source_manager.create_source(
PydanticSource(
@@ -973,10 +985,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
# Attach source to agent first
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
# Get initial passage count
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
assert initial_passage_count == 0
# Create a job for loading the first file
job = server.job_manager.create_job(
PydanticJob(
@@ -1001,7 +1009,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert job.metadata_["num_documents"] == 1
# Verify passages were added
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert first_file_passage_count > initial_passage_count
# Create a second test file with different content
@@ -1032,14 +1040,13 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert job2.metadata_["num_documents"] == 1
# Verify passages were appended (not replaced)
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert final_passage_count > first_file_passage_count
# Verify both old and new content is searchable
passages = server.passage_manager.list_passages(
actor=actor,
passages = server.agent_manager.list_passages(
agent_id=agent_id,
source_id=source.id,
actor=actor,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
@@ -1048,35 +1055,27 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert any("chicken" in passage.text.lower() for passage in passages)
assert any("Anna".lower() in passage.text.lower() for passage in passages)
# TODO: Add this test back in after separation of `Passage tables` (LET-449)
# # Load second agent
# agent2 = server.load_agent(agent_id=other_agent_id)
# Initially should have no passages
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
assert initial_agent2_passages == 0
# # Initially should have no passages
# initial_agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
# assert initial_agent2_passages == 0
# Attach source to second agent
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)
# # Attach source to second agent
# agent2.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)
# Verify second agent has same number of passages as first agent
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
assert agent2_passages == agent1_passages
# # Verify second agent has same number of passages as first agent
# agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
# agent1_passages = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
# assert agent2_passages == agent1_passages
# # Verify second agent can query the same content
# passages2 = server.passage_manager.list_passages(
# actor=user,
# agent_id=other_agent_id,
# source_id=source.id,
# query_text="what does Timber like to eat",
# embedding_config=EmbeddingConfig.default_config(provider="openai"),
# embed_query=True,
# limit=10,
# )
# assert len(passages2) == len(passages)
# assert any("chicken" in passage.text.lower() for passage in passages2)
# assert any("sleep" in passage.text.lower() for passage in passages2)
# # Cleanup
# server.delete_agent(user_id=user_id, agent_id=agent2_state.id)
# Verify second agent can query the same content
passages2 = server.agent_manager.list_passages(
actor=actor,
agent_id=other_agent_id,
source_id=source.id,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
)
assert len(passages2) == len(passages)
assert any("chicken" in passage.text.lower() for passage in passages2)
assert any("Anna".lower() in passage.text.lower() for passage in passages2)