feat: Support timestamp filtering for archival memories [LET-3469] (#4330)

Finish temporal filtering
This commit is contained in:
Matthew Zhou
2025-08-30 19:31:07 -07:00
committed by GitHub
parent 6c160e1d1d
commit 335e0c2be1
5 changed files with 300 additions and 13 deletions

View File

@@ -78,16 +78,37 @@ async def archival_memory_insert(self: "Agent", content: str, tags: Optional[lis
async def archival_memory_search(
self: "Agent", query: str, tags: Optional[list[str]] = None, tag_match_mode: Literal["any", "all"] = "any", top_k: Optional[int] = None
self: "Agent",
query: str,
tags: Optional[list[str]] = None,
tag_match_mode: Literal["any", "all"] = "any",
top_k: Optional[int] = None,
start_datetime: Optional[str] = None,
end_datetime: Optional[str] = None,
) -> Optional[str]:
"""
Search archival memory using semantic (embedding-based) search.
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
Args:
query (str): String to search for using semantic similarity.
tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned.
tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any".
top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified.
start_datetime (Optional[str]): Filter results to passages created after this datetime. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30".
end_datetime (Optional[str]): Filter results to passages created before this datetime. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00".
Examples:
# Search all passages
archival_memory_search(query="project updates")
# Search with date range (full days)
archival_memory_search(query="meetings", start_datetime="2024-01-15", end_datetime="2024-01-20")
# Search with specific time range
archival_memory_search(query="error logs", start_datetime="2024-01-15T09:30", end_datetime="2024-01-15T17:30")
# Search from a specific point in time onwards
archival_memory_search(query="customer feedback", start_datetime="2024-01-15T14:00")
Returns:
str: Query result string containing matching passages with timestamps and content.

View File

@@ -167,6 +167,8 @@ class TurbopufferClient:
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[Tuple[PydanticPassage, float]]:
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
@@ -180,6 +182,8 @@ class TurbopufferClient:
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter passages created after this date
end_date: Optional datetime to filter passages created before this date
Returns:
List of (passage, score) tuples
@@ -225,6 +229,29 @@ class TurbopufferClient:
# For ANY mode, use ContainsAny to match any of the tags
tag_filter = ("tags", "ContainsAny", tags)
# build date filter conditions
date_filters = []
if start_date:
# Turbopuffer expects datetime objects directly for comparison
date_filters.append(("created_at", "Gte", start_date))
if end_date:
# Turbopuffer expects datetime objects directly for comparison
date_filters.append(("created_at", "Lte", end_date))
# combine all filters
all_filters = []
if tag_filter:
all_filters.append(tag_filter)
if date_filters:
all_filters.extend(date_filters)
# create final filter expression
final_filter = None
if len(all_filters) == 1:
final_filter = all_filters[0]
elif len(all_filters) > 1:
final_filter = ("And", all_filters)
if search_mode == "timestamp":
# Fallback: retrieve most recent passages by timestamp
query_params = {
@@ -232,8 +259,8 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
query_params["filters"] = tag_filter
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags)
@@ -245,8 +272,8 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
query_params["filters"] = tag_filter
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags)
@@ -258,8 +285,8 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
query_params["filters"] = tag_filter
if final_filter:
query_params["filters"] = final_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags, is_fts=True)
@@ -274,8 +301,8 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
vector_query["filters"] = tag_filter
if final_filter:
vector_query["filters"] = final_filter
queries.append(vector_query)
# full-text search query
@@ -284,8 +311,8 @@ class TurbopufferClient:
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
fts_query["filters"] = tag_filter
if final_filter:
fts_query["filters"] = final_filter
queries.append(fts_query)
# execute multi-query

View File

@@ -2690,6 +2690,8 @@ class AgentManager:
top_k=limit,
tags=tags,
tag_match_mode=tag_match_mode or TagMatchMode.ANY,
start_date=start_date,
end_date=end_date,
)
# Return just the passages (without scores)

View File

@@ -126,20 +126,70 @@ class LettaCoreToolExecutor(ToolExecutor):
tags: Optional[list[str]] = None,
tag_match_mode: Literal["any", "all"] = "any",
top_k: Optional[int] = None,
start_datetime: Optional[str] = None,
end_datetime: Optional[str] = None,
) -> Optional[str]:
"""
Search archival memory using semantic (embedding-based) search.
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
Args:
query (str): String to search for using semantic similarity.
tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned.
tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any".
top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified.
start_datetime (Optional[str]): Filter results to passages created after this datetime. ISO 8601 format.
end_datetime (Optional[str]): Filter results to passages created before this datetime. ISO 8601 format.
Returns:
str: Query result string containing matching passages with timestamps, content, and tags.
"""
try:
# Parse datetime parameters if provided
from datetime import datetime
start_date = None
end_date = None
if start_datetime:
try:
# Try parsing as full datetime first (with time)
start_date = datetime.fromisoformat(start_datetime)
except ValueError:
try:
# Fall back to date-only format
start_date = datetime.strptime(start_datetime, "%Y-%m-%d")
# Set to beginning of day
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
except ValueError:
raise ValueError(
f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
)
# Apply agent's timezone if datetime is naive
if start_date.tzinfo is None and agent_state.timezone:
tz = ZoneInfo(agent_state.timezone)
start_date = start_date.replace(tzinfo=tz)
if end_datetime:
try:
# Try parsing as full datetime first (with time)
end_date = datetime.fromisoformat(end_datetime)
except ValueError:
try:
# Fall back to date-only format
end_date = datetime.strptime(end_datetime, "%Y-%m-%d")
# Set to end of day for end dates
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise ValueError(
f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
)
# Apply agent's timezone if datetime is naive
if end_date.tzinfo is None and agent_state.timezone:
tz = ZoneInfo(agent_state.timezone)
end_date = end_date.replace(tzinfo=tz)
# Convert string to TagMatchMode enum
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
@@ -154,6 +204,8 @@ class LettaCoreToolExecutor(ToolExecutor):
embed_query=True,
tags=tags,
tag_match_mode=tag_mode,
start_date=start_date,
end_date=end_date,
)
# Format results to include tags with friendly timestamps

View File

@@ -527,6 +527,141 @@ class TestTurbopufferIntegration:
except:
pass
@pytest.mark.asyncio
async def test_temporal_filtering_with_real_tpuf(self, enable_turbopuffer):
"""Test temporal filtering with date ranges"""
from datetime import datetime, timedelta, timezone
# Skip if Turbopuffer is not properly configured
if not should_use_tpuf():
pytest.skip("Turbopuffer not configured - skipping TPUF temporal filtering test")
# Create client
client = TurbopufferClient()
# Create a unique archive ID for this test
archive_id = f"test-temporal-{uuid.uuid4()}"
try:
# Create passages with different timestamps
now = datetime.now(timezone.utc)
yesterday = now - timedelta(days=1)
last_week = now - timedelta(days=7)
last_month = now - timedelta(days=30)
# Insert passages with specific timestamps
test_passages = [
("Today's meeting notes about project Alpha", now),
("Yesterday's standup summary", yesterday),
("Last week's sprint review", last_week),
("Last month's quarterly planning", last_month),
]
# We need to generate embeddings for the passages
# For testing, we'll use simple dummy embeddings
for text, timestamp in test_passages:
dummy_embedding = [1.0, 2.0, 3.0] # Simple test embedding
passage_id = f"passage-{uuid.uuid4()}"
await client.insert_archival_memories(
archive_id=archive_id,
text_chunks=[text],
embeddings=[dummy_embedding],
passage_ids=[passage_id],
organization_id="test-org",
created_at=timestamp,
)
# Test 1: Query with date range (last 3 days)
three_days_ago = now - timedelta(days=3)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
search_mode="vector",
top_k=10,
start_date=three_days_ago,
end_date=now,
)
# Should only get today's and yesterday's passages
passages = [p for p, _ in results]
texts = [p.text for p in passages]
assert len(passages) == 2
assert "Today's meeting notes" in texts[0] or "Today's meeting notes" in texts[1]
assert "Yesterday's standup" in texts[0] or "Yesterday's standup" in texts[1]
assert "Last week's sprint" not in str(texts)
assert "Last month's quarterly" not in str(texts)
# Test 2: Query with only start_date (everything after 2 weeks ago)
two_weeks_ago = now - timedelta(days=14)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
search_mode="vector",
top_k=10,
start_date=two_weeks_ago,
)
# Should get all except last month's passage
passages = [p for p, _ in results]
assert len(passages) == 3
texts = [p.text for p in passages]
assert "Last month's quarterly" not in str(texts)
# Test 3: Query with only end_date (everything before yesterday)
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
search_mode="vector",
top_k=10,
end_date=yesterday + timedelta(hours=12), # Middle of yesterday
)
# Should get yesterday and older passages
passages = [p for p, _ in results]
assert len(passages) >= 3 # yesterday, last week, last month
texts = [p.text for p in passages]
assert "Today's meeting notes" not in str(texts)
# Test 4: Test with FTS mode and date filtering
results = await client.query_passages(
archive_id=archive_id,
query_text="meeting notes project",
search_mode="fts",
top_k=10,
start_date=yesterday,
)
# Should only find today's meeting notes
passages = [p for p, _ in results]
if len(passages) > 0: # FTS might not match if text search doesn't find keywords
texts = [p.text for p in passages]
assert "Today's meeting notes" in texts[0]
# Test 5: Test with hybrid mode and date filtering
results = await client.query_passages(
archive_id=archive_id,
query_embedding=[1.0, 2.0, 3.0],
query_text="sprint review",
search_mode="hybrid",
top_k=10,
start_date=last_week - timedelta(days=1),
end_date=last_week + timedelta(days=1),
)
# Should find last week's sprint review
passages = [p for p, _ in results]
if len(passages) > 0:
texts = [p.text for p in passages]
assert "Last week's sprint review" in texts[0]
finally:
# Clean up
try:
await client.delete_all_passages(archive_id)
except:
pass
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
class TestTurbopufferParametrized:
@@ -566,3 +701,53 @@ class TestTurbopufferParametrized:
# Verify deletion
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
assert not any(p.id == passages[0].id for p in remaining)
@pytest.mark.asyncio
async def test_temporal_filtering_in_both_modes(self, turbopuffer_mode, server, default_user, sarah_agent):
"""Test that temporal filtering works in both NATIVE and TPUF modes"""
from datetime import datetime, timedelta, timezone
# Insert passages with different timestamps
now = datetime.now(timezone.utc)
yesterday = now - timedelta(days=1)
last_week = now - timedelta(days=7)
# Insert passages with specific timestamps
recent_passage = await server.passage_manager.insert_passage(
agent_state=sarah_agent, text="Recent update from today", actor=default_user, created_at=now
)
old_passage = await server.passage_manager.insert_passage(
agent_state=sarah_agent, text="Old update from last week", actor=default_user, created_at=last_week
)
# Query with date range that includes only recent passage
start_date = yesterday
end_date = now + timedelta(hours=1) # Slightly in the future to ensure we catch it
# Query with date filtering
results = await server.agent_manager.query_agent_passages_async(
actor=default_user, agent_id=sarah_agent.id, start_date=start_date, end_date=end_date, limit=10
)
# Should find only the recent passage, not the old one
assert len(results) >= 1
assert any("Recent update from today" in p.text for p in results)
assert not any("Old update from last week" in p.text for p in results)
# Query with date range that includes only the old passage
old_start = last_week - timedelta(days=1)
old_end = last_week + timedelta(days=1)
old_results = await server.agent_manager.query_agent_passages_async(
actor=default_user, agent_id=sarah_agent.id, start_date=old_start, end_date=old_end, limit=10
)
# Should find only the old passage
assert len(old_results) >= 1
assert any("Old update from last week" in p.text for p in old_results)
assert not any("Recent update from today" in p.text for p in old_results)
# Clean up
await server.passage_manager.delete_agent_passages_async(recent_passage, default_user)
await server.passage_manager.delete_agent_passages_async(old_passage, default_user)