feat: Add search messages endpoint [LET-4144] (#4434)
* Add search messages endpoint * Run fern autogen and fix tests
This commit is contained in:
@@ -233,7 +233,7 @@ class TestTurbopufferIntegration:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_turbopuffer_metadata_attributes(self, enable_turbopuffer):
|
||||
async def test_turbopuffer_metadata_attributes(self, default_user, enable_turbopuffer):
|
||||
"""Test that Turbopuffer properly stores and retrieves metadata attributes"""
|
||||
|
||||
# Only run if we have a real API key
|
||||
@@ -273,17 +273,16 @@ class TestTurbopufferIntegration:
|
||||
result = await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[d["text"] for d in test_data],
|
||||
embeddings=[d["vector"] for d in test_data],
|
||||
passage_ids=[d["id"] for d in test_data],
|
||||
organization_id="org-123", # Default org
|
||||
actor=default_user,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
# Query all passages (no tag filtering)
|
||||
query_vector = [0.15] * 1536
|
||||
results = await client.query_passages(archive_id=archive_id, query_embedding=query_vector, top_k=10)
|
||||
results = await client.query_passages(archive_id=archive_id, actor=default_user, top_k=10)
|
||||
|
||||
# Should get all passages
|
||||
assert len(results) == 3 # All three passages
|
||||
@@ -339,7 +338,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
|
||||
async def test_hybrid_search_with_real_tpuf(self, enable_turbopuffer):
|
||||
async def test_hybrid_search_with_real_tpuf(self, default_user, enable_turbopuffer):
|
||||
"""Test hybrid search functionality combining vector and full-text search"""
|
||||
|
||||
import uuid
|
||||
@@ -366,13 +365,14 @@ class TestTurbopufferIntegration:
|
||||
|
||||
# Insert passages
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id, text_chunks=texts, embeddings=embeddings, passage_ids=passage_ids, organization_id=org_id
|
||||
archive_id=archive_id, text_chunks=texts, passage_ids=passage_ids, organization_id=org_id, actor=default_user
|
||||
)
|
||||
|
||||
# Test vector-only search
|
||||
vector_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0], # similar to second passage embedding
|
||||
actor=default_user,
|
||||
query_text="python programming tutorial",
|
||||
search_mode="vector",
|
||||
top_k=3,
|
||||
)
|
||||
@@ -382,7 +382,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
# Test FTS-only search
|
||||
fts_results = await client.query_passages(
|
||||
archive_id=archive_id, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
|
||||
archive_id=archive_id, actor=default_user, query_text="Turbopuffer vector database", search_mode="fts", top_k=3
|
||||
)
|
||||
assert 0 < len(fts_results) <= 3
|
||||
# should find passages mentioning Turbopuffer
|
||||
@@ -393,7 +393,7 @@ class TestTurbopufferIntegration:
|
||||
# Test hybrid search
|
||||
hybrid_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[2.0, 7.0, 12.0],
|
||||
actor=default_user,
|
||||
query_text="vector search Turbopuffer",
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
@@ -412,7 +412,7 @@ class TestTurbopufferIntegration:
|
||||
# Test with different weights
|
||||
vector_heavy_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[0.0, 5.0, 10.0], # very similar to first passage
|
||||
actor=default_user,
|
||||
query_text="quick brown fox", # matches second passage
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
@@ -423,16 +423,13 @@ class TestTurbopufferIntegration:
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score, _ in vector_heavy_results)
|
||||
|
||||
# Test error handling - missing text for hybrid mode (embedding provided but text missing)
|
||||
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
|
||||
await client.query_passages(archive_id=archive_id, query_embedding=[1.0, 2.0, 3.0], search_mode="hybrid", top_k=3)
|
||||
|
||||
# Test error handling - missing embedding for hybrid mode (text provided but embedding missing)
|
||||
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
|
||||
await client.query_passages(archive_id=archive_id, query_text="test", search_mode="hybrid", top_k=3)
|
||||
# Test with different search modes
|
||||
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="vector", top_k=3)
|
||||
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="fts", top_k=3)
|
||||
await client.query_passages(archive_id=archive_id, actor=default_user, query_text="test", search_mode="hybrid", top_k=3)
|
||||
|
||||
# Test explicit timestamp mode
|
||||
timestamp_results = await client.query_passages(archive_id=archive_id, search_mode="timestamp", top_k=3)
|
||||
timestamp_results = await client.query_passages(archive_id=archive_id, actor=default_user, search_mode="timestamp", top_k=3)
|
||||
assert len(timestamp_results) <= 3
|
||||
# Should return passages ordered by timestamp (most recent first)
|
||||
assert all(isinstance(passage, Passage) for passage, _, _ in timestamp_results)
|
||||
@@ -446,7 +443,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured for testing")
|
||||
async def test_tag_filtering_with_real_tpuf(self, enable_turbopuffer):
|
||||
async def test_tag_filtering_with_real_tpuf(self, default_user, enable_turbopuffer):
|
||||
"""Test tag filtering functionality with AND and OR logic"""
|
||||
|
||||
import uuid
|
||||
@@ -479,13 +476,13 @@ class TestTurbopufferIntegration:
|
||||
passage_ids = [f"passage-{str(uuid.uuid4())}" for _ in texts]
|
||||
|
||||
# Insert passages with tags
|
||||
for i, (text, tags, embedding, passage_id) in enumerate(zip(texts, tag_sets, embeddings, passage_ids)):
|
||||
for i, (text, tags, passage_id) in enumerate(zip(texts, tag_sets, passage_ids)):
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[text],
|
||||
embeddings=[embedding],
|
||||
passage_ids=[passage_id],
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
tags=tags,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -493,7 +490,8 @@ class TestTurbopufferIntegration:
|
||||
# Test tag filtering with "any" mode (should find passages with any of the specified tags)
|
||||
python_any_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0],
|
||||
actor=default_user,
|
||||
query_text="python programming",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
tags=["python"],
|
||||
@@ -511,7 +509,8 @@ class TestTurbopufferIntegration:
|
||||
# Test tag filtering with "all" mode
|
||||
python_tutorial_all_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 6.0, 11.0],
|
||||
actor=default_user,
|
||||
query_text="python tutorial",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
tags=["python", "tutorial"],
|
||||
@@ -528,6 +527,7 @@ class TestTurbopufferIntegration:
|
||||
# Test tag filtering with FTS mode
|
||||
js_fts_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
actor=default_user,
|
||||
query_text="javascript",
|
||||
search_mode="fts",
|
||||
top_k=10,
|
||||
@@ -545,7 +545,7 @@ class TestTurbopufferIntegration:
|
||||
# Test hybrid search with tags
|
||||
python_hybrid_results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[2.0, 7.0, 12.0],
|
||||
actor=default_user,
|
||||
query_text="python programming",
|
||||
search_mode="hybrid",
|
||||
top_k=10,
|
||||
@@ -569,7 +569,7 @@ class TestTurbopufferIntegration:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_filtering_with_real_tpuf(self, enable_turbopuffer):
|
||||
async def test_temporal_filtering_with_real_tpuf(self, default_user, enable_turbopuffer):
|
||||
"""Test temporal filtering with date ranges"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
@@ -601,15 +601,14 @@ class TestTurbopufferIntegration:
|
||||
# We need to generate embeddings for the passages
|
||||
# For testing, we'll use simple dummy embeddings
|
||||
for text, timestamp in test_passages:
|
||||
dummy_embedding = [1.0, 2.0, 3.0] # Simple test embedding
|
||||
passage_id = f"passage-{uuid.uuid4()}"
|
||||
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[text],
|
||||
embeddings=[dummy_embedding],
|
||||
passage_ids=[passage_id],
|
||||
organization_id="test-org",
|
||||
actor=default_user,
|
||||
created_at=timestamp,
|
||||
)
|
||||
|
||||
@@ -617,7 +616,8 @@ class TestTurbopufferIntegration:
|
||||
three_days_ago = now - timedelta(days=3)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="meeting notes",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
@@ -637,7 +637,8 @@ class TestTurbopufferIntegration:
|
||||
two_weeks_ago = now - timedelta(days=14)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="meeting notes",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=two_weeks_ago,
|
||||
@@ -652,7 +653,8 @@ class TestTurbopufferIntegration:
|
||||
# Test 3: Query with only end_date (everything before yesterday)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="meeting notes",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
end_date=yesterday + timedelta(hours=12), # Middle of yesterday
|
||||
@@ -667,6 +669,7 @@ class TestTurbopufferIntegration:
|
||||
# Test 4: Test with FTS mode and date filtering
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
actor=default_user,
|
||||
query_text="meeting notes project",
|
||||
search_mode="fts",
|
||||
top_k=10,
|
||||
@@ -682,7 +685,7 @@ class TestTurbopufferIntegration:
|
||||
# Test 5: Test with hybrid mode and date filtering
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="sprint review",
|
||||
search_mode="hybrid",
|
||||
top_k=10,
|
||||
@@ -934,11 +937,9 @@ class TestTurbopufferMessagesIntegration:
|
||||
),
|
||||
]
|
||||
|
||||
# Create messages without embedding_config
|
||||
created = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=messages,
|
||||
actor=default_user,
|
||||
embedding_config=None, # No config provided
|
||||
)
|
||||
|
||||
assert len(created) == 2
|
||||
@@ -1057,7 +1058,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_dual_write_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test actual message embedding and storage in Turbopuffer"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -1087,9 +1088,9 @@ class TestTurbopufferMessagesIntegration:
|
||||
success = await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
@@ -1097,11 +1098,8 @@ class TestTurbopufferMessagesIntegration:
|
||||
assert success == True
|
||||
|
||||
# Verify we can query the messages
|
||||
results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, actor=default_user
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
@@ -1121,7 +1119,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_vector_search_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test vector search on messages in Turbopuffer"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -1145,29 +1143,23 @@ class TestTurbopufferMessagesIntegration:
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
|
||||
# Create embeddings that reflect content similarity
|
||||
embeddings = [
|
||||
[1.0, 0.0, 0.0], # Python programming
|
||||
[0.0, 1.0, 0.0], # JavaScript web
|
||||
[0.8, 0.0, 0.2], # ML with Python (similar to first)
|
||||
]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Search for Python-related messages using vector search
|
||||
query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages
|
||||
results = await client.query_messages(
|
||||
results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
query_embedding=query_embedding,
|
||||
actor=default_user,
|
||||
query_text="Python programming",
|
||||
search_mode="vector",
|
||||
top_k=2,
|
||||
)
|
||||
@@ -1187,7 +1179,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_hybrid_search_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test hybrid search combining vector and FTS for messages"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -1211,30 +1203,22 @@ class TestTurbopufferMessagesIntegration:
|
||||
roles = [MessageRole.assistant] * len(message_texts)
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
|
||||
# Embeddings
|
||||
embeddings = [
|
||||
[0.1, 0.9, 0.0], # fox text
|
||||
[0.9, 0.1, 0.0], # ML algorithms
|
||||
[0.5, 0.5, 0.0], # Quick Python
|
||||
[0.8, 0.2, 0.0], # Deep learning
|
||||
]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Hybrid search - vector similar to ML but text contains "quick"
|
||||
results = await client.query_messages(
|
||||
# Hybrid search - text search for "quick"
|
||||
results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages
|
||||
actor=default_user,
|
||||
query_text="quick", # Text search for "quick"
|
||||
search_mode="hybrid",
|
||||
top_k=3,
|
||||
@@ -1257,7 +1241,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_role_filtering_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test filtering messages by role"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
@@ -1283,26 +1267,21 @@ class TestTurbopufferMessagesIntegration:
|
||||
roles = [role for _, role in message_data]
|
||||
message_ids = [str(uuid.uuid4()) for _ in message_texts]
|
||||
created_ats = [datetime.now(timezone.utc) for _ in message_texts]
|
||||
embeddings = [[float(i), float(i + 1), float(i + 2)] for i in range(len(message_texts))]
|
||||
|
||||
# Insert messages
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
|
||||
# Query only user messages
|
||||
user_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
roles=[MessageRole.user],
|
||||
user_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, roles=[MessageRole.user], actor=default_user
|
||||
)
|
||||
|
||||
assert len(user_results) == 2
|
||||
@@ -1311,12 +1290,13 @@ class TestTurbopufferMessagesIntegration:
|
||||
assert msg["text"] in ["I need help with Python", "Can you explain this?"]
|
||||
|
||||
# Query assistant and system messages
|
||||
non_user_results = await client.query_messages(
|
||||
non_user_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
roles=[MessageRole.assistant, MessageRole.system],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(non_user_results) == 3
|
||||
@@ -1395,7 +1375,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
@@ -1409,7 +1388,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Python",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(python_results) > 0
|
||||
assert any(msg.id == message_id for msg, metadata in python_results)
|
||||
@@ -1419,7 +1397,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about JavaScript development"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
@@ -1432,7 +1409,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Python",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# Should either find no results or results that don't include our message
|
||||
assert not any(msg.id == message_id for msg, metadata in python_results_after)
|
||||
@@ -1444,7 +1420,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="JavaScript",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(js_results) > 0
|
||||
assert any(msg.id == message_id for msg, metadata in js_results)
|
||||
@@ -1497,7 +1472,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
agent_a_messages.extend(msgs)
|
||||
@@ -1514,7 +1488,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
agent_b_messages.extend(msgs)
|
||||
@@ -1526,7 +1499,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Agent A",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_a_search) == 5
|
||||
|
||||
@@ -1536,7 +1508,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Agent B",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_b_search) == 3
|
||||
|
||||
@@ -1559,7 +1530,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Agent A",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_a_final) == 2
|
||||
# Verify the remaining messages are the correct ones
|
||||
@@ -1574,7 +1544,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Agent B",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_b_final) == 0
|
||||
|
||||
@@ -1583,84 +1552,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
await server.agent_manager.delete_agent_async(agent_a.id, default_user)
|
||||
await server.agent_manager.delete_agent_async(agent_b.id, default_user)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_crud_operations_without_embedding_config(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that CRUD operations handle missing embedding_config gracefully"""
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create message WITH embedding_config
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Message with searchable content about databases")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
message_id = messages[0].id
|
||||
|
||||
# Verify message is searchable initially
|
||||
initial_search = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(initial_search) > 0
|
||||
assert any(msg.id == message_id for msg, metadata in initial_search)
|
||||
|
||||
# Update message WITHOUT embedding_config - should update postgres but not turbopuffer
|
||||
updated_message = await server.message_manager.update_message_by_id_async(
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about algorithms"),
|
||||
actor=default_user,
|
||||
embedding_config=None, # No config provided
|
||||
)
|
||||
|
||||
# Verify postgres was updated
|
||||
assert updated_message.id == message_id
|
||||
updated_text = server.message_manager._extract_message_text(updated_message)
|
||||
assert "algorithms" in updated_text
|
||||
assert "databases" not in updated_text
|
||||
|
||||
# Original search term should STILL find the message (turbopuffer wasn't updated)
|
||||
still_searchable = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(still_searchable) > 0
|
||||
assert any(msg.id == message_id for msg, metadata in still_searchable)
|
||||
|
||||
# New content should NOT be searchable (wasn't re-indexed)
|
||||
not_searchable = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="algorithms",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# Should either find no results or results that don't include our message
|
||||
assert not any(msg.id == message_id for msg, metadata in not_searchable)
|
||||
|
||||
# Clean up
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_turbopuffer_failure_does_not_break_postgres(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
@@ -1681,7 +1572,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
@@ -1702,7 +1592,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated despite turbopuffer failure"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Don't fail on turbopuffer errors - that's what we're testing!
|
||||
)
|
||||
|
||||
@@ -1722,7 +1611,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=None, # Create without embedding to avoid mock issues
|
||||
)
|
||||
message_to_delete_id = messages2[0].id
|
||||
|
||||
@@ -1741,7 +1629,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False)
|
||||
|
||||
async def wait_for_embedding(
|
||||
self, agent_id: str, message_id: str, organization_id: str, max_wait: float = 10.0, poll_interval: float = 0.5
|
||||
self, agent_id: str, message_id: str, organization_id: str, actor, max_wait: float = 10.0, poll_interval: float = 0.5
|
||||
) -> bool:
|
||||
"""Poll Turbopuffer directly to check if a message has been embedded.
|
||||
|
||||
@@ -1765,9 +1653,10 @@ class TestTurbopufferMessagesIntegration:
|
||||
while asyncio.get_event_loop().time() - start_time < max_wait:
|
||||
try:
|
||||
# Query Turbopuffer directly using timestamp mode to get all messages
|
||||
results = await client.query_messages(
|
||||
results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=organization_id,
|
||||
actor=actor,
|
||||
search_mode="timestamp",
|
||||
top_k=100, # Get more messages to ensure we find it
|
||||
)
|
||||
@@ -1800,7 +1689,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Background mode
|
||||
)
|
||||
|
||||
@@ -1814,7 +1702,12 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
# Poll for embedding completion by querying Turbopuffer directly
|
||||
embedded = await self.wait_for_embedding(
|
||||
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
|
||||
agent_id=sarah_agent.id,
|
||||
message_id=message_id,
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
max_wait=10.0,
|
||||
poll_interval=0.5,
|
||||
)
|
||||
assert embedded, "Message was not embedded in Turbopuffer within timeout"
|
||||
|
||||
@@ -1825,7 +1718,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="Python programming",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(search_results) > 0
|
||||
assert any(msg.id == message_id for msg, _ in search_results)
|
||||
@@ -1851,7 +1743,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True, # Ensure initial embedding
|
||||
)
|
||||
|
||||
@@ -1865,7 +1756,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert any(msg.id == message_id for msg, _ in initial_results)
|
||||
|
||||
@@ -1874,7 +1764,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about machine learning"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Background mode
|
||||
)
|
||||
|
||||
@@ -1890,7 +1779,12 @@ class TestTurbopufferMessagesIntegration:
|
||||
# Poll for the update to be reflected in Turbopuffer
|
||||
# We check by searching for the new content
|
||||
embedded = await self.wait_for_embedding(
|
||||
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
|
||||
agent_id=sarah_agent.id,
|
||||
message_id=message_id,
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
max_wait=10.0,
|
||||
poll_interval=0.5,
|
||||
)
|
||||
assert embedded, "Updated message was not re-embedded within timeout"
|
||||
|
||||
@@ -1901,7 +1795,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="machine learning",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert any(msg.id == message_id for msg, _ in new_results)
|
||||
|
||||
@@ -1914,7 +1807,6 @@ class TestTurbopufferMessagesIntegration:
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# The message shouldn't match the old search term anymore
|
||||
if len(old_results) > 0:
|
||||
@@ -1929,7 +1821,7 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):
|
||||
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding, default_user):
|
||||
"""Test filtering messages by date range"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -1959,21 +1851,17 @@ class TestTurbopufferMessagesIntegration:
|
||||
await client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=[text],
|
||||
embeddings=[[1.0, 2.0, 3.0]],
|
||||
message_ids=[str(uuid.uuid4())],
|
||||
organization_id=org_id,
|
||||
actor=default_user,
|
||||
roles=[MessageRole.assistant],
|
||||
created_ats=[timestamp],
|
||||
)
|
||||
|
||||
# Query messages from the last 3 days
|
||||
three_days_ago = now - timedelta(days=3)
|
||||
recent_results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
recent_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id, organization_id=org_id, search_mode="timestamp", top_k=10, start_date=three_days_ago, actor=default_user
|
||||
)
|
||||
|
||||
# Should get today's and yesterday's messages
|
||||
@@ -1984,13 +1872,14 @@ class TestTurbopufferMessagesIntegration:
|
||||
|
||||
# Query messages between 2 weeks ago and 1 week ago
|
||||
two_weeks_ago = now - timedelta(days=14)
|
||||
week_results = await client.query_messages(
|
||||
week_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
start_date=two_weeks_ago,
|
||||
end_date=last_week + timedelta(days=1), # Include last week's message
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should get only last week's message
|
||||
@@ -1998,10 +1887,11 @@ class TestTurbopufferMessagesIntegration:
|
||||
assert week_results[0][0]["text"] == "Last week's message"
|
||||
|
||||
# Query with vector search and date filtering
|
||||
filtered_vector_results = await client.query_messages(
|
||||
filtered_vector_results = await client.query_messages_by_agent_id(
|
||||
agent_id=agent_id,
|
||||
organization_id=org_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
actor=default_user,
|
||||
query_text="message",
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
@@ -2101,7 +1991,7 @@ class TestNamespaceTracking:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_project_id_filtering(self, server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding):
|
||||
"""Test that project_id filtering works correctly in query_messages"""
|
||||
"""Test that project_id filtering works correctly in query_messages_by_agent_id"""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
|
||||
# Create two project IDs
|
||||
@@ -2124,24 +2014,15 @@ class TestNamespaceTracking:
|
||||
# Insert messages with their respective project IDs
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Generate embeddings
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=sarah_agent.embedding_config.embedding_endpoint_type,
|
||||
actor=default_user,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(
|
||||
[message_a.content[0].text, message_b.content[0].text], sarah_agent.embedding_config
|
||||
)
|
||||
# Embeddings will be generated automatically by the client
|
||||
|
||||
# Insert message A with project_a_id
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
message_texts=[message_a.content[0].text],
|
||||
embeddings=[embeddings[0]],
|
||||
message_ids=[message_a.id],
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
roles=[message_a.role],
|
||||
created_ats=[message_a.created_at],
|
||||
project_id=project_a_id,
|
||||
@@ -2151,9 +2032,9 @@ class TestNamespaceTracking:
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
message_texts=[message_b.content[0].text],
|
||||
embeddings=[embeddings[1]],
|
||||
message_ids=[message_b.id],
|
||||
organization_id=default_user.organization_id,
|
||||
actor=default_user,
|
||||
roles=[message_b.role],
|
||||
created_ats=[message_b.created_at],
|
||||
project_id=project_b_id,
|
||||
@@ -2162,12 +2043,13 @@ class TestNamespaceTracking:
|
||||
# Poll for message A with project_a_id filter
|
||||
max_retries = 10
|
||||
for i in range(max_retries):
|
||||
results_a = await tpuf_client.query_messages(
|
||||
results_a = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp", # Simple timestamp retrieval
|
||||
top_k=10,
|
||||
project_id=project_a_id,
|
||||
actor=default_user,
|
||||
)
|
||||
if len(results_a) == 1 and results_a[0][0]["id"] == message_a.id:
|
||||
break
|
||||
@@ -2179,12 +2061,13 @@ class TestNamespaceTracking:
|
||||
|
||||
# Poll for message B with project_b_id filter
|
||||
for i in range(max_retries):
|
||||
results_b = await tpuf_client.query_messages(
|
||||
results_b = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
project_id=project_b_id,
|
||||
actor=default_user,
|
||||
)
|
||||
if len(results_b) == 1 and results_b[0][0]["id"] == message_b.id:
|
||||
break
|
||||
@@ -2195,12 +2078,13 @@ class TestNamespaceTracking:
|
||||
assert "JavaScript" in results_b[0][0]["text"]
|
||||
|
||||
# Query without project filter - should find both
|
||||
results_all = await tpuf_client.query_messages(
|
||||
results_all = await tpuf_client.query_messages_by_agent_id(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
project_id=None, # No filter
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert len(results_all) >= 2 # May have other messages from setup
|
||||
|
||||
Reference in New Issue
Block a user