feat: Add search messages endpoint [LET-4144] (#4434)

* Add search messages endpoint

* Run fern autogen and fix tests
This commit is contained in:
Matthew Zhou
2025-09-05 14:28:27 -07:00
committed by GitHub
parent f2485daef7
commit 2ef47d8002
13 changed files with 462 additions and 312 deletions

View File

@@ -233,7 +233,7 @@ class TestTurbopufferIntegration:
pass
@pytest.mark.asyncio
async def test_turbopuffer_metadata_attributes(self, enable_turbopuffer):
async def test_turbopuffer_metadata_attributes(self, default_user, enable_turbopuffer):
"""Test that Turbopuffer properly stores and retrieves metadata attributes"""
# Only run if we have a real API key
@@ -273,17 +273,16 @@ class TestTurbopufferIntegration:
result = await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[d["text"] for d in test_data],
embeddings=[d["vector"] for d in test_data],
passage_ids=[d["id"] for d in test_data],
organization_id="org-123", # Default org
actor=default_user,
created_at=datetime.now(timezone.utc),
)
assert len(result) == 3
# Query all passages (no tag filtering)
query_vector = [0.15] * 1536
results = await client.query_passages(archive_id=archive_id, query_embedding=query_vector, top_k=10)
results = await client.query_passages(archive_id=archive_id, actor=default_user, top_k=10)
# Should get all passages
assert len(results) == 3 # All three passages
@@ -339,7 +338,7 @@ class TestTurbopufferIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
async def test_hybrid_search_with_real_tpuf(self, enable_turbopuffer):
async def test_hybrid_search_with_real_tpuf(self, default_user, enable_turbopuffer):
"""Test hybrid search functionality combining vector and full-text search"""
import uuid
@@ -366,13 +365,14 @@ class TestTurbopufferIntegration:
# Insert passages
await client.insert_archival_memories(
archive_id=archive_id, text_chunks=texts, embeddings=embeddings, passage_ids=passage_ids, organization_id=org_id
archive_id=archive_id, text_chunks=texts, passage_ids=passage_ids, organization_id=org_id, actor=default_user
)
# Test vector-only search
vector_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 6.0, 11.0], # similar to second passage embedding
actor=default_user,
query_text="python programming tutorial",
search_mode="vector",
top_k=3,
)
@@ -382,7 +382,7 @@ class TestTurbopufferIntegration:
# Test FTS-only search
fts_results = await client.query_passages(
archive_id=archive_id, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
archive_id=archive_id, actor=default_user, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
)
assert 0 < len(fts_results) <= 3
# should find passages mentioning Turbopuffer
@@ -393,7 +393,7 @@ class TestTurbopufferIntegration:
# Test hybrid search
hybrid_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[2.0, 7.0, 12.0],
actor=default_user,
query_text="vector search Turbopuffer",
search_mode="hybrid",
top_k=3,
@@ -412,7 +412,7 @@ class TestTurbopufferIntegration:
# Test with different weights
vector_heavy_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[0.0, 5.0, 10.0], # very similar to first passage
actor=default_user,
query_text="quick brown fox", # matches second passage
search_mode="hybrid",
top_k=3,
@@ -423,16 +423,13 @@ class TestTurbopufferIntegration:
# all results should have scores
assert all(isinstance(score, float) for _, score, _ in vector_heavy_results)
# Test error handling - missing text for hybrid mode (embedding provided but text missing)
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3)
# Test error handling - missing embedding for hybrid mode (text provided but embedding missing)
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
await client.query_passages(archive_id=archive_id, query_text="test", search_mode="hybrid", top_k=3)
# Test with different search modes
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="vector", top_k=3)
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="fts", top_k=3)
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="hybrid", top_k=3)
# Test explicit timestamp mode
timestamp_results = await client.query_passages(archive_id=archive_id, search_mode="timestamp", top_k=3)
timestamp_results = await client.query_passages(archive_id=archive_id, actor=default_user, search_mode="timestamp", top_k=3)
assert len(timestamp_results) <= 3
# Should return passages ordered by timestamp (most recent first)
assert all(isinstance(passage, Passage) for passage, _, _ in timestamp_results)
@@ -446,7 +443,7 @@ class TestTurbopufferIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
async def test_tag_filtering_with_real_tpuf(self, enable_turbopuffer):
async def test_tag_filtering_with_real_tpuf(self, default_user, enable_turbopuffer):
"""Test tag filtering functionality with AND and OR logic"""
import uuid
@@ -479,13 +476,13 @@ class TestTurbopufferIntegration:
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
# Insert passages with tags
for i, (text, tags, embedding, passage_id) in enumerate(zip(texts, tag_sets, embeddings, passage_ids)):
for i, (text, tags, passage_id) in enumerate(zip(texts, tag_sets, passage_ids)):
await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[text],
embeddings=[embedding],
passage_ids=[passage_id],
organization_id=org_id,
actor=default_user,
tags=tags,
created_at=datetime.now(timezone.utc),
)
@@ -493,7 +490,8 @@ class TestTurbopufferIntegration:
# Test tag filtering with "any" mode (should find passages with any of the specified tags)
python_any_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 6.0, 11.0],
actor=default_user,
query_text="python programming",
search_mode="vector",
top_k=10,
tags=["python"],
@@ -511,7 +509,8 @@ class TestTurbopufferIntegration:
# Test tag filtering with "all" mode
python_tutorial_all_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 6.0, 11.0],
actor=default_user,
query_text="python tutorial",
search_mode="vector",
top_k=10,
tags=["python", "tutorial"],
@@ -528,6 +527,7 @@ class TestTurbopufferIntegration:
# Test tag filtering with FTS mode
js_fts_results = await client.query_passages(
archive_id=archive_id,
actor=default_user,
query_text="javascript",
search_mode="fts",
top_k=10,
@@ -545,7 +545,7 @@ class TestTurbopufferIntegration:
# Test hybrid search with tags
python_hybrid_results = await client.query_passages(
archive_id=archive_id,
query_embedding=[2.0, 7.0, 12.0],
actor=default_user,
query_text="python programming",
search_mode="hybrid",
top_k=10,
@@ -569,7 +569,7 @@ class TestTurbopufferIntegration:
pass
@pytest.mark.asyncio
async def test_temporal_filtering_with_real_tpuf(self, enable_turbopuffer):
async def test_temporal_filtering_with_real_tpuf(self, default_user, enable_turbopuffer):
"""Test temporal filtering with date ranges"""
from datetime import datetime, timedelta, timezone
@@ -601,15 +601,14 @@ class TestTurbopufferIntegration:
# We need to generate embeddings for the passages
# For testing, we'll use simple dummy embeddings
for text, timestamp in test_passages:
dummy_embedding = [1.0, 2.0, 3.0] # Simple test embedding
passage_id = f"passage-{uuid.uuid4()}"
await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[text],
embeddings=[dummy_embedding],
passage_ids=[passage_id],
organization_id="test-org",
actor=default_user,
created_at=timestamp,
)
@@ -617,7 +616,8 @@ class TestTurbopufferIntegration:
three_days_ago = now - timedelta(days=3)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="meeting notes",
search_mode="vector",
top_k=10,
start_date=three_days_ago,
@@ -637,7 +637,8 @@ class TestTurbopufferIntegration:
two_weeks_ago = now - timedelta(days=14)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="meeting notes",
search_mode="vector",
top_k=10,
start_date=two_weeks_ago,
@@ -652,7 +653,8 @@ class TestTurbopufferIntegration:
# Test 3: Query with only end_date (everything before yesterday)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="meeting notes",
search_mode="vector",
top_k=10,
end_date=yesterday + timedelta(hours=12), # Middle of yesterday
@@ -667,6 +669,7 @@ class TestTurbopufferIntegration:
# Test 4: Test with FTS mode and date filtering
results = await client.query_passages(
archive_id=archive_id,
actor=default_user,
query_text="meeting notes project",
search_mode="fts",
top_k=10,
@@ -682,7 +685,7 @@ class TestTurbopufferIntegration:
# Test 5: Test with hybrid mode and date filtering
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="sprint review",
search_mode="hybrid",
top_k=10,
@@ -934,11 +937,9 @@ class TestTurbopufferMessagesIntegration:
),
]
# Create messages without embedding_config
created = await server.message_manager.create_many_messages_async(
pydantic_msgs=messages,
actor=default_user,
embedding_config=None, # No config provided
)
assert len(created) == 2
@@ -1057,7 +1058,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding):
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test actual message embedding and storage in Turbopuffer"""
import uuid
from datetime import datetime, timezone
@@ -1087,9 +1088,9 @@ class TestTurbopufferMessagesIntegration:
success = await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
actor=default_user,
roles=roles,
created_ats=created_ats,
)
@@ -1097,11 +1098,8 @@ class TestTurbopufferMessagesIntegration:
assert success == True
# Verify we can query the messages
results = await client.query_messages(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
results = await client.query_messages_by_agent_id(
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, actor=default_user
)
assert len(results) == 3
@@ -1121,7 +1119,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding):
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test vector search on messages in Turbopuffer"""
import uuid
from datetime import datetime, timezone
@@ -1145,29 +1143,23 @@ class TestTurbopufferMessagesIntegration:
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
# Create embeddings that reflect content similarity
embeddings = [
[1.0, 0.0, 0.0], # Python programming
[0.0, 1.0, 0.0], # JavaScript web
[0.8, 0.0, 0.2], # ML with Python (similar to first)
]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
actor=default_user,
roles=roles,
created_ats=created_ats,
)
# Search for Python-related messages using vector search
query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages
results = await client.query_messages(
results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
query_embedding=query_embedding,
actor=default_user,
query_text="Python programming",
search_mode="vector",
top_k=2,
)
@@ -1187,7 +1179,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding):
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test hybrid search combining vector and FTS for messages"""
import uuid
from datetime import datetime, timezone
@@ -1211,30 +1203,22 @@ class TestTurbopufferMessagesIntegration:
roles = [MessageRole.assistant] * len(message_texts)
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
# Embeddings
embeddings = [
[0.1, 0.9, 0.0], # fox text
[0.9, 0.1, 0.0], # ML algorithms
[0.5, 0.5, 0.0], # Quick Python
[0.8, 0.2, 0.0], # Deep learning
]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
actor=default_user,
roles=roles,
created_ats=created_ats,
)
# Hybrid search - vector similar to ML but text contains "quick"
results = await client.query_messages(
# Hybrid search - text search for "quick"
results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages
actor=default_user,
query_text="quick", # Text search for "quick"
search_mode="hybrid",
top_k=3,
@@ -1257,7 +1241,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding):
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test filtering messages by role"""
import uuid
from datetime import datetime, timezone
@@ -1283,26 +1267,21 @@ class TestTurbopufferMessagesIntegration:
roles = [role for _, role in message_data]
message_ids = [str(uuid.uuid4()) for _ in message_texts]
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
# Insert messages
await client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=org_id,
actor=default_user,
roles=roles,
created_ats=created_ats,
)
# Query only user messages
user_results = await client.query_messages(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
roles=[MessageRole.user],
user_results = await client.query_messages_by_agent_id(
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, roles=[MessageRole.user], actor=default_user
)
assert len(user_results) == 2
@@ -1311,12 +1290,13 @@ class TestTurbopufferMessagesIntegration:
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
# Query assistant and system messages
non_user_results = await client.query_messages(
non_user_results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
roles=[MessageRole.assistant, MessageRole.system],
actor=default_user,
)
assert len(non_user_results) == 3
@@ -1395,7 +1375,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
@@ -1409,7 +1388,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Python",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(python_results) > 0
assert any(msg.id == message_id for msg, metadata in python_results)
@@ -1419,7 +1397,6 @@ class TestTurbopufferMessagesIntegration:
message_id=message_id,
message_update=MessageUpdate(content="Updated content about JavaScript development"),
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
@@ -1432,7 +1409,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Python",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
# Should either find no results or results that don't include our message
assert not any(msg.id == message_id for msg, metadata in python_results_after)
@@ -1444,7 +1420,6 @@ class TestTurbopufferMessagesIntegration:
query_text="JavaScript",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(js_results) > 0
assert any(msg.id == message_id for msg, metadata in js_results)
@@ -1497,7 +1472,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
agent_a_messages.extend(msgs)
@@ -1514,7 +1488,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
agent_b_messages.extend(msgs)
@@ -1526,7 +1499,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Agent A",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(agent_a_search) == 5
@@ -1536,7 +1508,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Agent B",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(agent_b_search) == 3
@@ -1559,7 +1530,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Agent A",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(agent_a_final) == 2
# Verify the remaining messages are the correct ones
@@ -1574,7 +1544,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Agent B",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(agent_b_final) == 0
@@ -1583,84 +1552,6 @@ class TestTurbopufferMessagesIntegration:
await server.agent_manager.delete_agent_async(agent_a.id, default_user)
await server.agent_manager.delete_agent_async(agent_b.id, default_user)
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_crud_operations_without_embedding_config(self, server, default_user, sarah_agent, enable_message_embedding):
"""Test that CRUD operations handle missing embedding_config gracefully"""
from letta.schemas.message import MessageUpdate
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
# Create message WITH embedding_config
messages = await server.message_manager.create_many_messages_async(
pydantic_msgs=[
PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Message with searchable content about databases")],
agent_id=sarah_agent.id,
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True,
)
assert len(messages) == 1
message_id = messages[0].id
# Verify message is searchable initially
initial_search = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(initial_search) > 0
assert any(msg.id == message_id for msg, metadata in initial_search)
# Update message WITHOUT embedding_config - should update postgres but not turbopuffer
updated_message = await server.message_manager.update_message_by_id_async(
message_id=message_id,
message_update=MessageUpdate(content="Updated content about algorithms"),
actor=default_user,
embedding_config=None, # No config provided
)
# Verify postgres was updated
assert updated_message.id == message_id
updated_text = server.message_manager._extract_message_text(updated_message)
assert "algorithms" in updated_text
assert "databases" not in updated_text
# Original search term should STILL find the message (turbopuffer wasn't updated)
still_searchable = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(still_searchable) > 0
assert any(msg.id == message_id for msg, metadata in still_searchable)
# New content should NOT be searchable (wasn't re-indexed)
not_searchable = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="algorithms",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
# Should either find no results or results that don't include our message
assert not any(msg.id == message_id for msg, metadata in not_searchable)
# Clean up
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_turbopuffer_failure_does_not_break_postgres(self, server, default_user, sarah_agent, enable_message_embedding):
@@ -1681,7 +1572,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
)
assert len(messages) == 1
@@ -1702,7 +1592,6 @@ class TestTurbopufferMessagesIntegration:
message_id=message_id,
message_update=MessageUpdate(content="Updated despite turbopuffer failure"),
actor=default_user,
embedding_config=embedding_config,
strict_mode=False, # Don't fail on turbopuffer errors - that's what we're testing!
)
@@ -1722,7 +1611,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=None, # Create without embedding to avoid mock issues
)
message_to_delete_id = messages2[0].id
@@ -1741,7 +1629,7 @@ class TestTurbopufferMessagesIntegration:
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False)
async def wait_for_embedding(
self, agent_id: str, message_id: str, organization_id: str, max_wait: float = 10.0, poll_interval: float = 0.5
self, agent_id: str, message_id: str, organization_id: str, actor, max_wait: float = 10.0, poll_interval: float = 0.5
) -> bool:
"""Poll Turbopuffer directly to check if a message has been embedded.
@@ -1765,9 +1653,10 @@ class TestTurbopufferMessagesIntegration:
while asyncio.get_event_loop().time() - start_time < max_wait:
try:
# Query Turbopuffer directly using timestamp mode to get all messages
results = await client.query_messages(
results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=organization_id,
actor=actor,
search_mode="timestamp",
top_k=100, # Get more messages to ensure we find it
)
@@ -1800,7 +1689,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=False, # Background mode
)
@@ -1814,7 +1702,12 @@ class TestTurbopufferMessagesIntegration:
# Poll for embedding completion by querying Turbopuffer directly
embedded = await self.wait_for_embedding(
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
agent_id=sarah_agent.id,
message_id=message_id,
organization_id=default_user.organization_id,
actor=default_user,
max_wait=10.0,
poll_interval=0.5,
)
assert embedded, "Message was not embedded in Turbopuffer within timeout"
@@ -1825,7 +1718,6 @@ class TestTurbopufferMessagesIntegration:
query_text="Python programming",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(search_results) > 0
assert any(msg.id == message_id for msg, _ in search_results)
@@ -1851,7 +1743,6 @@ class TestTurbopufferMessagesIntegration:
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True, # Ensure initial embedding
)
@@ -1865,7 +1756,6 @@ class TestTurbopufferMessagesIntegration:
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert any(msg.id == message_id for msg, _ in initial_results)
@@ -1874,7 +1764,6 @@ class TestTurbopufferMessagesIntegration:
message_id=message_id,
message_update=MessageUpdate(content="Updated content about machine learning"),
actor=default_user,
embedding_config=embedding_config,
strict_mode=False, # Background mode
)
@@ -1890,7 +1779,12 @@ class TestTurbopufferMessagesIntegration:
# Poll for the update to be reflected in Turbopuffer
# We check by searching for the new content
embedded = await self.wait_for_embedding(
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
agent_id=sarah_agent.id,
message_id=message_id,
organization_id=default_user.organization_id,
actor=default_user,
max_wait=10.0,
poll_interval=0.5,
)
assert embedded, "Updated message was not re-embedded within timeout"
@@ -1901,7 +1795,6 @@ class TestTurbopufferMessagesIntegration:
query_text="machine learning",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert any(msg.id == message_id for msg, _ in new_results)
@@ -1914,7 +1807,6 @@ class TestTurbopufferMessagesIntegration:
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
# The message shouldn't match the old search term anymore
if len(old_results) > 0:
@@ -1929,7 +1821,7 @@ class TestTurbopufferMessagesIntegration:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding, default_user):
"""Test filtering messages by date range"""
import uuid
from datetime import datetime, timedelta, timezone
@@ -1959,21 +1851,17 @@ class TestTurbopufferMessagesIntegration:
await client.insert_messages(
agent_id=agent_id,
message_texts=[text],
embeddings=[[1.0, 2.0, 3.0]],
message_ids=[str(uuid.uuid4())],
organization_id=org_id,
actor=default_user,
roles=[MessageRole.assistant],
created_ats=[timestamp],
)
# Query messages from the last 3 days
three_days_ago = now - timedelta(days=3)
recent_results = await client.query_messages(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
start_date=three_days_ago,
recent_results = await client.query_messages_by_agent_id(
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, start_date=three_days_ago, actor=default_user
)
# Should get today's and yesterday's messages
@@ -1984,13 +1872,14 @@ class TestTurbopufferMessagesIntegration:
# Query messages between 2 weeks ago and 1 week ago
two_weeks_ago = now - timedelta(days=14)
week_results = await client.query_messages(
week_results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
search_mode="timestamp",
top_k=10,
start_date=two_weeks_ago,
end_date=last_week + timedelta(days=1), # Include last week's message
actor=default_user,
)
# Should get only last week's message
@@ -1998,10 +1887,11 @@ class TestTurbopufferMessagesIntegration:
assert week_results[0][0]["text"] == "Last week's message"
# Query with vector search and date filtering
filtered_vector_results = await client.query_messages(
filtered_vector_results = await client.query_messages_by_agent_id(
agent_id=agent_id,
organization_id=org_id,
query_embedding=[1.0, 2.0, 3.0],
actor=default_user,
query_text="message",
search_mode="vector",
top_k=10,
start_date=three_days_ago,
@@ -2101,7 +1991,7 @@ class TestNamespaceTracking:
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_project_id_filtering(self, server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding):
"""Test that project_id filtering works correctly in query_messages"""
"""Test that project_id filtering works correctly in query_messages_by_agent_id"""
from letta.schemas.letta_message_content import TextContent
# Create two project IDs
@@ -2124,24 +2014,15 @@ class TestNamespaceTracking:
# Insert messages with their respective project IDs
tpuf_client = TurbopufferClient()
# Generate embeddings
from letta.llm_api.llm_client import LLMClient
embedding_client = LLMClient.create(
provider_type=sarah_agent.embedding_config.embedding_endpoint_type,
actor=default_user,
)
embeddings = await embedding_client.request_embeddings(
[message_a.content[0].text, message_b.content[0].text], sarah_agent.embedding_config
)
# Embeddings will be generated automatically by the client
# Insert message A with project_a_id
await tpuf_client.insert_messages(
agent_id=sarah_agent.id,
message_texts=[message_a.content[0].text],
embeddings=[embeddings[0]],
message_ids=[message_a.id],
organization_id=default_user.organization_id,
actor=default_user,
roles=[message_a.role],
created_ats=[message_a.created_at],
project_id=project_a_id,
@@ -2151,9 +2032,9 @@ class TestNamespaceTracking:
await tpuf_client.insert_messages(
agent_id=sarah_agent.id,
message_texts=[message_b.content[0].text],
embeddings=[embeddings[1]],
message_ids=[message_b.id],
organization_id=default_user.organization_id,
actor=default_user,
roles=[message_b.role],
created_ats=[message_b.created_at],
project_id=project_b_id,
@@ -2162,12 +2043,13 @@ class TestNamespaceTracking:
# Poll for message A with project_a_id filter
max_retries = 10
for i in range(max_retries):
results_a = await tpuf_client.query_messages(
results_a = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp", # Simple timestamp retrieval
top_k=10,
project_id=project_a_id,
actor=default_user,
)
if len(results_a) == 1 and results_a[0][0]["id"] == message_a.id:
break
@@ -2179,12 +2061,13 @@ class TestNamespaceTracking:
# Poll for message B with project_b_id filter
for i in range(max_retries):
results_b = await tpuf_client.query_messages(
results_b = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
project_id=project_b_id,
actor=default_user,
)
if len(results_b) == 1 and results_b[0][0]["id"] == message_b.id:
break
@@ -2195,12 +2078,13 @@ class TestNamespaceTracking:
assert "JavaScript" in results_b[0][0]["text"]
# Query without project filter - should find both
results_all = await tpuf_client.query_messages(
results_all = await tpuf_client.query_messages_by_agent_id(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
project_id=None, # No filter
actor=default_user,
)
assert len(results_all) >= 2 # May have other messages from setup