From 0854ba0d0159990a24334013bbce76aea1db6b59 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Sat, 30 Aug 2025 19:31:07 -0700 Subject: [PATCH] feat: Support timestamp filtering for archival memories [LET-3469] (#4330) Finish temporal filtering --- letta/functions/function_sets/base.py | 25 ++- letta/helpers/tpuf_client.py | 47 ++++- letta/services/agent_manager.py | 2 + .../tool_executor/core_tool_executor.py | 54 ++++- tests/integration_test_turbopuffer.py | 185 ++++++++++++++++++ 5 files changed, 300 insertions(+), 13 deletions(-) diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 53f9e180..5ae5ab9a 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -78,16 +78,37 @@ async def archival_memory_insert(self: "Agent", content: str, tags: Optional[lis async def archival_memory_search( - self: "Agent", query: str, tags: Optional[list[str]] = None, tag_match_mode: Literal["any", "all"] = "any", top_k: Optional[int] = None + self: "Agent", + query: str, + tags: Optional[list[str]] = None, + tag_match_mode: Literal["any", "all"] = "any", + top_k: Optional[int] = None, + start_datetime: Optional[str] = None, + end_datetime: Optional[str] = None, ) -> Optional[str]: """ - Search archival memory using semantic (embedding-based) search. + Search archival memory using semantic (embedding-based) search with optional temporal filtering. Args: query (str): String to search for using semantic similarity. tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned. tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any". top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified. + start_datetime (Optional[str]): Filter results to passages created after this datetime. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30". + end_datetime (Optional[str]): Filter results to passages created before this datetime. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00". + + Examples: + # Search all passages + archival_memory_search(query="project updates") + + # Search with date range (full days) + archival_memory_search(query="meetings", start_datetime="2024-01-15", end_datetime="2024-01-20") + + # Search with specific time range + archival_memory_search(query="error logs", start_datetime="2024-01-15T09:30", end_datetime="2024-01-15T17:30") + + # Search from a specific point in time onwards + archival_memory_search(query="customer feedback", start_datetime="2024-01-15T14:00") Returns: str: Query result string containing matching passages with timestamps and content. diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 7ed1e915..144e443b 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -167,6 +167,8 @@ class TurbopufferClient: tag_match_mode: TagMatchMode = TagMatchMode.ANY, vector_weight: float = 0.5, fts_weight: float = 0.5, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> List[Tuple[PydanticPassage, float]]: """Query passages from Turbopuffer using vector search, full-text search, or hybrid search. @@ -180,6 +182,8 @@ class TurbopufferClient: tag_match_mode: TagMatchMode.ANY (match any tag) or TagMatchMode.ALL (match all tags) - default: TagMatchMode.ANY vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) + start_date: Optional datetime to filter passages created after this date + end_date: Optional datetime to filter passages created before this date Returns: List of (passage, score) tuples @@ -225,6 +229,29 @@ class TurbopufferClient: # For ANY mode, use ContainsAny to match any of the tags tag_filter = ("tags", "ContainsAny", tags) + # build date filter conditions + date_filters = [] + if start_date: + # Turbopuffer expects datetime objects directly for comparison + date_filters.append(("created_at", "Gte", start_date)) + if end_date: + # Turbopuffer expects datetime objects directly for comparison + date_filters.append(("created_at", "Lte", end_date)) + + # combine all filters + all_filters = [] + if tag_filter: + all_filters.append(tag_filter) + if date_filters: + all_filters.extend(date_filters) + + # create final filter expression + final_filter = None + if len(all_filters) == 1: + final_filter = all_filters[0] + elif len(all_filters) > 1: + final_filter = ("And", all_filters) + if search_mode == "timestamp": # Fallback: retrieve most recent passages by timestamp query_params = { @@ -232,8 +259,8 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } - if tag_filter: - query_params["filters"] = tag_filter + if final_filter: + query_params["filters"] = final_filter result = await namespace.query(**query_params) return self._process_single_query_results(result, archive_id, tags) @@ -245,8 +272,8 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } - if tag_filter: - query_params["filters"] = tag_filter + if final_filter: + query_params["filters"] = final_filter result = await namespace.query(**query_params) return self._process_single_query_results(result, archive_id, tags) @@ -258,8 +285,8 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } - if tag_filter: - query_params["filters"] = tag_filter + if final_filter: + query_params["filters"] = final_filter result = await namespace.query(**query_params) return self._process_single_query_results(result, archive_id, tags, is_fts=True) @@ -274,8 +301,8 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } - if tag_filter: - vector_query["filters"] = tag_filter + if final_filter: + vector_query["filters"] = final_filter queries.append(vector_query) # full-text search query @@ -284,8 +311,8 @@ class TurbopufferClient: "top_k": top_k, "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } - if tag_filter: - fts_query["filters"] = tag_filter + if final_filter: + fts_query["filters"] = final_filter queries.append(fts_query) # execute multi-query diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 239916df..05a9fa9d 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -2690,6 +2690,8 @@ class AgentManager: top_k=limit, tags=tags, tag_match_mode=tag_match_mode or TagMatchMode.ANY, + start_date=start_date, + end_date=end_date, ) # Return just the passages (without scores) diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 8db4ea6c..695041b2 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -126,20 +126,70 @@ class LettaCoreToolExecutor(ToolExecutor): tags: Optional[list[str]] = None, tag_match_mode: Literal["any", "all"] = "any", top_k: Optional[int] = None, + start_datetime: Optional[str] = None, + end_datetime: Optional[str] = None, ) -> Optional[str]: """ - Search archival memory using semantic (embedding-based) search. + Search archival memory using semantic (embedding-based) search with optional temporal filtering. Args: query (str): String to search for using semantic similarity. tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned. tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any". top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified. + start_datetime (Optional[str]): Filter results to passages created after this datetime. ISO 8601 format. + end_datetime (Optional[str]): Filter results to passages created before this datetime. ISO 8601 format. Returns: str: Query result string containing matching passages with timestamps, content, and tags. """ try: + # Parse datetime parameters if provided + from datetime import datetime + + start_date = None + end_date = None + + if start_datetime: + try: + # Try parsing as full datetime first (with time) + start_date = datetime.fromisoformat(start_datetime) + except ValueError: + try: + # Fall back to date-only format + start_date = datetime.strptime(start_datetime, "%Y-%m-%d") + # Set to beginning of day + start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) + except ValueError: + raise ValueError( + f"Invalid start_datetime format: {start_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)" + ) + + # Apply agent's timezone if datetime is naive + if start_date.tzinfo is None and agent_state.timezone: + tz = ZoneInfo(agent_state.timezone) + start_date = start_date.replace(tzinfo=tz) + + if end_datetime: + try: + # Try parsing as full datetime first (with time) + end_date = datetime.fromisoformat(end_datetime) + except ValueError: + try: + # Fall back to date-only format + end_date = datetime.strptime(end_datetime, "%Y-%m-%d") + # Set to end of day for end dates + end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999) + except ValueError: + raise ValueError( + f"Invalid end_datetime format: {end_datetime}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)" + ) + + # Apply agent's timezone if datetime is naive + if end_date.tzinfo is None and agent_state.timezone: + tz = ZoneInfo(agent_state.timezone) + end_date = end_date.replace(tzinfo=tz) + # Convert string to TagMatchMode enum tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL @@ -154,6 +204,8 @@ class LettaCoreToolExecutor(ToolExecutor): embed_query=True, tags=tags, tag_match_mode=tag_mode, + start_date=start_date, + end_date=end_date, ) # Format results to include tags with friendly timestamps diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index 0bc8555d..5cbda488 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -527,6 +527,141 @@ class TestTurbopufferIntegration: except: pass + @pytest.mark.asyncio + async def test_temporal_filtering_with_real_tpuf(self, enable_turbopuffer): + """Test temporal filtering with date ranges""" + from datetime import datetime, timedelta, timezone + + # Skip if Turbopuffer is not properly configured + if not should_use_tpuf(): + pytest.skip("Turbopuffer not configured - skipping TPUF temporal filtering test") + + # Create client + client = TurbopufferClient() + + # Create a unique archive ID for this test + archive_id = f"test-temporal-{uuid.uuid4()}" + + try: + # Create passages with different timestamps + now = datetime.now(timezone.utc) + yesterday = now - timedelta(days=1) + last_week = now - timedelta(days=7) + last_month = now - timedelta(days=30) + + # Insert passages with specific timestamps + test_passages = [ + ("Today's meeting notes about project Alpha", now), + ("Yesterday's standup summary", yesterday), + ("Last week's sprint review", last_week), + ("Last month's quarterly planning", last_month), + ] + + # We need to generate embeddings for the passages + # For testing, we'll use simple dummy embeddings + for text, timestamp in test_passages: + dummy_embedding = [1.0, 2.0, 3.0] # Simple test embedding + passage_id = f"passage-{uuid.uuid4()}" + + await client.insert_archival_memories( + archive_id=archive_id, + text_chunks=[text], + embeddings=[dummy_embedding], + passage_ids=[passage_id], + organization_id="test-org", + created_at=timestamp, + ) + + # Test 1: Query with date range (last 3 days) + three_days_ago = now - timedelta(days=3) + results = await client.query_passages( + archive_id=archive_id, + query_embedding=[1.0, 2.0, 3.0], + search_mode="vector", + top_k=10, + start_date=three_days_ago, + end_date=now, + ) + + # Should only get today's and yesterday's passages + passages = [p for p, _ in results] + texts = [p.text for p in passages] + assert len(passages) == 2 + assert "Today's meeting notes" in texts[0] or "Today's meeting notes" in texts[1] + assert "Yesterday's standup" in texts[0] or "Yesterday's standup" in texts[1] + assert "Last week's sprint" not in str(texts) + assert "Last month's quarterly" not in str(texts) + + # Test 2: Query with only start_date (everything after 2 weeks ago) + two_weeks_ago = now - timedelta(days=14) + results = await client.query_passages( + archive_id=archive_id, + query_embedding=[1.0, 2.0, 3.0], + search_mode="vector", + top_k=10, + start_date=two_weeks_ago, + ) + + # Should get all except last month's passage + passages = [p for p, _ in results] + assert len(passages) == 3 + texts = [p.text for p in passages] + assert "Last month's quarterly" not in str(texts) + + # Test 3: Query with only end_date (everything before yesterday) + results = await client.query_passages( + archive_id=archive_id, + query_embedding=[1.0, 2.0, 3.0], + search_mode="vector", + top_k=10, + end_date=yesterday + timedelta(hours=12), # Middle of yesterday + ) + + # Should get yesterday and older passages + passages = [p for p, _ in results] + assert len(passages) >= 3 # yesterday, last week, last month + texts = [p.text for p in passages] + assert "Today's meeting notes" not in str(texts) + + # Test 4: Test with FTS mode and date filtering + results = await client.query_passages( + archive_id=archive_id, + query_text="meeting notes project", + search_mode="fts", + top_k=10, + start_date=yesterday, + ) + + # Should only find today's meeting notes + passages = [p for p, _ in results] + if len(passages) > 0: # FTS might not match if text search doesn't find keywords + texts = [p.text for p in passages] + assert "Today's meeting notes" in texts[0] + + # Test 5: Test with hybrid mode and date filtering + results = await client.query_passages( + archive_id=archive_id, + query_embedding=[1.0, 2.0, 3.0], + query_text="sprint review", + search_mode="hybrid", + top_k=10, + start_date=last_week - timedelta(days=1), + end_date=last_week + timedelta(days=1), + ) + + # Should find last week's sprint review + passages = [p for p, _ in results] + if len(passages) > 0: + texts = [p.text for p in passages] + assert "Last week's sprint review" in texts[0] + + finally: + # Clean up + try: + await client.delete_all_passages(archive_id) + except: + pass + @pytest.mark.parametrize("turbopuffer_mode", [True, False], indirect=True) class TestTurbopufferParametrized: @@ -566,3 +701,53 @@ class TestTurbopufferParametrized: # Verify deletion remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10) assert not any(p.id == passages[0].id for p in remaining) + + @pytest.mark.asyncio + async def test_temporal_filtering_in_both_modes(self, turbopuffer_mode, server, default_user, sarah_agent): + """Test that temporal filtering works in both NATIVE and TPUF modes""" + from datetime import datetime, timedelta, timezone + + # Insert passages with different timestamps + now = datetime.now(timezone.utc) + yesterday = now - timedelta(days=1) + last_week = now - timedelta(days=7) + + # Insert passages with specific timestamps + recent_passage = await server.passage_manager.insert_passage( + agent_state=sarah_agent, text="Recent update from today", actor=default_user, created_at=now + ) + + old_passage = await server.passage_manager.insert_passage( + agent_state=sarah_agent, text="Old update from last week", actor=default_user, created_at=last_week + ) + + # Query with date range that includes only recent passage + start_date = yesterday + end_date = now + timedelta(hours=1) # Slightly in the future to ensure we catch it + + # Query with date filtering + results = await server.agent_manager.query_agent_passages_async( + actor=default_user, agent_id=sarah_agent.id, start_date=start_date, end_date=end_date, limit=10 + ) + + # Should find only the recent passage, not the old one + assert len(results) >= 1 + assert any("Recent update from today" in p.text for p in results) + assert not any("Old update from last week" in p.text for p in results) + + # Query with date range that includes only the old passage + old_start = last_week - timedelta(days=1) + old_end = last_week + timedelta(days=1) + + old_results = await server.agent_manager.query_agent_passages_async( + actor=default_user, agent_id=sarah_agent.id, start_date=old_start, end_date=old_end, limit=10 + ) + + # Should find only the old passage + assert len(old_results) >= 1 + assert any("Old update from last week" in p.text for p in old_results) + assert not any("Recent update from today" in p.text for p in old_results) + + # Clean up + await server.passage_manager.delete_agent_passages_async(recent_passage, default_user) + await server.passage_manager.delete_agent_passages_async(old_passage, default_user)