feat: Add ranks to archival memory search [LET-4193] (#4426)
* Add ranks to archival memory search * Fix test managers * Fix archival memory test
This commit is contained in:
@@ -494,7 +494,8 @@ class VoiceAgent(BaseAgent):
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
formatted_archival_results = [{"timestamp": str(result.created_at), "content": result.text} for result in archival_results]
|
||||
# Extract passages from tuples and format
|
||||
formatted_archival_results = [{"timestamp": str(passage.created_at), "content": passage.text} for passage, _, _ in archival_results]
|
||||
response = {
|
||||
"archival_search_results": formatted_archival_results,
|
||||
}
|
||||
|
||||
@@ -392,7 +392,7 @@ class TurbopufferClient:
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
) -> List[Tuple[PydanticPassage, float, dict]]:
|
||||
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
|
||||
|
||||
Args:
|
||||
@@ -409,7 +409,7 @@ class TurbopufferClient:
|
||||
end_date: Optional datetime to filter passages created before this date
|
||||
|
||||
Returns:
|
||||
List of (passage, score) tuples
|
||||
List of (passage, score, metadata) tuples with relevance rankings
|
||||
"""
|
||||
# Check if we should fallback to timestamp-based retrieval
|
||||
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
|
||||
@@ -474,7 +474,7 @@ class TurbopufferClient:
|
||||
# for hybrid mode, we get a multi-query response
|
||||
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
|
||||
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
|
||||
# use RRF and return only (passage, score) for backwards compatibility
|
||||
# use RRF and include metadata with ranks
|
||||
results_with_metadata = self._reciprocal_rank_fusion(
|
||||
vector_results=[passage for passage, _ in vector_results],
|
||||
fts_results=[passage for passage, _ in fts_results],
|
||||
@@ -483,11 +483,21 @@ class TurbopufferClient:
|
||||
fts_weight=fts_weight,
|
||||
top_k=top_k,
|
||||
)
|
||||
return [(passage, rrf_score) for passage, rrf_score, metadata in results_with_metadata]
|
||||
# Return (passage, score, metadata) with ranks
|
||||
return results_with_metadata
|
||||
else:
|
||||
# for single queries (vector, fts, timestamp)
|
||||
# for single queries (vector, fts, timestamp) - add basic metadata
|
||||
is_fts = search_mode == "fts"
|
||||
return self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
|
||||
results = self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
|
||||
# Add simple metadata for single search modes
|
||||
results_with_metadata = []
|
||||
for idx, (passage, score) in enumerate(results):
|
||||
metadata = {
|
||||
"combined_score": score,
|
||||
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
|
||||
}
|
||||
results_with_metadata.append((passage, score, metadata))
|
||||
return results_with_metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query passages from Turbopuffer: {e}")
|
||||
|
||||
@@ -1013,7 +1013,7 @@ async def search_archival_memory(
|
||||
end_datetime = end_datetime.isoformat() if end_datetime else None
|
||||
|
||||
# Use the shared agent manager method
|
||||
formatted_results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
formatted_results = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
query=query,
|
||||
@@ -1027,7 +1027,7 @@ async def search_archival_memory(
|
||||
# Convert to proper response schema
|
||||
search_results = [ArchivalMemorySearchResult(**result) for result in formatted_results]
|
||||
|
||||
return ArchivalMemorySearchResponse(results=search_results, count=count)
|
||||
return ArchivalMemorySearchResponse(results=search_results, count=len(formatted_results))
|
||||
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")
|
||||
|
||||
@@ -1125,7 +1125,8 @@ class SyncServer(Server):
|
||||
ascending=ascending,
|
||||
limit=limit,
|
||||
)
|
||||
return records
|
||||
# Extract just the passages (SQL path returns empty metadata)
|
||||
return [passage for passage, _, _ in records]
|
||||
|
||||
async def insert_archival_memory_async(
|
||||
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]
|
||||
|
||||
@@ -2655,7 +2655,7 @@ class AgentManager:
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
tag_match_mode: Optional[TagMatchMode] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
) -> List[Tuple[PydanticPassage, float, dict]]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
# Check if we should use Turbopuffer for vector search
|
||||
if embed_query and agent_id and query_text and embedding_config:
|
||||
@@ -2698,8 +2698,8 @@ class AgentManager:
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Return just the passages (without scores)
|
||||
return [passage for passage, _ in passages_with_scores]
|
||||
# Return full tuples with metadata
|
||||
return passages_with_scores
|
||||
else:
|
||||
return []
|
||||
|
||||
@@ -2750,9 +2750,11 @@ class AgentManager:
|
||||
if query_tags.intersection(passage_tags):
|
||||
filtered_passages.append(passage)
|
||||
|
||||
return filtered_passages
|
||||
# Return as tuples with empty metadata for SQL path
|
||||
return [(p, 0.0, {}) for p in filtered_passages]
|
||||
|
||||
return pydantic_passages
|
||||
# Return as tuples with empty metadata for SQL path
|
||||
return [(p, 0.0, {}) for p in pydantic_passages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -2766,7 +2768,7 @@ class AgentManager:
|
||||
top_k: Optional[int] = None,
|
||||
start_datetime: Optional[str] = None,
|
||||
end_datetime: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
|
||||
|
||||
@@ -2783,11 +2785,11 @@ class AgentManager:
|
||||
end_datetime: Filter results before this datetime (ISO 8601 format)
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_results, count)
|
||||
List of formatted results with relevance metadata
|
||||
"""
|
||||
# Handle empty or whitespace-only queries
|
||||
if not query or not query.strip():
|
||||
return [], 0
|
||||
return []
|
||||
|
||||
# Get the agent to access timezone and embedding config
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
@@ -2839,7 +2841,7 @@ class AgentManager:
|
||||
|
||||
# Get results using existing passage query method
|
||||
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
all_results = await self.query_agent_passages_async(
|
||||
passages_with_metadata = await self.query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
query_text=query,
|
||||
@@ -2852,11 +2854,11 @@ class AgentManager:
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Format results to include tags with friendly timestamps
|
||||
# Format results to include tags with friendly timestamps and relevance metadata
|
||||
formatted_results = []
|
||||
for result in all_results:
|
||||
for passage, score, metadata in passages_with_metadata:
|
||||
# Format timestamp in agent's timezone if available
|
||||
timestamp = result.created_at
|
||||
timestamp = passage.created_at
|
||||
if timestamp and agent_state.timezone:
|
||||
try:
|
||||
# Convert to agent's timezone
|
||||
@@ -2871,9 +2873,26 @@ class AgentManager:
|
||||
# Use ISO format if no timezone is set
|
||||
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
|
||||
|
||||
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
|
||||
result_dict = {"timestamp": formatted_timestamp, "content": passage.text, "tags": passage.tags or []}
|
||||
|
||||
return formatted_results, len(formatted_results)
|
||||
# Add relevance metadata if available
|
||||
if metadata:
|
||||
relevance_info = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"rrf_score": metadata.get("combined_score"),
|
||||
"vector_rank": metadata.get("vector_rank"),
|
||||
"fts_rank": metadata.get("fts_rank"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
if relevance_info: # Only add if we have metadata
|
||||
result_dict["relevance"] = relevance_info
|
||||
|
||||
formatted_results.append(result_dict)
|
||||
|
||||
return formatted_results
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -302,7 +302,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
"""
|
||||
try:
|
||||
# Use the shared service method to get results
|
||||
formatted_results, count = await self.agent_manager.search_agent_archival_memory_async(
|
||||
formatted_results = await self.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=agent_state.id,
|
||||
actor=actor,
|
||||
query=query,
|
||||
@@ -313,7 +313,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
end_datetime=end_datetime,
|
||||
)
|
||||
|
||||
return formatted_results, count
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@@ -188,7 +188,7 @@ class TestTurbopufferIntegration:
|
||||
sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert len(sql_passages) >= len(test_passages)
|
||||
for text in test_passages:
|
||||
assert any(p.text == text for p in sql_passages)
|
||||
assert any(p.text == text for p, _, _ in sql_passages)
|
||||
|
||||
# Test vector search which should use Turbopuffer
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
@@ -206,15 +206,15 @@ class TestTurbopufferIntegration:
|
||||
# Should find relevant passages via Turbopuffer vector search
|
||||
assert len(vector_results) > 0
|
||||
# The most relevant result should be about Turbopuffer
|
||||
assert any("Turbopuffer" in p.text or "vector" in p.text for p in vector_results)
|
||||
assert any("Turbopuffer" in p.text or "vector" in p.text for p, _, _ in vector_results)
|
||||
|
||||
# Test deletion - should delete from both
|
||||
passage_to_delete = sql_passages[0]
|
||||
passage_to_delete = sql_passages[0][0] # Extract passage from tuple
|
||||
await server.passage_manager.delete_agent_passages_async([passage_to_delete], default_user, strict_mode=True)
|
||||
|
||||
# Verify deleted from SQL
|
||||
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert not any(p.id == passage_to_delete.id for p in remaining)
|
||||
assert not any(p.id == passage_to_delete.id for p, _, _ in remaining)
|
||||
|
||||
# Verify vector search no longer returns deleted passage
|
||||
vector_results_after_delete = await server.agent_manager.query_agent_passages_async(
|
||||
@@ -225,7 +225,7 @@ class TestTurbopufferIntegration:
|
||||
embed_query=True,
|
||||
limit=10,
|
||||
)
|
||||
assert not any(p.id == passage_to_delete.id for p in vector_results_after_delete)
|
||||
assert not any(p.id == passage_to_delete.id for p, _, _ in vector_results_after_delete)
|
||||
|
||||
finally:
|
||||
# TODO: Clean up archive when delete_archive method is available
|
||||
@@ -286,7 +286,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
# Should get all passages
|
||||
assert len(results) == 3 # All three passages
|
||||
for passage, score in results:
|
||||
for passage, score, metadata in results:
|
||||
assert passage.organization_id is not None
|
||||
|
||||
# Clean up
|
||||
@@ -321,7 +321,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
# List passages - should work from SQL
|
||||
sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert any(p.text == text_content for p in sql_passages)
|
||||
assert any(p.text == text_content for p, _, _ in sql_passages)
|
||||
|
||||
# Vector search should use PostgreSQL pgvector
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
@@ -377,7 +377,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
assert 0 < len(vector_results) <= 3
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in vector_results)
|
||||
assert all(isinstance(score, float) for _, score, _ in vector_results)
|
||||
|
||||
# Test FTS-only search
|
||||
fts_results = await client.query_passages(
|
||||
@@ -385,9 +385,9 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
assert 0 < len(fts_results) <= 3
|
||||
# should find passages mentioning Turbopuffer
|
||||
assert any("Turbopuffer" in passage.text for passage, _ in fts_results)
|
||||
assert any("Turbopuffer" in passage.text for passage, _, _ in fts_results)
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in fts_results)
|
||||
assert all(isinstance(score, float) for _, score, _ in fts_results)
|
||||
|
||||
# Test hybrid search
|
||||
hybrid_results = await client.query_passages(
|
||||
@@ -401,11 +401,11 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
assert 0 < len(hybrid_results) <= 3
|
||||
# hybrid should combine both vector and text relevance
|
||||
assert any("Turbopuffer" in passage.text or "vector" in passage.text for passage, _ in hybrid_results)
|
||||
assert any("Turbopuffer" in passage.text or "vector" in passage.text for passage, _, _ in hybrid_results)
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in hybrid_results)
|
||||
assert all(isinstance(score, float) for _, score, _ in hybrid_results)
|
||||
# results should be sorted by score (highest first)
|
||||
scores = [score for _, score in hybrid_results]
|
||||
scores = [score for _, score, _ in hybrid_results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
# Test with different weights
|
||||
@@ -420,7 +420,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
assert 0 < len(vector_heavy_results) <= 3
|
||||
# all results should have scores
|
||||
assert all(isinstance(score, float) for _, score in vector_heavy_results)
|
||||
assert all(isinstance(score, float) for _, score, _ in vector_heavy_results)
|
||||
|
||||
# Test error handling - missing text for hybrid mode (embedding provided but text missing)
|
||||
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
|
||||
@@ -434,7 +434,7 @@ class TestTurbopufferIntegration:
|
||||
timestamp_results = await client.query_passages(archive_id=archive_id, search_mode="timestamp", top_k=3)
|
||||
assert len(timestamp_results) <= 3
|
||||
# Should return passages ordered by timestamp (most recent first)
|
||||
assert all(isinstance(passage, Passage) for passage, _ in timestamp_results)
|
||||
assert all(isinstance(passage, Passage) for passage, _, _ in timestamp_results)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
@@ -500,7 +500,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should find 3 passages with python tag
|
||||
python_passages = [passage for passage, _ in python_any_results]
|
||||
python_passages = [passage for passage, _, _ in python_any_results]
|
||||
python_texts = [p.text for p in python_passages]
|
||||
assert len(python_passages) == 3
|
||||
assert "Python programming tutorial" in python_texts
|
||||
@@ -518,7 +518,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should find 2 passages that have both python AND tutorial tags
|
||||
tutorial_passages = [passage for passage, _ in python_tutorial_all_results]
|
||||
tutorial_passages = [passage for passage, _, _ in python_tutorial_all_results]
|
||||
tutorial_texts = [p.text for p in tutorial_passages]
|
||||
assert len(tutorial_passages) == 2
|
||||
assert "Python programming tutorial" in tutorial_texts
|
||||
@@ -535,7 +535,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should find 2 passages with javascript tag
|
||||
js_passages = [passage for passage, _ in js_fts_results]
|
||||
js_passages = [passage for passage, _, _ in js_fts_results]
|
||||
js_texts = [p.text for p in js_passages]
|
||||
assert len(js_passages) == 2
|
||||
assert "JavaScript web development" in js_texts
|
||||
@@ -555,7 +555,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should find python-tagged passages
|
||||
hybrid_passages = [passage for passage, _ in python_hybrid_results]
|
||||
hybrid_passages = [passage for passage, _, _ in python_hybrid_results]
|
||||
hybrid_texts = [p.text for p in hybrid_passages]
|
||||
assert len(hybrid_passages) == 3
|
||||
assert all("Python" in text for text in hybrid_texts)
|
||||
@@ -624,7 +624,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should only get today's and yesterday's passages
|
||||
passages = [p for p, _ in results]
|
||||
passages = [p for p, _, _ in results]
|
||||
texts = [p.text for p in passages]
|
||||
assert len(passages) == 2
|
||||
assert "Today's meeting notes" in texts[0] or "Today's meeting notes" in texts[1]
|
||||
@@ -643,7 +643,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should get all except last month's passage
|
||||
passages = [p for p, _ in results]
|
||||
passages = [p for p, _, _ in results]
|
||||
assert len(passages) == 3
|
||||
texts = [p.text for p in passages]
|
||||
assert "Last month's quarterly" not in str(texts)
|
||||
@@ -658,7 +658,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should get yesterday and older passages
|
||||
passages = [p for p, _ in results]
|
||||
passages = [p for p, _, _ in results]
|
||||
assert len(passages) >= 3 # yesterday, last week, last month
|
||||
texts = [p.text for p in passages]
|
||||
assert "Today's meeting notes" not in str(texts)
|
||||
@@ -673,7 +673,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should only find today's meeting notes
|
||||
passages = [p for p, _ in results]
|
||||
passages = [p for p, _, _ in results]
|
||||
if len(passages) > 0: # FTS might not match if text search doesn't find keywords
|
||||
texts = [p.text for p in passages]
|
||||
assert "Today's meeting notes" in texts[0]
|
||||
@@ -690,7 +690,7 @@ class TestTurbopufferIntegration:
|
||||
)
|
||||
|
||||
# Should find last week's sprint review
|
||||
passages = [p for p, _ in results]
|
||||
passages = [p for p, _, _ in results]
|
||||
if len(passages) > 0:
|
||||
texts = [p.text for p in passages]
|
||||
assert "Last week's sprint review" in texts[0]
|
||||
@@ -735,14 +735,14 @@ class TestTurbopufferParametrized:
|
||||
|
||||
# List passages should work in both modes
|
||||
listed = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert any(p.text == test_text for p in listed)
|
||||
assert any(p.text == test_text for p, _, _ in listed)
|
||||
|
||||
# Delete should work in both modes
|
||||
await server.passage_manager.delete_agent_passages_async(passages, default_user, strict_mode=True)
|
||||
|
||||
# Verify deletion
|
||||
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert not any(p.id == passages[0].id for p in remaining)
|
||||
assert not any(p.id == passages[0].id for p, _, _ in remaining)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_filtering_in_both_modes(self, turbopuffer_mode, server, default_user, sarah_agent):
|
||||
@@ -774,8 +774,8 @@ class TestTurbopufferParametrized:
|
||||
|
||||
# Should find only the recent passage, not the old one
|
||||
assert len(results) >= 1
|
||||
assert any("Recent update from today" in p.text for p in results)
|
||||
assert not any("Old update from last week" in p.text for p in results)
|
||||
assert any("Recent update from today" in p.text for p, _, _ in results)
|
||||
assert not any("Old update from last week" in p.text for p, _, _ in results)
|
||||
|
||||
# Query with date range that includes only the old passage
|
||||
old_start = last_week - timedelta(days=1)
|
||||
@@ -787,8 +787,8 @@ class TestTurbopufferParametrized:
|
||||
|
||||
# Should find only the old passage
|
||||
assert len(old_results) >= 1
|
||||
assert any("Old update from last week" in p.text for p in old_results)
|
||||
assert not any("Recent update from today" in p.text for p in old_results)
|
||||
assert any("Old update from last week" in p.text for p, _, _ in old_results)
|
||||
assert not any("Recent update from today" in p.text for p, _, _ in old_results)
|
||||
|
||||
# Clean up
|
||||
await server.passage_manager.delete_agent_passages_async(recent_passage, default_user, strict_mode=True)
|
||||
|
||||
@@ -3658,7 +3658,7 @@ async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServe
|
||||
tag_match_mode=TagMatchMode.ANY,
|
||||
)
|
||||
|
||||
python_texts = [p.text for p in python_results]
|
||||
python_texts = [p.text for p, _, _ in python_results]
|
||||
assert len([t for t in python_texts if "Python" in t]) >= 2
|
||||
|
||||
# Test querying with multiple tags using ALL mode
|
||||
@@ -3669,7 +3669,7 @@ async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServe
|
||||
tag_match_mode=TagMatchMode.ALL,
|
||||
)
|
||||
|
||||
tutorial_texts = [p.text for p in tutorial_python_results]
|
||||
tutorial_texts = [p.text for p, _, _ in tutorial_python_results]
|
||||
expected_matches = [t for t in tutorial_texts if "tutorial" in t and "Python" in t]
|
||||
assert len(expected_matches) >= 1
|
||||
|
||||
@@ -3747,7 +3747,7 @@ async def test_comprehensive_tag_functionality(disable_turbopuffer, server: Sync
|
||||
)
|
||||
|
||||
# Should match passages with "important" OR "api" tags (passages 1, 2, 3, 4)
|
||||
[p.text for p in any_results]
|
||||
[p.text for p, _, _ in any_results]
|
||||
assert len(any_results) >= 4
|
||||
|
||||
# Test 5: Query passages with ALL tag matching
|
||||
@@ -3761,7 +3761,7 @@ async def test_comprehensive_tag_functionality(disable_turbopuffer, server: Sync
|
||||
)
|
||||
|
||||
# Should only match passage4 which has both "python" AND "testing"
|
||||
all_passage_texts = [p.text for p in all_results]
|
||||
all_passage_texts = [p.text for p, _, _ in all_results]
|
||||
assert any("Test passage 4" in text for text in all_passage_texts)
|
||||
|
||||
# Test 6: Query with non-existent tags
|
||||
@@ -4029,12 +4029,11 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
|
||||
created_passages.append(passage)
|
||||
|
||||
# Test 1: Basic search by query text
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="Python programming"
|
||||
)
|
||||
|
||||
assert count > 0
|
||||
assert len(results) == count
|
||||
assert len(results) > 0
|
||||
|
||||
# Check structure of results
|
||||
for result in results:
|
||||
@@ -4044,27 +4043,27 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
|
||||
assert isinstance(result["tags"], list)
|
||||
|
||||
# Test 2: Search with tag filtering - single tag
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="programming", tags=["python"]
|
||||
)
|
||||
|
||||
assert count > 0
|
||||
assert len(results) > 0
|
||||
# All results should have "python" tag
|
||||
for result in results:
|
||||
assert "python" in result["tags"]
|
||||
|
||||
# Test 3: Search with tag filtering - multiple tags with "any" mode
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="development", tags=["web", "database"], tag_match_mode="any"
|
||||
)
|
||||
|
||||
assert count > 0
|
||||
assert len(results) > 0
|
||||
# All results should have at least one of the specified tags
|
||||
for result in results:
|
||||
assert any(tag in result["tags"] for tag in ["web", "database"])
|
||||
|
||||
# Test 4: Search with tag filtering - multiple tags with "all" mode
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="Python", tags=["python", "web"], tag_match_mode="all"
|
||||
)
|
||||
|
||||
@@ -4074,15 +4073,14 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
|
||||
assert "web" in result["tags"]
|
||||
|
||||
# Test 5: Search with top_k limit
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="programming", top_k=2
|
||||
)
|
||||
|
||||
assert count <= 2
|
||||
assert len(results) <= 2
|
||||
|
||||
# Test 6: Search with datetime filtering
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query="programming", start_datetime="2024-01-16", end_datetime="2024-01-17"
|
||||
)
|
||||
|
||||
@@ -4094,7 +4092,7 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
|
||||
assert "2024-01-16" in timestamp_str or "2024-01-17" in timestamp_str
|
||||
|
||||
# Test 7: Search with ISO datetime format
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query="algorithms",
|
||||
@@ -4103,7 +4101,7 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
|
||||
)
|
||||
|
||||
# Should include the machine learning passage created at 14:45
|
||||
assert count >= 0 # Might be 0 if no results, but shouldn't error
|
||||
assert len(results) >= 0 # Might be 0 if no results, but shouldn't error
|
||||
|
||||
# Test 8: Search with non-existent agent should raise error
|
||||
non_existent_agent_id = "agent-00000000-0000-4000-8000-000000000000"
|
||||
@@ -4118,18 +4116,14 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
|
||||
)
|
||||
|
||||
# Test 10: Empty query should return empty results
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query="")
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query="")
|
||||
|
||||
assert count == 0 # Empty query should return 0 results
|
||||
assert len(results) == 0
|
||||
assert len(results) == 0 # Empty query should return 0 results
|
||||
|
||||
# Test 11: Whitespace-only query should also return empty results
|
||||
results, count = await server.agent_manager.search_agent_archival_memory_async(
|
||||
agent_id=sarah_agent.id, actor=default_user, query=" \n\t "
|
||||
)
|
||||
results = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query=" \n\t ")
|
||||
|
||||
assert count == 0 # Whitespace-only query should return 0 results
|
||||
assert len(results) == 0
|
||||
assert len(results) == 0 # Whitespace-only query should return 0 results
|
||||
|
||||
# Cleanup - delete the created passages
|
||||
for passage in created_passages:
|
||||
|
||||
Reference in New Issue
Block a user