feat: Add ranks to archival memory search [LET-4193] (#4426)

* Add ranks to archival memory search

* Fix test managers

* Fix archival memory test
This commit is contained in:
Matthew Zhou
2025-09-04 13:35:54 -07:00
committed by GitHub
parent ef225d3e49
commit def95050e2
8 changed files with 106 additions and 81 deletions

View File

@@ -494,7 +494,8 @@ class VoiceAgent(BaseAgent):
start_date=start_date,
end_date=end_date,
)
formatted_archival_results = [{"timestamp": str(result.created_at), "content": result.text} for result in archival_results]
# Extract passages from tuples and format
formatted_archival_results = [{"timestamp": str(passage.created_at), "content": passage.text} for passage, _, _ in archival_results]
response = {
"archival_search_results": formatted_archival_results,
}

View File

@@ -392,7 +392,7 @@ class TurbopufferClient:
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[PydanticPassage, float]]:
) -> List[Tuple[PydanticPassage, float, dict]]:
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
Args:
@@ -409,7 +409,7 @@ class TurbopufferClient:
end_date: Optional datetime to filter passages created before this date
Returns:
List of (passage, score) tuples
List of (passage, score, metadata) tuples with relevance rankings
"""
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None and search_mode not in ["timestamp"]:
@@ -474,7 +474,7 @@ class TurbopufferClient:
# for hybrid mode, we get a multi-query response
vector_results = self._process_single_query_results(result.results[0], archive_id, tags)
fts_results = self._process_single_query_results(result.results[1], archive_id, tags, is_fts=True)
# use RRF and return only (passage, score) for backwards compatibility
# use RRF and include metadata with ranks
results_with_metadata = self._reciprocal_rank_fusion(
vector_results=[passage for passage, _ in vector_results],
fts_results=[passage for passage, _ in fts_results],
@@ -483,11 +483,21 @@ class TurbopufferClient:
fts_weight=fts_weight,
top_k=top_k,
)
return [(passage, rrf_score) for passage, rrf_score, metadata in results_with_metadata]
# Return (passage, score, metadata) with ranks
return results_with_metadata
else:
# for single queries (vector, fts, timestamp)
# for single queries (vector, fts, timestamp) - add basic metadata
is_fts = search_mode == "fts"
return self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
results = self._process_single_query_results(result, archive_id, tags, is_fts=is_fts)
# Add simple metadata for single search modes
results_with_metadata = []
for idx, (passage, score) in enumerate(results):
metadata = {
"combined_score": score,
f"{search_mode}_rank": idx + 1, # Add the rank for this search mode
}
results_with_metadata.append((passage, score, metadata))
return results_with_metadata
except Exception as e:
logger.error(f"Failed to query passages from Turbopuffer: {e}")

View File

@@ -1013,7 +1013,7 @@ async def search_archival_memory(
end_datetime = end_datetime.isoformat() if end_datetime else None
# Use the shared agent manager method
formatted_results, count = await server.agent_manager.search_agent_archival_memory_async(
formatted_results = await server.agent_manager.search_agent_archival_memory_async(
agent_id=agent_id,
actor=actor,
query=query,
@@ -1027,7 +1027,7 @@ async def search_archival_memory(
# Convert to proper response schema
search_results = [ArchivalMemorySearchResult(**result) for result in formatted_results]
return ArchivalMemorySearchResponse(results=search_results, count=count)
return ArchivalMemorySearchResponse(results=search_results, count=len(formatted_results))
except NoResultFound as e:
raise HTTPException(status_code=404, detail=f"Agent with id={agent_id} not found for user_id={actor.id}.")

View File

@@ -1125,7 +1125,8 @@ class SyncServer(Server):
ascending=ascending,
limit=limit,
)
return records
# Extract just the passages (SQL path returns empty metadata)
return [passage for passage, _, _ in records]
async def insert_archival_memory_async(
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]

View File

@@ -2655,7 +2655,7 @@ class AgentManager:
embedding_config: Optional[EmbeddingConfig] = None,
tags: Optional[List[str]] = None,
tag_match_mode: Optional[TagMatchMode] = None,
) -> List[PydanticPassage]:
) -> List[Tuple[PydanticPassage, float, dict]]:
"""Lists all passages attached to an agent."""
# Check if we should use Turbopuffer for vector search
if embed_query and agent_id and query_text and embedding_config:
@@ -2698,8 +2698,8 @@ class AgentManager:
end_date=end_date,
)
# Return just the passages (without scores)
return [passage for passage, _ in passages_with_scores]
# Return full tuples with metadata
return passages_with_scores
else:
return []
@@ -2750,9 +2750,11 @@ class AgentManager:
if query_tags.intersection(passage_tags):
filtered_passages.append(passage)
return filtered_passages
# Return as tuples with empty metadata for SQL path
return [(p, 0.0, {}) for p in filtered_passages]
return pydantic_passages
# Return as tuples with empty metadata for SQL path
return [(p, 0.0, {}) for p in pydantic_passages]
@enforce_types
@trace_method
@@ -2766,7 +2768,7 @@ class AgentManager:
top_k: Optional[int] = None,
start_datetime: Optional[str] = None,
end_datetime: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
) -> List[Dict[str, Any]]:
"""
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
@@ -2783,11 +2785,11 @@ class AgentManager:
end_datetime: Filter results before this datetime (ISO 8601 format)
Returns:
Tuple of (formatted_results, count)
List of formatted results with relevance metadata
"""
# Handle empty or whitespace-only queries
if not query or not query.strip():
return [], 0
return []
# Get the agent to access timezone and embedding config
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
@@ -2839,7 +2841,7 @@ class AgentManager:
# Get results using existing passage query method
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
all_results = await self.query_agent_passages_async(
passages_with_metadata = await self.query_agent_passages_async(
actor=actor,
agent_id=agent_id,
query_text=query,
@@ -2852,11 +2854,11 @@ class AgentManager:
end_date=end_date,
)
# Format results to include tags with friendly timestamps
# Format results to include tags with friendly timestamps and relevance metadata
formatted_results = []
for result in all_results:
for passage, score, metadata in passages_with_metadata:
# Format timestamp in agent's timezone if available
timestamp = result.created_at
timestamp = passage.created_at
if timestamp and agent_state.timezone:
try:
# Convert to agent's timezone
@@ -2871,9 +2873,26 @@ class AgentManager:
# Use ISO format if no timezone is set
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
result_dict = {"timestamp": formatted_timestamp, "content": passage.text, "tags": passage.tags or []}
return formatted_results, len(formatted_results)
# Add relevance metadata if available
if metadata:
relevance_info = {
k: v
for k, v in {
"rrf_score": metadata.get("combined_score"),
"vector_rank": metadata.get("vector_rank"),
"fts_rank": metadata.get("fts_rank"),
}.items()
if v is not None
}
if relevance_info: # Only add if we have metadata
result_dict["relevance"] = relevance_info
formatted_results.append(result_dict)
return formatted_results
@enforce_types
@trace_method

View File

@@ -302,7 +302,7 @@ class LettaCoreToolExecutor(ToolExecutor):
"""
try:
# Use the shared service method to get results
formatted_results, count = await self.agent_manager.search_agent_archival_memory_async(
formatted_results = await self.agent_manager.search_agent_archival_memory_async(
agent_id=agent_state.id,
actor=actor,
query=query,
@@ -313,7 +313,7 @@ class LettaCoreToolExecutor(ToolExecutor):
end_datetime=end_datetime,
)
return formatted_results, count
return formatted_results
except Exception as e:
raise e

View File

@@ -188,7 +188,7 @@ class TestTurbopufferIntegration:
sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert len(sql_passages) >= len(test_passages)
for text in test_passages:
assert any(p.text == text for p in sql_passages)
assert any(p.text == text for p, _, _ in sql_passages)
# Test vector search which should use Turbopuffer
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
@@ -206,15 +206,15 @@ class TestTurbopufferIntegration:
# Should find relevant passages via Turbopuffer vector search
assert len(vector_results) > 0
# The most relevant result should be about Turbopuffer
assert any("Turbopuffer" in p.text or "vector" in p.text for p in vector_results)
assert any("Turbopuffer" in p.text or "vector" in p.text for p, _, _ in vector_results)
# Test deletion - should delete from both
passage_to_delete = sql_passages[0]
passage_to_delete = sql_passages[0][0] # Extract passage from tuple
await server.passage_manager.delete_agent_passages_async([passage_to_delete], default_user, strict_mode=True)
# Verify deleted from SQL
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert not any(p.id == passage_to_delete.id for p in remaining)
assert not any(p.id == passage_to_delete.id for p, _, _ in remaining)
# Verify vector search no longer returns deleted passage
vector_results_after_delete = await server.agent_manager.query_agent_passages_async(
@@ -225,7 +225,7 @@ class TestTurbopufferIntegration:
embed_query=True,
limit=10,
)
assert not any(p.id == passage_to_delete.id for p in vector_results_after_delete)
assert not any(p.id == passage_to_delete.id for p, _, _ in vector_results_after_delete)
finally:
# TODO: Clean up archive when delete_archive method is available
@@ -286,7 +286,7 @@ class TestTurbopufferIntegration:
# Should get all passages
assert len(results) == 3 # All three passages
for passage, score in results:
for passage, score, metadata in results:
assert passage.organization_id is not None
# Clean up
@@ -321,7 +321,7 @@ class TestTurbopufferIntegration:
# List passages - should work from SQL
sql_passages = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert any(p.text == text_content for p in sql_passages)
assert any(p.text == text_content for p, _, _ in sql_passages)
# Vector search should use PostgreSQL pgvector
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
@@ -377,7 +377,7 @@ class TestTurbopufferIntegration:
)
assert 0 < len(vector_results) <= 3
# all results should have scores
assert all(isinstance(score, float) for _, score in vector_results)
assert all(isinstance(score, float) for _, score, _ in vector_results)
# Test FTS-only search
fts_results = await client.query_passages(
@@ -385,9 +385,9 @@ class TestTurbopufferIntegration:
)
assert 0 < len(fts_results) <= 3
# should find passages mentioning Turbopuffer
assert any("Turbopuffer" in passage.text for passage, _ in fts_results)
assert any("Turbopuffer" in passage.text for passage, _, _ in fts_results)
# all results should have scores
assert all(isinstance(score, float) for _, score in fts_results)
assert all(isinstance(score, float) for _, score, _ in fts_results)
# Test hybrid search
hybrid_results = await client.query_passages(
@@ -401,11 +401,11 @@ class TestTurbopufferIntegration:
)
assert 0 < len(hybrid_results) <= 3
# hybrid should combine both vector and text relevance
assert any("Turbopuffer" in passage.text or "vector" in passage.text for passage, _ in hybrid_results)
assert any("Turbopuffer" in passage.text or "vector" in passage.text for passage, _, _ in hybrid_results)
# all results should have scores
assert all(isinstance(score, float) for _, score in hybrid_results)
assert all(isinstance(score, float) for _, score, _ in hybrid_results)
# results should be sorted by score (highest first)
scores = [score for _, score in hybrid_results]
scores = [score for _, score, _ in hybrid_results]
assert scores == sorted(scores, reverse=True)
# Test with different weights
@@ -420,7 +420,7 @@ class TestTurbopufferIntegration:
)
assert 0 < len(vector_heavy_results) <= 3
# all results should have scores
assert all(isinstance(score, float) for _, score in vector_heavy_results)
assert all(isinstance(score, float) for _, score, _ in vector_heavy_results)
# Test error handling - missing text for hybrid mode (embedding provided but text missing)
with pytest.raises(ValueError, match="Both query_embedding and query_text are required"):
@@ -434,7 +434,7 @@ class TestTurbopufferIntegration:
timestamp_results = await client.query_passages(archive_id=archive_id, search_mode="timestamp", top_k=3)
assert len(timestamp_results) <= 3
# Should return passages ordered by timestamp (most recent first)
assert all(isinstance(passage, Passage) for passage, _ in timestamp_results)
assert all(isinstance(passage, Passage) for passage, _, _ in timestamp_results)
finally:
# Clean up
@@ -500,7 +500,7 @@ class TestTurbopufferIntegration:
)
# Should find 3 passages with python tag
python_passages = [passage for passage, _ in python_any_results]
python_passages = [passage for passage, _, _ in python_any_results]
python_texts = [p.text for p in python_passages]
assert len(python_passages) == 3
assert "Python programming tutorial" in python_texts
@@ -518,7 +518,7 @@ class TestTurbopufferIntegration:
)
# Should find 2 passages that have both python AND tutorial tags
tutorial_passages = [passage for passage, _ in python_tutorial_all_results]
tutorial_passages = [passage for passage, _, _ in python_tutorial_all_results]
tutorial_texts = [p.text for p in tutorial_passages]
assert len(tutorial_passages) == 2
assert "Python programming tutorial" in tutorial_texts
@@ -535,7 +535,7 @@ class TestTurbopufferIntegration:
)
# Should find 2 passages with javascript tag
js_passages = [passage for passage, _ in js_fts_results]
js_passages = [passage for passage, _, _ in js_fts_results]
js_texts = [p.text for p in js_passages]
assert len(js_passages) == 2
assert "JavaScript web development" in js_texts
@@ -555,7 +555,7 @@ class TestTurbopufferIntegration:
)
# Should find python-tagged passages
hybrid_passages = [passage for passage, _ in python_hybrid_results]
hybrid_passages = [passage for passage, _, _ in python_hybrid_results]
hybrid_texts = [p.text for p in hybrid_passages]
assert len(hybrid_passages) == 3
assert all("Python" in text for text in hybrid_texts)
@@ -624,7 +624,7 @@ class TestTurbopufferIntegration:
)
# Should only get today's and yesterday's passages
passages = [p for p, _ in results]
passages = [p for p, _, _ in results]
texts = [p.text for p in passages]
assert len(passages) == 2
assert "Today's meeting notes" in texts[0] or "Today's meeting notes" in texts[1]
@@ -643,7 +643,7 @@ class TestTurbopufferIntegration:
)
# Should get all except last month's passage
passages = [p for p, _ in results]
passages = [p for p, _, _ in results]
assert len(passages) == 3
texts = [p.text for p in passages]
assert "Last month's quarterly" not in str(texts)
@@ -658,7 +658,7 @@ class TestTurbopufferIntegration:
)
# Should get yesterday and older passages
passages = [p for p, _ in results]
passages = [p for p, _, _ in results]
assert len(passages) >= 3 # yesterday, last week, last month
texts = [p.text for p in passages]
assert "Today's meeting notes" not in str(texts)
@@ -673,7 +673,7 @@ class TestTurbopufferIntegration:
)
# Should only find today's meeting notes
passages = [p for p, _ in results]
passages = [p for p, _, _ in results]
if len(passages) > 0: # FTS might not match if text search doesn't find keywords
texts = [p.text for p in passages]
assert "Today's meeting notes" in texts[0]
@@ -690,7 +690,7 @@ class TestTurbopufferIntegration:
)
# Should find last week's sprint review
passages = [p for p, _ in results]
passages = [p for p, _, _ in results]
if len(passages) > 0:
texts = [p.text for p in passages]
assert "Last week's sprint review" in texts[0]
@@ -735,14 +735,14 @@ class TestTurbopufferParametrized:
# List passages should work in both modes
listed = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert any(p.text == test_text for p in listed)
assert any(p.text == test_text for p, _, _ in listed)
# Delete should work in both modes
await server.passage_manager.delete_agent_passages_async(passages, default_user, strict_mode=True)
# Verify deletion
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert not any(p.id == passages[0].id for p in remaining)
assert not any(p.id == passages[0].id for p, _, _ in remaining)
@pytest.mark.asyncio
async def test_temporal_filtering_in_both_modes(self, turbopuffer_mode, server, default_user, sarah_agent):
@@ -774,8 +774,8 @@ class TestTurbopufferParametrized:
# Should find only the recent passage, not the old one
assert len(results) >= 1
assert any("Recent update from today" in p.text for p in results)
assert not any("Old update from last week" in p.text for p in results)
assert any("Recent update from today" in p.text for p, _, _ in results)
assert not any("Old update from last week" in p.text for p, _, _ in results)
# Query with date range that includes only the old passage
old_start = last_week - timedelta(days=1)
@@ -787,8 +787,8 @@ class TestTurbopufferParametrized:
# Should find only the old passage
assert len(old_results) >= 1
assert any("Old update from last week" in p.text for p in old_results)
assert not any("Recent update from today" in p.text for p in old_results)
assert any("Old update from last week" in p.text for p, _, _ in old_results)
assert not any("Recent update from today" in p.text for p, _, _ in old_results)
# Clean up
await server.passage_manager.delete_agent_passages_async(recent_passage, default_user, strict_mode=True)

View File

@@ -3658,7 +3658,7 @@ async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServe
tag_match_mode=TagMatchMode.ANY,
)
python_texts = [p.text for p in python_results]
python_texts = [p.text for p, _, _ in python_results]
assert len([t for t in python_texts if "Python" in t]) >= 2
# Test querying with multiple tags using ALL mode
@@ -3669,7 +3669,7 @@ async def test_passage_tags_functionality(disable_turbopuffer, server: SyncServe
tag_match_mode=TagMatchMode.ALL,
)
tutorial_texts = [p.text for p in tutorial_python_results]
tutorial_texts = [p.text for p, _, _ in tutorial_python_results]
expected_matches = [t for t in tutorial_texts if "tutorial" in t and "Python" in t]
assert len(expected_matches) >= 1
@@ -3747,7 +3747,7 @@ async def test_comprehensive_tag_functionality(disable_turbopuffer, server: Sync
)
# Should match passages with "important" OR "api" tags (passages 1, 2, 3, 4)
[p.text for p in any_results]
[p.text for p, _, _ in any_results]
assert len(any_results) >= 4
# Test 5: Query passages with ALL tag matching
@@ -3761,7 +3761,7 @@ async def test_comprehensive_tag_functionality(disable_turbopuffer, server: Sync
)
# Should only match passage4 which has both "python" AND "testing"
all_passage_texts = [p.text for p in all_results]
all_passage_texts = [p.text for p, _, _ in all_results]
assert any("Test passage 4" in text for text in all_passage_texts)
# Test 6: Query with non-existent tags
@@ -4029,12 +4029,11 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
created_passages.append(passage)
# Test 1: Basic search by query text
results, count = await server.agent_manager.search_agent_archival_memory_async(
results = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="Python programming"
)
assert count > 0
assert len(results) == count
assert len(results) > 0
# Check structure of results
for result in results:
@@ -4044,27 +4043,27 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
assert isinstance(result["tags"], list)
# Test 2: Search with tag filtering - single tag
results, count = await server.agent_manager.search_agent_archival_memory_async(
results = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="programming", tags=["python"]
)
assert count > 0
assert len(results) > 0
# All results should have "python" tag
for result in results:
assert "python" in result["tags"]
# Test 3: Search with tag filtering - multiple tags with "any" mode
results, count = await server.agent_manager.search_agent_archival_memory_async(
results = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="development", tags=["web", "database"], tag_match_mode="any"
)
assert count > 0
assert len(results) > 0
# All results should have at least one of the specified tags
for result in results:
assert any(tag in result["tags"] for tag in ["web", "database"])
# Test 4: Search with tag filtering - multiple tags with "all" mode
results, count = await server.agent_manager.search_agent_archival_memory_async(
results = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="Python", tags=["python", "web"], tag_match_mode="all"
)
@@ -4074,15 +4073,14 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
assert "web" in result["tags"]
# Test 5: Search with top_k limit
results, count = await server.agent_manager.search_agent_archival_memory_async(
results = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="programming", top_k=2
)
assert count <= 2
assert len(results) <= 2
# Test 6: Search with datetime filtering
results, count = await server.agent_manager.search_agent_archival_memory_async(
results = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query="programming", start_datetime="2024-01-16", end_datetime="2024-01-17"
)
@@ -4094,7 +4092,7 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
assert "2024-01-16" in timestamp_str or "2024-01-17" in timestamp_str
# Test 7: Search with ISO datetime format
results, count = await server.agent_manager.search_agent_archival_memory_async(
results = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id,
actor=default_user,
query="algorithms",
@@ -4103,7 +4101,7 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
)
# Should include the machine learning passage created at 14:45
assert count >= 0 # Might be 0 if no results, but shouldn't error
assert len(results) >= 0 # Might be 0 if no results, but shouldn't error
# Test 8: Search with non-existent agent should raise error
non_existent_agent_id = "agent-00000000-0000-4000-8000-000000000000"
@@ -4118,18 +4116,14 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
)
# Test 10: Empty query should return empty results
results, count = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query="")
results = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query="")
assert count == 0 # Empty query should return 0 results
assert len(results) == 0
assert len(results) == 0 # Empty query should return 0 results
# Test 11: Whitespace-only query should also return empty results
results, count = await server.agent_manager.search_agent_archival_memory_async(
agent_id=sarah_agent.id, actor=default_user, query=" \n\t "
)
results = await server.agent_manager.search_agent_archival_memory_async(agent_id=sarah_agent.id, actor=default_user, query=" \n\t ")
assert count == 0 # Whitespace-only query should return 0 results
assert len(results) == 0
assert len(results) == 0 # Whitespace-only query should return 0 results
# Cleanup - delete the created passages
for passage in created_passages: