import json import os from datetime import datetime, timedelta, timezone from unittest.mock import Mock import pytest from _pytest.python_api import approx # Import shared fixtures and constants from conftest from conftest import ( DEFAULT_EMBEDDING_CONFIG, ) from letta.orm.errors import NoResultFound from letta.schemas.enums import ( TagMatchMode, ) from letta.schemas.passage import Passage as PydanticPassage from letta.server.server import SyncServer # ====================================================================================================================== # Agent Manager - Passages Tests # ====================================================================================================================== @pytest.mark.asyncio async def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test basic listing functionality of agent passages""" all_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id) assert len(all_passages) == 5 # 3 source + 2 agent passages source_passages = await server.agent_manager.query_source_passages_async(actor=default_user, agent_id=sarah_agent.id) assert len(source_passages) == 3 # 3 source + 2 agent passages @pytest.mark.asyncio async def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test ordering of agent passages""" # Test ascending order asc_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, ascending=True) assert len(asc_passages) == 5 for i in range(1, len(asc_passages)): assert asc_passages[i - 1].created_at <= asc_passages[i].created_at # Test descending order desc_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, ascending=False) assert len(desc_passages) == 5 for i in range(1, len(desc_passages)): assert desc_passages[i - 1].created_at >= desc_passages[i].created_at @pytest.mark.asyncio async def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test pagination of agent passages""" # Test limit limited_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=3) assert len(limited_passages) == 3 # Test cursor-based pagination first_page = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True) assert len(first_page) == 2 second_page = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, after=first_page[-1].id, limit=2, ascending=True ) assert len(second_page) == 2 assert first_page[-1].id != second_page[0].id assert first_page[-1].created_at <= second_page[0].created_at """ [1] [2] * * | * * [mid] * | * * | * """ middle_page = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, before=second_page[-1].id, after=first_page[0].id, ascending=True ) assert len(middle_page) == 2 assert middle_page[0].id == first_page[-1].id assert middle_page[1].id == second_page[0].id middle_page_desc = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, before=second_page[-1].id, after=first_page[0].id, ascending=False ) assert len(middle_page_desc) == 2 assert middle_page_desc[0].id == second_page[0].id assert middle_page_desc[1].id == first_page[-1].id @pytest.mark.asyncio async def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test text search functionality of agent passages""" # Test text search for source passages source_text_passages = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, query_text="Source passage" ) assert len(source_text_passages) == 3 # Test text search for agent passages agent_text_passages = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage" ) assert len(agent_text_passages) == 2 @pytest.mark.asyncio async def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup, disable_turbopuffer): """Test text search functionality of agent passages""" # Test text search for agent passages agent_text_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, agent_only=True) assert len(agent_text_passages) == 2 @pytest.mark.asyncio async def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup, disable_turbopuffer): """Test filtering functionality of agent passages""" # Test source filtering source_filtered = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id ) assert len(source_filtered) == 3 # Test date filtering now = datetime.now(timezone.utc) future_date = now + timedelta(days=1) past_date = now - timedelta(days=1) date_filtered = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date ) assert len(date_filtered) == 5 @pytest.mark.asyncio async def test_agent_query_passages_time_only(server, default_user, default_archive, disable_turbopuffer): """Test querying passages with date filters and no query text.""" now = datetime.now(timezone.utc) older_date = now - timedelta(days=2) newer_date = now - timedelta(hours=2) older_passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( organization_id=default_user.organization_id, archive_id=default_archive.id, text="Older passage", embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, created_at=older_date, ), actor=default_user, ) newer_passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( organization_id=default_user.organization_id, archive_id=default_archive.id, text="Newer passage", embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, created_at=newer_date, ), actor=default_user, ) results = await server.agent_manager.query_agent_passages_async( actor=default_user, archive_id=default_archive.id, start_date=now - timedelta(days=1), end_date=now + timedelta(minutes=1), ) assert len(results) == 1 passage, _, _ = results[0] assert passage.id == newer_passage.id assert passage.id != older_passage.id assert passage.created_at >= now - timedelta(days=1) assert passage.created_at <= now + timedelta(minutes=1) @pytest.fixture def mock_embeddings(): """Load mock embeddings from JSON file""" fixture_path = os.path.join(os.path.dirname(__file__), "data", "test_embeddings.json") with open(fixture_path, "r") as f: return json.load(f) @pytest.fixture def mock_embed_model(mock_embeddings): """Mock embedding model that returns predefined embeddings""" mock_model = Mock() mock_model.get_text_embedding = lambda text: mock_embeddings.get(text, [0.0] * 1536) return mock_model async def test_agent_list_passages_vector_search( server, default_user, sarah_agent, default_source, default_file, mock_embed_model, disable_turbopuffer ): """Test vector search functionality of agent passages""" embed_model = mock_embed_model # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create passages with known embeddings passages = [] # Create passages with different embeddings test_passages = [ "I like red", "random text", "blue shoes", ] await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) for i, text in enumerate(test_passages): embedding = embed_model.get_text_embedding(text) if i % 2 == 0: # Create agent passage passage = PydanticPassage( text=text, organization_id=default_user.organization_id, archive_id=archive.id, embedding_config=DEFAULT_EMBEDDING_CONFIG, embedding=embedding, ) created_passage = await server.passage_manager.create_agent_passage_async(passage, default_user) else: # Create source passage passage = PydanticPassage( text=text, organization_id=default_user.organization_id, source_id=default_source.id, file_id=default_file.id, embedding_config=DEFAULT_EMBEDDING_CONFIG, embedding=embedding, ) created_passage = await server.passage_manager.create_source_passage_async(passage, default_file, default_user) passages.append(created_passage) # Query vector similar to "red" embedding query_key = "What's my favorite color?" # Test vector search with all passages results = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, query_text=query_key, embedding_config=DEFAULT_EMBEDDING_CONFIG, embed_query=True, ) # Verify results are ordered by similarity assert len(results) == 3 assert results[0].text == "I like red" assert "random" in results[1].text or "random" in results[2].text assert "blue" in results[1].text or "blue" in results[2].text # Test vector search with agent_only=True agent_only_results = await server.agent_manager.list_passages_async( actor=default_user, agent_id=sarah_agent.id, query_text=query_key, embedding_config=DEFAULT_EMBEDDING_CONFIG, embed_query=True, agent_only=True, ) # Verify agent-only results assert len(agent_only_results) == 2 assert agent_only_results[0].text == "I like red" assert agent_only_results[1].text == "blue shoes" @pytest.mark.asyncio async def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup): """Test listing passages from a source without specifying an agent.""" # List passages by source_id without agent_id source_passages = await server.agent_manager.list_passages_async( actor=default_user, source_id=default_source.id, ) # Verify we get only source passages (3 from agent_passages_setup) assert len(source_passages) == 3 assert all(p.source_id == default_source.id for p in source_passages) assert all(p.archive_id is None for p in source_passages) # ====================================================================================================================== # Passage Manager Tests # ====================================================================================================================== @pytest.mark.asyncio async def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user): """Test creating a passage using agent_passage_fixture fixture""" assert agent_passage_fixture.id is not None assert agent_passage_fixture.text == "Hello, I am an agent passage" # Verify we can retrieve it retrieved = await server.passage_manager.get_passage_by_id_async( agent_passage_fixture.id, actor=default_user, ) assert retrieved is not None assert retrieved.id == agent_passage_fixture.id assert retrieved.text == agent_passage_fixture.text @pytest.mark.asyncio async def test_passage_create_source(server: SyncServer, source_passage_fixture, default_user): """Test creating a source passage.""" assert source_passage_fixture is not None assert source_passage_fixture.text == "Hello, I am a source passage" # Verify we can retrieve it retrieved = await server.passage_manager.get_passage_by_id_async( source_passage_fixture.id, actor=default_user, ) assert retrieved is not None assert retrieved.id == source_passage_fixture.id assert retrieved.text == source_passage_fixture.text @pytest.mark.asyncio async def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, default_user): """Test creating an agent passage.""" assert agent_passage_fixture is not None assert agent_passage_fixture.text == "Hello, I am an agent passage" # Try to create an invalid passage (with both archive_id and source_id) with pytest.raises(AssertionError): await server.passage_manager.create_passage_async( PydanticPassage( text="Invalid passage", archive_id="123", source_id="456", organization_id=default_user.organization_id, embedding=[0.1] * 1024, embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) @pytest.mark.asyncio async def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user): """Test retrieving a passage by ID""" retrieved = await server.passage_manager.get_passage_by_id_async(agent_passage_fixture.id, actor=default_user) assert retrieved is not None assert retrieved.id == agent_passage_fixture.id assert retrieved.text == agent_passage_fixture.text retrieved = await server.passage_manager.get_passage_by_id_async(source_passage_fixture.id, actor=default_user) assert retrieved is not None assert retrieved.id == source_passage_fixture.id assert retrieved.text == source_passage_fixture.text @pytest.mark.asyncio async def test_passage_cascade_deletion( server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent ): """Test that passages are deleted when their parent (agent or source) is deleted.""" # Verify passages exist agent_passage = await server.passage_manager.get_passage_by_id_async(agent_passage_fixture.id, default_user) source_passage = await server.passage_manager.get_passage_by_id_async(source_passage_fixture.id, default_user) assert agent_passage is not None assert source_passage is not None # Delete agent and verify its passages are deleted await server.agent_manager.delete_agent_async(sarah_agent.id, default_user) agentic_passages = await server.agent_manager.list_passages_async(actor=default_user, agent_id=sarah_agent.id, agent_only=True) assert len(agentic_passages) == 0 @pytest.mark.asyncio async def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test creating an agent passage using the new agent-specific method.""" # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( text="Test agent passage via specific method", archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, metadata={"type": "test_specific"}, tags=["python", "test", "agent"], ), actor=default_user, ) assert passage.id is not None assert passage.text == "Test agent passage via specific method" assert passage.archive_id == archive.id assert passage.source_id is None assert sorted(passage.tags) == sorted(["python", "test", "agent"]) @pytest.mark.asyncio async def test_create_source_passage_specific(server: SyncServer, default_user, default_file, default_source): """Test creating a source passage using the new source-specific method.""" passage = await server.passage_manager.create_source_passage_async( PydanticPassage( text="Test source passage via specific method", source_id=default_source.id, file_id=default_file.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, metadata={"type": "test_specific"}, tags=["document", "test", "source"], ), file_metadata=default_file, actor=default_user, ) assert passage.id is not None assert passage.text == "Test source passage via specific method" assert passage.source_id == default_source.id assert passage.archive_id is None assert sorted(passage.tags) == sorted(["document", "test", "source"]) @pytest.mark.asyncio async def test_create_agent_passage_validation(server: SyncServer, default_user, default_source, sarah_agent): """Test that agent passage creation validates inputs correctly.""" # Should fail if archive_id is missing with pytest.raises(ValueError, match="Agent passage must have archive_id"): await server.passage_manager.create_agent_passage_async( PydanticPassage( text="Invalid agent passage", organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Should fail if source_id is present with pytest.raises(ValueError, match="Agent passage cannot have source_id"): await server.passage_manager.create_agent_passage_async( PydanticPassage( text="Invalid agent passage", archive_id=archive.id, source_id=default_source.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) @pytest.mark.asyncio async def test_create_source_passage_validation(server: SyncServer, default_user, default_file, default_source, sarah_agent): """Test that source passage creation validates inputs correctly.""" # Should fail if source_id is missing with pytest.raises(ValueError, match="Source passage must have source_id"): await server.passage_manager.create_source_passage_async( PydanticPassage( text="Invalid source passage", organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), file_metadata=default_file, actor=default_user, ) # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Should fail if archive_id is present with pytest.raises(ValueError, match="Source passage cannot have archive_id"): await server.passage_manager.create_source_passage_async( PydanticPassage( text="Invalid source passage", source_id=default_source.id, archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), file_metadata=default_file, actor=default_user, ) @pytest.mark.asyncio async def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sarah_agent): """Test retrieving an agent passage using the new agent-specific method.""" # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create an agent passage passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( text="Agent passage for retrieval test", archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) # Retrieve it using the specific method retrieved = await server.passage_manager.get_agent_passage_by_id_async(passage.id, actor=default_user) assert retrieved is not None assert retrieved.id == passage.id assert retrieved.text == passage.text assert retrieved.archive_id == archive.id @pytest.mark.asyncio async def test_get_source_passage_by_id_specific(server: SyncServer, default_user, default_file, default_source): """Test retrieving a source passage using the new source-specific method.""" # Create a source passage passage = await server.passage_manager.create_source_passage_async( PydanticPassage( text="Source passage for retrieval test", source_id=default_source.id, file_id=default_file.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), file_metadata=default_file, actor=default_user, ) # Retrieve it using the specific method retrieved = await server.passage_manager.get_source_passage_by_id_async(passage.id, actor=default_user) assert retrieved is not None assert retrieved.id == passage.id assert retrieved.text == passage.text assert retrieved.source_id == default_source.id @pytest.mark.asyncio async def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sarah_agent, default_file, default_source): """Test that trying to get the wrong passage type with specific methods fails.""" # Create an agent passage # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) agent_passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( text="Agent passage", archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) # Create a source passage source_passage = await server.passage_manager.create_source_passage_async( PydanticPassage( text="Source passage", source_id=default_source.id, file_id=default_file.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), file_metadata=default_file, actor=default_user, ) # Trying to get agent passage with source method should fail with pytest.raises(NoResultFound): await server.passage_manager.get_source_passage_by_id_async(agent_passage.id, actor=default_user) # Trying to get source passage with agent method should fail with pytest.raises(NoResultFound): await server.passage_manager.get_agent_passage_by_id_async(source_passage.id, actor=default_user) @pytest.mark.asyncio async def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test updating an agent passage using the new agent-specific method.""" # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create an agent passage passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( text="Original agent passage text", archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) # Update it updated_passage = await server.passage_manager.update_agent_passage_by_id_async( passage.id, PydanticPassage( text="Updated agent passage text", archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.2], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) assert updated_passage.text == "Updated agent passage text" assert updated_passage.embedding[0] == approx(0.2) assert updated_passage.id == passage.id @pytest.mark.asyncio async def test_update_source_passage_specific(server: SyncServer, default_user, default_file, default_source): """Test updating a source passage using the new source-specific method.""" # Create a source passage passage = await server.passage_manager.create_source_passage_async( PydanticPassage( text="Original source passage text", source_id=default_source.id, file_id=default_file.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), file_metadata=default_file, actor=default_user, ) # Update it updated_passage = await server.passage_manager.update_source_passage_by_id_async( passage.id, PydanticPassage( text="Updated source passage text", source_id=default_source.id, file_id=default_file.id, organization_id=default_user.organization_id, embedding=[0.2], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) assert updated_passage.text == "Updated source passage text" assert updated_passage.embedding[0] == approx(0.2) assert updated_passage.id == passage.id @pytest.mark.asyncio async def test_delete_agent_passage_specific(server: SyncServer, default_user, sarah_agent): """Test deleting an agent passage using the new agent-specific method.""" # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create an agent passage passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( text="Agent passage to delete", archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) # Verify it exists retrieved = await server.passage_manager.get_agent_passage_by_id_async(passage.id, actor=default_user) assert retrieved is not None # Delete it result = await server.passage_manager.delete_agent_passage_by_id_async(passage.id, actor=default_user) assert result is True # Verify it's gone with pytest.raises(NoResultFound): await server.passage_manager.get_agent_passage_by_id_async(passage.id, actor=default_user) @pytest.mark.asyncio async def test_delete_source_passage_specific(server: SyncServer, default_user, default_file, default_source): """Test deleting a source passage using the new source-specific method.""" # Create a source passage passage = await server.passage_manager.create_source_passage_async( PydanticPassage( text="Source passage to delete", source_id=default_source.id, file_id=default_file.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), file_metadata=default_file, actor=default_user, ) # Verify it exists retrieved = await server.passage_manager.get_source_passage_by_id_async(passage.id, actor=default_user) assert retrieved is not None # Delete it result = await server.passage_manager.delete_source_passage_by_id_async(passage.id, actor=default_user) assert result is True # Verify it's gone with pytest.raises(NoResultFound): await server.passage_manager.get_source_passage_by_id_async(passage.id, actor=default_user) @pytest.mark.asyncio async def test_create_many_agent_passages_async(server: SyncServer, default_user, sarah_agent): """Test creating multiple agent passages using the new batch method.""" # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) passages = [ PydanticPassage( text=f"Batch agent passage {i}", archive_id=archive.id, # Now archive is a PydanticArchive object organization_id=default_user.organization_id, embedding=[0.1 * i], embedding_config=DEFAULT_EMBEDDING_CONFIG, tags=["batch", f"item{i}"] if i % 2 == 0 else ["batch", "odd"], ) for i in range(3) ] created_passages = await server.passage_manager.create_many_archival_passages_async(passages, actor=default_user) assert len(created_passages) == 3 for i, passage in enumerate(created_passages): assert passage.text == f"Batch agent passage {i}" assert passage.archive_id == archive.id assert passage.source_id is None expected_tags = ["batch", f"item{i}"] if i % 2 == 0 else ["batch", "odd"] assert passage.tags == expected_tags @pytest.mark.asyncio async def test_create_many_source_passages_async(server: SyncServer, default_user, default_file, default_source): """Test creating multiple source passages using the new batch method.""" passages = [ PydanticPassage( text=f"Batch source passage {i}", source_id=default_source.id, file_id=default_file.id, organization_id=default_user.organization_id, embedding=[0.1 * i], embedding_config=DEFAULT_EMBEDDING_CONFIG, ) for i in range(3) ] created_passages = await server.passage_manager.create_many_source_passages_async( passages, file_metadata=default_file, actor=default_user ) assert len(created_passages) == 3 for i, passage in enumerate(created_passages): assert passage.text == f"Batch source passage {i}" assert passage.source_id == default_source.id assert passage.archive_id is None @pytest.mark.asyncio async def test_agent_passage_size(server: SyncServer, default_user, sarah_agent): """Test counting agent passages using the new agent-specific size method.""" initial_size = await server.passage_manager.agent_passage_size_async(actor=default_user, agent_id=sarah_agent.id) # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create some agent passages for i in range(3): await server.passage_manager.create_agent_passage_async( PydanticPassage( text=f"Agent passage {i} for size test", archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1], embedding_config=DEFAULT_EMBEDDING_CONFIG, ), actor=default_user, ) final_size = await server.passage_manager.agent_passage_size_async(actor=default_user, agent_id=sarah_agent.id) assert final_size == initial_size + 3 @pytest.mark.asyncio async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServer, default_user, sarah_agent): """Test comprehensive tag functionality for passages.""" from letta.schemas.enums import TagMatchMode # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create passages with different tag combinations test_passages = [ {"text": "Python programming tutorial", "tags": ["python", "tutorial", "programming"]}, {"text": "Machine learning with Python", "tags": ["python", "ml", "ai"]}, {"text": "JavaScript web development", "tags": ["javascript", "web", "frontend"]}, {"text": "Python data science guide", "tags": ["python", "tutorial", "data"]}, {"text": "No tags passage", "tags": None}, ] created_passages = [] for test_data in test_passages: passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( text=test_data["text"], archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1, 0.2, 0.3], embedding_config=DEFAULT_EMBEDDING_CONFIG, tags=test_data["tags"], ), actor=default_user, ) created_passages.append(passage) # Test that tags are properly stored (deduplicated) for i, passage in enumerate(created_passages): expected_tags = test_passages[i]["tags"] if expected_tags: assert set(passage.tags) == set(expected_tags) else: assert passage.tags is None # Test querying with tag filtering (if Turbopuffer is enabled) if hasattr(server.agent_manager, "query_agent_passages_async"): # Test querying with python tag (should find 3 passages) python_results = await server.agent_manager.query_agent_passages_async( actor=default_user, agent_id=sarah_agent.id, tags=["python"], tag_match_mode=TagMatchMode.ANY, ) python_texts = [p.text for p, _, _ in python_results] assert len([t for t in python_texts if "Python" in t]) >= 2 # Test querying with multiple tags using ALL mode tutorial_python_results = await server.agent_manager.query_agent_passages_async( actor=default_user, agent_id=sarah_agent.id, tags=["python", "tutorial"], tag_match_mode=TagMatchMode.ALL, ) tutorial_texts = [p.text for p, _, _ in tutorial_python_results] expected_matches = [t for t in tutorial_texts if "tutorial" in t and "Python" in t] assert len(expected_matches) >= 1 @pytest.mark.asyncio async def test_comprehensive_tag_functionality(disable_turbopuffer, server: SyncServer, sarah_agent, default_user): """Comprehensive test for tag functionality including dual storage and junction table.""" # Test 1: Create passages with tags and verify they're stored in both places passages_with_tags = [] test_tags = { "passage1": ["important", "documentation", "python"], "passage2": ["important", "testing"], "passage3": ["documentation", "api"], "passage4": ["python", "testing", "api"], "passage5": [], # Test empty tags } for i, (passage_key, tags) in enumerate(test_tags.items(), 1): text = f"Test passage {i} for comprehensive tag testing" created_passages = await server.passage_manager.insert_passage( agent_state=sarah_agent, text=text, actor=default_user, tags=tags if tags else None, ) assert len(created_passages) == 1 passage = created_passages[0] # Verify tags are stored in the JSON column (deduplicated) if tags: assert set(passage.tags) == set(tags) else: assert passage.tags is None passages_with_tags.append(passage) # Test 2: Verify unique tags for archive archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( agent_state=sarah_agent, actor=default_user, ) unique_tags = await server.passage_manager.get_unique_tags_for_archive_async( archive_id=archive.id, actor=default_user, ) # Should have all unique tags: "important", "documentation", "python", "testing", "api" expected_unique_tags = {"important", "documentation", "python", "testing", "api"} assert set(unique_tags) == expected_unique_tags assert len(unique_tags) == 5 # Test 3: Verify tag counts tag_counts = await server.passage_manager.get_tag_counts_for_archive_async( archive_id=archive.id, actor=default_user, ) # Verify counts assert tag_counts["important"] == 2 # passage1 and passage2 assert tag_counts["documentation"] == 2 # passage1 and passage3 assert tag_counts["python"] == 2 # passage1 and passage4 assert tag_counts["testing"] == 2 # passage2 and passage4 assert tag_counts["api"] == 2 # passage3 and passage4 # Test 4: Query passages with ANY tag matching any_results = await server.agent_manager.query_agent_passages_async( agent_id=sarah_agent.id, query_text="test", limit=10, tags=["important", "api"], tag_match_mode=TagMatchMode.ANY, actor=default_user, ) # Should match passages with "important" OR "api" tags (passages 1, 2, 3, 4) [p.text for p, _, _ in any_results] assert len(any_results) >= 4 # Test 5: Query passages with ALL tag matching all_results = await server.agent_manager.query_agent_passages_async( agent_id=sarah_agent.id, query_text="test", limit=10, tags=["python", "testing"], tag_match_mode=TagMatchMode.ALL, actor=default_user, ) # Should only match passage4 which has both "python" AND "testing" all_passage_texts = [p.text for p, _, _ in all_results] assert any("Test passage 4" in text for text in all_passage_texts) # Test 6: Query with non-existent tags no_results = await server.agent_manager.query_agent_passages_async( agent_id=sarah_agent.id, query_text="test", limit=10, tags=["nonexistent", "missing"], tag_match_mode=TagMatchMode.ANY, actor=default_user, ) # Should return no results assert len(no_results) == 0 # Test 7: Verify tags CAN be updated (with junction table properly maintained) first_passage = passages_with_tags[0] new_tags = ["updated", "modified", "changed"] update_data = PydanticPassage( id=first_passage.id, text="Updated text", tags=new_tags, organization_id=first_passage.organization_id, archive_id=first_passage.archive_id, embedding=first_passage.embedding, embedding_config=first_passage.embedding_config, ) # Update should work and tags should be updated updated = await server.passage_manager.update_agent_passage_by_id_async( passage_id=first_passage.id, passage=update_data, actor=default_user, ) # Both text and tags should be updated assert updated.text == "Updated text" assert set(updated.tags) == set(new_tags) # Verify tags are properly updated in junction table updated_unique_tags = await server.passage_manager.get_unique_tags_for_archive_async( archive_id=archive.id, actor=default_user, ) # Should include new tags and not include old "important", "documentation", "python" from passage1 # But still have tags from other passages assert "updated" in updated_unique_tags assert "modified" in updated_unique_tags assert "changed" in updated_unique_tags # Test 8: Delete a passage and verify cascade deletion of tags passage_to_delete = passages_with_tags[1] # passage2 with ["important", "testing"] await server.passage_manager.delete_agent_passage_by_id_async( passage_id=passage_to_delete.id, actor=default_user, ) # Get updated tag counts updated_tag_counts = await server.passage_manager.get_tag_counts_for_archive_async( archive_id=archive.id, actor=default_user, ) # "important" no longer exists (was in passage1 which was updated and passage2 which was deleted) assert "important" not in updated_tag_counts # "testing" count should decrease from 2 to 1 (only in passage4 now) assert updated_tag_counts["testing"] == 1 # Test 9: Batch create passages with tags batch_texts = [ "Batch passage 1", "Batch passage 2", "Batch passage 3", ] batch_tags = ["batch", "test", "multiple"] batch_passages = [] for text in batch_texts: passages = await server.passage_manager.insert_passage( agent_state=sarah_agent, text=text, actor=default_user, tags=batch_tags, ) batch_passages.extend(passages) # Verify all batch passages have the same tags for passage in batch_passages: assert set(passage.tags) == set(batch_tags) # Test 10: Verify tag counts include batch passages final_tag_counts = await server.passage_manager.get_tag_counts_for_archive_async( archive_id=archive.id, actor=default_user, ) assert final_tag_counts["batch"] == 3 assert final_tag_counts["test"] == 3 assert final_tag_counts["multiple"] == 3 # Test 11: Complex query with multiple tags and ALL matching complex_all_results = await server.agent_manager.query_agent_passages_async( agent_id=sarah_agent.id, query_text="batch", limit=10, tags=["batch", "test", "multiple"], tag_match_mode=TagMatchMode.ALL, actor=default_user, ) # Should match all 3 batch passages assert len(complex_all_results) >= 3 # Test 12: Empty tag list should return all passages all_passages = await server.agent_manager.query_agent_passages_async( agent_id=sarah_agent.id, query_text="passage", limit=50, tags=[], tag_match_mode=TagMatchMode.ANY, actor=default_user, ) # Should return passages based on text search only assert len(all_passages) > 0 @pytest.mark.asyncio async def test_tag_edge_cases(disable_turbopuffer, server: SyncServer, sarah_agent, default_user): """Test edge cases for tag functionality.""" # Test 1: Very long tag names long_tag = "a" * 500 # 500 character tag passages = await server.passage_manager.insert_passage( agent_state=sarah_agent, text="Testing long tag names", actor=default_user, tags=[long_tag, "normal_tag"], ) assert len(passages) == 1 assert long_tag in passages[0].tags # Test 2: Special characters in tags special_tags = [ "tag-with-dash", "tag_with_underscore", "tag.with.dots", "tag/with/slash", "tag:with:colon", "tag@with@at", "tag#with#hash", "tag with spaces", "CamelCaseTag", "数字标签", ] passages_special = await server.passage_manager.insert_passage( agent_state=sarah_agent, text="Testing special character tags", actor=default_user, tags=special_tags, ) assert len(passages_special) == 1 assert set(passages_special[0].tags) == set(special_tags) # Verify unique tags includes all special character tags archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( agent_state=sarah_agent, actor=default_user, ) unique_tags = await server.passage_manager.get_unique_tags_for_archive_async( archive_id=archive.id, actor=default_user, ) for tag in special_tags: assert tag in unique_tags # Test 3: Duplicate tags in input (should be deduplicated) duplicate_tags = ["tag1", "tag2", "tag1", "tag3", "tag2", "tag1"] passages_dup = await server.passage_manager.insert_passage( agent_state=sarah_agent, text="Testing duplicate tags", actor=default_user, tags=duplicate_tags, ) # Should only have unique tags (duplicates removed) assert len(passages_dup) == 1 assert set(passages_dup[0].tags) == {"tag1", "tag2", "tag3"} assert len(passages_dup[0].tags) == 3 # Should be deduplicated # Test 4: Case sensitivity in tags case_tags = ["Tag", "tag", "TAG", "tAg"] passages_case = await server.passage_manager.insert_passage( agent_state=sarah_agent, text="Testing case sensitive tags", actor=default_user, tags=case_tags, ) # All variations should be preserved (case-sensitive) assert len(passages_case) == 1 assert set(passages_case[0].tags) == set(case_tags) @pytest.mark.asyncio async def test_search_agent_archival_memory_async(disable_turbopuffer, server: SyncServer, default_user, sarah_agent): """Test the search_agent_archival_memory_async method that powers both the agent tool and API endpoint.""" # Get or create default archive for the agent archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(agent_state=sarah_agent, actor=default_user) # Create test passages with various content and tags test_data = [ { "text": "Python is a powerful programming language used for data science and web development.", "tags": ["python", "programming", "data-science", "web"], "created_at": datetime(2024, 1, 15, 10, 30, tzinfo=timezone.utc), }, { "text": "Machine learning algorithms can be implemented in Python using libraries like scikit-learn.", "tags": ["python", "machine-learning", "algorithms"], "created_at": datetime(2024, 1, 16, 14, 45, tzinfo=timezone.utc), }, { "text": "JavaScript is essential for frontend web development and modern web applications.", "tags": ["javascript", "frontend", "web"], "created_at": datetime(2024, 1, 17, 9, 15, tzinfo=timezone.utc), }, { "text": "Database design principles are important for building scalable applications.", "tags": ["database", "design", "scalability"], "created_at": datetime(2024, 1, 18, 16, 20, tzinfo=timezone.utc), }, { "text": "The weather today is sunny and warm, perfect for outdoor activities.", "tags": ["weather", "outdoor"], "created_at": datetime(2024, 1, 19, 11, 0, tzinfo=timezone.utc), }, ] # Create passages in the database created_passages = [] for data in test_data: passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( text=data["text"], archive_id=archive.id, organization_id=default_user.organization_id, embedding=[0.1, 0.2, 0.3], # Mock embedding embedding_config=DEFAULT_EMBEDDING_CONFIG, tags=data["tags"], created_at=data["created_at"], ), actor=default_user, ) created_passages.append(passage) # Test 1: Basic search by query text results = await server.agent_manager.search_agent_archival_memory_async( agent_id=sarah_agent.id, actor=default_user, query="Python programming" ) assert len(results) > 0 # Check structure of results for result in results: assert "timestamp" in result assert "content" in result assert "tags" in result assert isinstance(result["tags"], list) # Test 2: Search with tag filtering - single tag results = await server.agent_manager.search_agent_archival_memory_async( agent_id=sarah_agent.id, actor=default_user, query="programming", tags=["python"] ) assert len(results) > 0 # All results should have "python" tag for result in results: assert "python" in result["tags"] # Test 3: Search with tag filtering - multiple tags with "any" mode results = await server.agent_manager.search_agent_archival_memory_async( agent_id=sarah_agent.id, actor=default_user, query="development", tags=["web", "database"], tag_match_mode="any" ) assert len(results) > 0 # All results should have at least one of the specified tags for result in results: assert any(tag in result["tags"] for tag in ["web", "database"]) # Test 4: Search with tag filtering - multiple tags with "all" mode results = await server.agent_manager.search_agent_archival_memory_async( agent_id=sarah_agent.id, actor=default_user, query="Python", tags=["python", "web"], tag_match_mode="all" ) # Should only return results that have BOTH tags for result in results: assert "python" in result["tags"] assert "web" in result["tags"] # Test 5: Search with top_k limit results = await server.agent_manager.search_agent_archival_memory_async( agent_id=sarah_agent.id, actor=default_user, query="programming", top_k=2 ) assert len(results) <= 2 # Test 6: Search with datetime filtering results = await server.agent_manager.search_agent_archival_memory_async( agent_id=sarah_agent.id, actor=default_user, query="programming", start_datetime="2024-01-16", end_datetime="2024-01-17" ) # Should only include passages created between those dates for result in results: # Parse timestamp to verify it's in range timestamp_str = result["timestamp"] # Basic validation that timestamp exists and has expected format assert "2024-01-16" in timestamp_str or "2024-01-17" in timestamp_str # Test 7: Search with ISO datetime format results = await server.agent_manager.search_agent_archival_memory_async( agent_id=sarah_agent.id, actor=default_user, query="algorithms", start_datetime="2024-01-16T14:00:00", end_datetime="2024-01-16T15:00:00", ) # Should include the machine learning passage created at 14:45 assert len(results) >= 0 # Might be 0 if no results, but shouldn't error # Test 8: Search with non-existent agent should raise error non_existent_agent_id = "agent-00000000-0000-4000-8000-000000000000" with pytest.raises(Exception): # Should raise NoResultFound or similar await server.agent_manager.search_agent_archival_memory_async(agent_id=non_existent_agent_id, actor=default_user, query="test") # Test 9: Search with invalid datetime format should raise ValueError with pytest.raises(ValueError, match="Invalid start_datetime format"): await server.agent_manager.search_agent_archival_memory_async( agent_id=sarah_agent.id, actor=default_user, query="test", start_datetime="invalid-date" ) # Test 10: Empty query should return empty results results = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query="") assert len(results) == 0 # Empty query should return 0 results # Test 11: Whitespace-only query should also return empty results results = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query=" \n\t ") assert len(results) == 0 # Whitespace-only query should return 0 results # Cleanup - delete the created passages for passage in created_passages: await server.passage_manager.delete_agent_passage_by_id_async(passage_id=passage.id, actor=default_user)