feat: Support timestamp filtering for archival memories [LET-3469] (#4330)
Finish temporal filtering
This commit is contained in:
@@ -78,16 +78,37 @@ async def archival_memory_insert(self: "Agent", content: str, tags: Optional[lis
|
||||
|
||||
|
||||
async def archival_memory_search(
|
||||
self: "Agent", query: str, tags: Optional[list[str]] = None, tag_match_mode: Literal["any", "all"] = "any", top_k: Optional[int] = None
|
||||
self: "Agent",
|
||||
query: str,
|
||||
tags: Optional[list[str]] = None,
|
||||
tag_match_mode: Literal["any", "all"] = "any",
|
||||
top_k: Optional[int] = None,
|
||||
start_datetime: Optional[str] = None,
|
||||
end_datetime: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search.
|
||||
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
|
||||
|
||||
Args:
|
||||
query (str): String to search for using semantic similarity.
|
||||
tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned.
|
||||
tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any".
|
||||
top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified.
|
||||
start_datetime (Optional[str]): Filter results to passages created after this datetime. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30".
|
||||
end_datetime (Optional[str]): Filter results to passages created before this datetime. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00".
|
||||
|
||||
Examples:
|
||||
# Search all passages
|
||||
archival_memory_search(query="project updates")
|
||||
|
||||
# Search with date range (full days)
|
||||
archival_memory_search(query="meetings", start_datetime="2024-01-15", end_datetime="2024-01-20")
|
||||
|
||||
# Search with specific time range
|
||||
archival_memory_search(query="error logs", start_datetime="2024-01-15T09:30", end_datetime="2024-01-15T17:30")
|
||||
|
||||
# Search from a specific point in time onwards
|
||||
archival_memory_search(query="customer feedback", start_datetime="2024-01-15T14:00")
|
||||
|
||||
Returns:
|
||||
str: Query result string containing matching passages with timestamps and content.
|
||||
|
||||
@@ -167,6 +167,8 @@ class TurbopufferClient:
|
||||
tag_match_mode: TagMatchMode = TagMatchMode.ANY,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
) -> List[Tuple[PydanticPassage, float]]:
|
||||
"""Query passages from Turbopuffer using vector search, full-text search, or hybrid search.
|
||||
|
||||
@@ -180,6 +182,8 @@ class TurbopufferClient:
|
||||
tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter passages created after this date
|
||||
end_date: Optional datetime to filter passages created before this date
|
||||
|
||||
Returns:
|
||||
List of (passage, score) tuples
|
||||
@@ -225,6 +229,29 @@ class TurbopufferClient:
|
||||
# For ANY mode, use ContainsAny to match any of the tags
|
||||
tag_filter = ("tags", "ContainsAny", tags)
|
||||
|
||||
# build date filter conditions
|
||||
date_filters = []
|
||||
if start_date:
|
||||
# Turbopuffer expects datetime objects directly for comparison
|
||||
date_filters.append(("created_at", "Gte", start_date))
|
||||
if end_date:
|
||||
# Turbopuffer expects datetime objects directly for comparison
|
||||
date_filters.append(("created_at", "Lte", end_date))
|
||||
|
||||
# combine all filters
|
||||
all_filters = []
|
||||
if tag_filter:
|
||||
all_filters.append(tag_filter)
|
||||
if date_filters:
|
||||
all_filters.extend(date_filters)
|
||||
|
||||
# create final filter expression
|
||||
final_filter = None
|
||||
if len(all_filters) == 1:
|
||||
final_filter = all_filters[0]
|
||||
elif len(all_filters) > 1:
|
||||
final_filter = ("And", all_filters)
|
||||
|
||||
if search_mode == "timestamp":
|
||||
# Fallback: retrieve most recent passages by timestamp
|
||||
query_params = {
|
||||
@@ -232,8 +259,8 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
query_params["filters"] = tag_filter
|
||||
if final_filter:
|
||||
query_params["filters"] = final_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, tags)
|
||||
@@ -245,8 +272,8 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
query_params["filters"] = tag_filter
|
||||
if final_filter:
|
||||
query_params["filters"] = final_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, tags)
|
||||
@@ -258,8 +285,8 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
query_params["filters"] = tag_filter
|
||||
if final_filter:
|
||||
query_params["filters"] = final_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, tags, is_fts=True)
|
||||
@@ -274,8 +301,8 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
vector_query["filters"] = tag_filter
|
||||
if final_filter:
|
||||
vector_query["filters"] = final_filter
|
||||
queries.append(vector_query)
|
||||
|
||||
# full-text search query
|
||||
@@ -284,8 +311,8 @@ class TurbopufferClient:
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
fts_query["filters"] = tag_filter
|
||||
if final_filter:
|
||||
fts_query["filters"] = final_filter
|
||||
queries.append(fts_query)
|
||||
|
||||
# execute multi-query
|
||||
|
||||
@@ -2690,6 +2690,8 @@ class AgentManager:
|
||||
top_k=limit,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_match_mode or TagMatchMode.ANY,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Return just the passages (without scores)
|
||||
|
||||
@@ -126,20 +126,70 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
tags: Optional[list[str]] = None,
|
||||
tag_match_mode: Literal["any", "all"] = "any",
|
||||
top_k: Optional[int] = None,
|
||||
start_datetime: Optional[str] = None,
|
||||
end_datetime: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search.
|
||||
Search archival memory using semantic (embedding-based) search with optional temporal filtering.
|
||||
|
||||
Args:
|
||||
query (str): String to search for using semantic similarity.
|
||||
tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned.
|
||||
tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any".
|
||||
top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified.
|
||||
start_datetime (Optional[str]): Filter results to passages created after this datetime. ISO 8601 format.
|
||||
end_datetime (Optional[str]): Filter results to passages created before this datetime. ISO 8601 format.
|
||||
|
||||
Returns:
|
||||
str: Query result string containing matching passages with timestamps, content, and tags.
|
||||
"""
|
||||
try:
|
||||
# Parse datetime parameters if provided
|
||||
from datetime import datetime
|
||||
|
||||
start_date = None
|
||||
end_date = None
|
||||
|
||||
if start_datetime:
|
||||
try:
|
||||
# Try parsing as full datetime first (with time)
|
||||
start_date = datetime.fromisoformat(start_datetime)
|
||||
except ValueError:
|
||||
try:
|
||||
# Fall back to date-only format
|
||||
start_date = datetime.strptime(start_datetime, "%Y-%m-%d")
|
||||
# Set to beginning of day
|
||||
start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
|
||||
)
|
||||
|
||||
# Apply agent's timezone if datetime is naive
|
||||
if start_date.tzinfo is None and agent_state.timezone:
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
start_date = start_date.replace(tzinfo=tz)
|
||||
|
||||
if end_datetime:
|
||||
try:
|
||||
# Try parsing as full datetime first (with time)
|
||||
end_date = datetime.fromisoformat(end_datetime)
|
||||
except ValueError:
|
||||
try:
|
||||
# Fall back to date-only format
|
||||
end_date = datetime.strptime(end_datetime, "%Y-%m-%d")
|
||||
# Set to end of day for end dates
|
||||
end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)"
|
||||
)
|
||||
|
||||
# Apply agent's timezone if datetime is naive
|
||||
if end_date.tzinfo is None and agent_state.timezone:
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
end_date = end_date.replace(tzinfo=tz)
|
||||
|
||||
# Convert string to TagMatchMode enum
|
||||
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
|
||||
|
||||
@@ -154,6 +204,8 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
embed_query=True,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_mode,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
# Format results to include tags with friendly timestamps
|
||||
|
||||
@@ -527,6 +527,141 @@ class TestTurbopufferIntegration:
|
||||
except:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_filtering_with_real_tpuf(self, enable_turbopuffer):
|
||||
"""Test temporal filtering with date ranges"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Skip if Turbopuffer is not properly configured
|
||||
if not should_use_tpuf():
|
||||
pytest.skip("Turbopuffer not configured - skipping TPUF temporal filtering test")
|
||||
|
||||
# Create client
|
||||
client = TurbopufferClient()
|
||||
|
||||
# Create a unique archive ID for this test
|
||||
archive_id = f"test-temporal-{uuid.uuid4()}"
|
||||
|
||||
try:
|
||||
# Create passages with different timestamps
|
||||
now = datetime.now(timezone.utc)
|
||||
yesterday = now - timedelta(days=1)
|
||||
last_week = now - timedelta(days=7)
|
||||
last_month = now - timedelta(days=30)
|
||||
|
||||
# Insert passages with specific timestamps
|
||||
test_passages = [
|
||||
("Today's meeting notes about project Alpha", now),
|
||||
("Yesterday's standup summary", yesterday),
|
||||
("Last week's sprint review", last_week),
|
||||
("Last month's quarterly planning", last_month),
|
||||
]
|
||||
|
||||
# We need to generate embeddings for the passages
|
||||
# For testing, we'll use simple dummy embeddings
|
||||
for text, timestamp in test_passages:
|
||||
dummy_embedding = [1.0, 2.0, 3.0] # Simple test embedding
|
||||
passage_id = f"passage-{uuid.uuid4()}"
|
||||
|
||||
await client.insert_archival_memories(
|
||||
archive_id=archive_id,
|
||||
text_chunks=[text],
|
||||
embeddings=[dummy_embedding],
|
||||
passage_ids=[passage_id],
|
||||
organization_id="test-org",
|
||||
created_at=timestamp,
|
||||
)
|
||||
|
||||
# Test 1: Query with date range (last 3 days)
|
||||
three_days_ago = now - timedelta(days=3)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=three_days_ago,
|
||||
end_date=now,
|
||||
)
|
||||
|
||||
# Should only get today's and yesterday's passages
|
||||
passages = [p for p, _ in results]
|
||||
texts = [p.text for p in passages]
|
||||
assert len(passages) == 2
|
||||
assert "Today's meeting notes" in texts[0] or "Today's meeting notes" in texts[1]
|
||||
assert "Yesterday's standup" in texts[0] or "Yesterday's standup" in texts[1]
|
||||
assert "Last week's sprint" not in str(texts)
|
||||
assert "Last month's quarterly" not in str(texts)
|
||||
|
||||
# Test 2: Query with only start_date (everything after 2 weeks ago)
|
||||
two_weeks_ago = now - timedelta(days=14)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
start_date=two_weeks_ago,
|
||||
)
|
||||
|
||||
# Should get all except last month's passage
|
||||
passages = [p for p, _ in results]
|
||||
assert len(passages) == 3
|
||||
texts = [p.text for p in passages]
|
||||
assert "Last month's quarterly" not in str(texts)
|
||||
|
||||
# Test 3: Query with only end_date (everything before yesterday)
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
search_mode="vector",
|
||||
top_k=10,
|
||||
end_date=yesterday + timedelta(hours=12), # Middle of yesterday
|
||||
)
|
||||
|
||||
# Should get yesterday and older passages
|
||||
passages = [p for p, _ in results]
|
||||
assert len(passages) >= 3 # yesterday, last week, last month
|
||||
texts = [p.text for p in passages]
|
||||
assert "Today's meeting notes" not in str(texts)
|
||||
|
||||
# Test 4: Test with FTS mode and date filtering
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_text="meeting notes project",
|
||||
search_mode="fts",
|
||||
top_k=10,
|
||||
start_date=yesterday,
|
||||
)
|
||||
|
||||
# Should only find today's meeting notes
|
||||
passages = [p for p, _ in results]
|
||||
if len(passages) > 0: # FTS might not match if text search doesn't find keywords
|
||||
texts = [p.text for p in passages]
|
||||
assert "Today's meeting notes" in texts[0]
|
||||
|
||||
# Test 5: Test with hybrid mode and date filtering
|
||||
results = await client.query_passages(
|
||||
archive_id=archive_id,
|
||||
query_embedding=[1.0, 2.0, 3.0],
|
||||
query_text="sprint review",
|
||||
search_mode="hybrid",
|
||||
top_k=10,
|
||||
start_date=last_week - timedelta(days=1),
|
||||
end_date=last_week + timedelta(days=1),
|
||||
)
|
||||
|
||||
# Should find last week's sprint review
|
||||
passages = [p for p, _ in results]
|
||||
if len(passages) > 0:
|
||||
texts = [p.text for p in passages]
|
||||
assert "Last week's sprint review" in texts[0]
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
try:
|
||||
await client.delete_all_passages(archive_id)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True)
|
||||
class TestTurbopufferParametrized:
|
||||
@@ -566,3 +701,53 @@ class TestTurbopufferParametrized:
|
||||
# Verify deletion
|
||||
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
assert not any(p.id == passages[0].id for p in remaining)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temporal_filtering_in_both_modes(self, turbopuffer_mode, server, default_user, sarah_agent):
|
||||
"""Test that temporal filtering works in both NATIVE and TPUF modes"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
# Insert passages with different timestamps
|
||||
now = datetime.now(timezone.utc)
|
||||
yesterday = now - timedelta(days=1)
|
||||
last_week = now - timedelta(days=7)
|
||||
|
||||
# Insert passages with specific timestamps
|
||||
recent_passage = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent, text="Recent update from today", actor=default_user, created_at=now
|
||||
)
|
||||
|
||||
old_passage = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent, text="Old update from last week", actor=default_user, created_at=last_week
|
||||
)
|
||||
|
||||
# Query with date range that includes only recent passage
|
||||
start_date = yesterday
|
||||
end_date = now + timedelta(hours=1) # Slightly in the future to ensure we catch it
|
||||
|
||||
# Query with date filtering
|
||||
results = await server.agent_manager.query_agent_passages_async(
|
||||
actor=default_user, agent_id=sarah_agent.id, start_date=start_date, end_date=end_date, limit=10
|
||||
)
|
||||
|
||||
# Should find only the recent passage, not the old one
|
||||
assert len(results) >= 1
|
||||
assert any("Recent update from today" in p.text for p in results)
|
||||
assert not any("Old update from last week" in p.text for p in results)
|
||||
|
||||
# Query with date range that includes only the old passage
|
||||
old_start = last_week - timedelta(days=1)
|
||||
old_end = last_week + timedelta(days=1)
|
||||
|
||||
old_results = await server.agent_manager.query_agent_passages_async(
|
||||
actor=default_user, agent_id=sarah_agent.id, start_date=old_start, end_date=old_end, limit=10
|
||||
)
|
||||
|
||||
# Should find only the old passage
|
||||
assert len(old_results) >= 1
|
||||
assert any("Old update from last week" in p.text for p in old_results)
|
||||
assert not any("Recent update from today" in p.text for p in old_results)
|
||||
|
||||
# Clean up
|
||||
await server.passage_manager.delete_agent_passages_async(recent_passage, default_user)
|
||||
await server.passage_manager.delete_agent_passages_async(old_passage, default_user)
|
||||
|
||||
Reference in New Issue
Block a user