From b1f2e8b2bc58d9ddff9f645c667aa04e959d34eb Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 29 Aug 2025 11:55:06 -0700 Subject: [PATCH] feat: Allow agent archival tools to insert/search with tags [LET-4072] (#4300) * Finish modifying archival memory tools * Add tags * Add disabled test --- letta/agents/base_agent.py | 18 ++++ letta/functions/function_sets/base.py | 60 +++--------- letta/helpers/tpuf_client.py | 92 +++++++++++-------- letta/prompts/prompt_generator.py | 9 ++ letta/services/archive_manager.py | 31 +++++++ .../tool_executor/core_tool_executor.py | 66 ++++++++----- tests/test_sdk_client.py | 85 +++++++++++++++++ 7 files changed, 255 insertions(+), 106 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 3355076b..6a03b216 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -102,6 +102,23 @@ class BaseAgent(ABC): if tool_rules_solver is not None: tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts() + # compile archive tags if there's an attached archive + from letta.services.archive_manager import ArchiveManager + + archive_manager = ArchiveManager() + archive = await archive_manager.get_default_archive_for_agent_async( + agent_id=agent_state.id, + actor=self.actor, + ) + + if archive: + archive_tags = await self.passage_manager.get_unique_tags_for_archive_async( + archive_id=archive.id, + actor=self.actor, + ) + else: + archive_tags = None + # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this curr_system_message = in_context_messages[0] curr_system_message_text = curr_system_message.content[0].text @@ -149,6 +166,7 @@ class BaseAgent(ABC): timezone=agent_state.timezone, previous_message_count=num_messages - len(in_context_messages), archival_memory_size=num_archival_memories, + archive_tags=archive_tags, ) diff = united_diff(curr_system_message_text, new_system_message_str) diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 9fcb2fdb..53f9e180 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from letta.agent import Agent from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING @@ -63,70 +63,36 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O return results_str -async def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: +async def archival_memory_insert(self: "Agent", content: str, tags: Optional[list[str]] = None) -> Optional[str]: """ Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. Args: content (str): Content to write to the memory. All unicode (including emojis) are supported. + tags (Optional[list[str]]): Optional list of tags to associate with this memory for better organization and filtering. Returns: Optional[str]: None is always returned as this function does not produce a response. """ - await self.passage_manager.insert_passage( - agent_state=self.agent_state, - text=content, - actor=self.user, - ) - self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user, force=True) - return None + raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.") -async def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]: +async def archival_memory_search( + self: "Agent", query: str, tags: Optional[list[str]] = None, tag_match_mode: Literal["any", "all"] = "any", top_k: Optional[int] = None +) -> Optional[str]: """ Search archival memory using semantic (embedding-based) search. Args: - query (str): String to search for. - page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). - start (Optional[int]): Starting index for the search results. Defaults to 0. + query (str): String to search for using semantic similarity. + tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned. + tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any". + top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified. Returns: - str: Query result string + str: Query result string containing matching passages with timestamps and content. """ - - from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - - if page is None or (isinstance(page, str) and page.lower().strip() == "none"): - page = 0 - try: - page = int(page) - except: - raise ValueError("'page' argument must be an integer") - count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - - try: - # Get results using passage manager - all_results = await self.agent_manager.query_agent_passages_async( - actor=self.user, - agent_id=self.agent_state.id, - query_text=query, - limit=count + start, # Request enough results to handle offset - embedding_config=self.agent_state.embedding_config, - embed_query=True, - ) - - # Apply pagination - end = min(count + start, len(all_results)) - paged_results = all_results[start:end] - - # Format results to match previous implementation - formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results] - - return formatted_results, len(formatted_results) - - except Exception as e: - raise e + raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.") def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 3519dd60..69884579 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -95,11 +95,9 @@ class TurbopufferClient: organization_ids = [] archive_ids = [] created_ats = [] + tags_arrays = [] # Store tags as arrays passages = [] - # prepare tag columns - tag_columns = {tag: [] for tag in (tags or [])} - for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)): passage_id = passage_ids[idx] @@ -110,10 +108,7 @@ class TurbopufferClient: organization_ids.append(organization_id) archive_ids.append(archive_id) created_ats.append(timestamp) - - # append tag values - for tag in tag_columns: - tag_columns[tag].append(True) + tags_arrays.append(tags or []) # Store tags as array # Create PydanticPassage object passage = PydanticPassage( @@ -123,6 +118,7 @@ class TurbopufferClient: archive_id=archive_id, created_at=timestamp, metadata_={}, + tags=tags or [], # Include tags in the passage embedding=embedding, embedding_config=None, # Will be set by caller if needed ) @@ -136,11 +132,9 @@ class TurbopufferClient: "organization_id": organization_ids, "archive_id": archive_ids, "created_at": created_ats, + "tags": tags_arrays, # Add tags as array column } - # add tag columns if any - upsert_columns.update(tag_columns) - try: # Use AsyncTurbopuffer as a context manager for proper resource cleanup async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: @@ -193,16 +187,21 @@ class TurbopufferClient: from turbopuffer import AsyncTurbopuffer from turbopuffer.types import QueryParam - # validate inputs based on search mode - if search_mode == "vector" and query_embedding is None: - raise ValueError("query_embedding is required for vector search mode") - if search_mode == "fts" and query_text is None: - raise ValueError("query_text is required for FTS search mode") - if search_mode == "hybrid": - if query_embedding is None or query_text is None: - raise ValueError("Both query_embedding and query_text are required for hybrid search mode") - if search_mode not in ["vector", "fts", "hybrid"]: - raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', or 'hybrid'") + # Check if we should fallback to timestamp-based retrieval + if query_embedding is None and query_text is None: + # Fallback to retrieving most recent passages when no search query is provided + search_mode = "timestamp" + else: + # validate inputs based on search mode + if search_mode == "vector" and query_embedding is None: + raise ValueError("query_embedding is required for vector search mode") + if search_mode == "fts" and query_text is None: + raise ValueError("query_text is required for FTS search mode") + if search_mode == "hybrid": + if query_embedding is None or query_text is None: + raise ValueError("Both query_embedding and query_text are required for hybrid search mode") + if search_mode not in ["vector", "fts", "hybrid"]: + raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', or 'hybrid'") namespace_name = self._get_namespace_name(archive_id) @@ -213,23 +212,38 @@ class TurbopufferClient: # build tag filter conditions tag_filter = None if tags: - tag_conditions = [] - for tag in tags: - tag_conditions.append((tag, "Eq", True)) - - if len(tag_conditions) == 1: - tag_filter = tag_conditions[0] - elif tag_match_mode == TagMatchMode.ALL: - tag_filter = ("And", tag_conditions) + if tag_match_mode == TagMatchMode.ALL: + # For ALL mode, need to check each tag individually with Contains + tag_conditions = [] + for tag in tags: + tag_conditions.append(("tags", "Contains", tag)) + if len(tag_conditions) == 1: + tag_filter = tag_conditions[0] + else: + tag_filter = ("And", tag_conditions) else: # tag_match_mode == TagMatchMode.ANY - tag_filter = ("Or", tag_conditions) + # For ANY mode, use ContainsAny to match any of the tags + tag_filter = ("tags", "ContainsAny", tags) - if search_mode == "vector": + if search_mode == "timestamp": + # Fallback: retrieve most recent passages by timestamp + query_params = { + "rank_by": ("created_at", "desc"), # Order by created_at in descending order + "top_k": top_k, + "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], + } + if tag_filter: + query_params["filters"] = tag_filter + + result = await namespace.query(**query_params) + return self._process_single_query_results(result, archive_id, tags) + + elif search_mode == "vector": # single vector search query query_params = { "rank_by": ("vector", "ANN", query_embedding), "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at"], + "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } if tag_filter: query_params["filters"] = tag_filter @@ -242,7 +256,7 @@ class TurbopufferClient: query_params = { "rank_by": ("text", "BM25", query_text), "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at"], + "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } if tag_filter: query_params["filters"] = tag_filter @@ -258,7 +272,7 @@ class TurbopufferClient: vector_query = { "rank_by": ("vector", "ANN", query_embedding), "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at"], + "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } if tag_filter: vector_query["filters"] = tag_filter @@ -268,7 +282,7 @@ class TurbopufferClient: fts_query = { "rank_by": ("text", "BM25", query_text), "top_k": top_k, - "include_attributes": ["text", "organization_id", "archive_id", "created_at"], + "include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"], } if tag_filter: fts_query["filters"] = tag_filter @@ -295,10 +309,11 @@ class TurbopufferClient: passages_with_scores = [] for row in result.rows: - # Build metadata including any tag filters that were applied + # Extract tags from the result row + passage_tags = getattr(row, "tags", []) or [] + + # Build metadata metadata = {} - if tags: - metadata["applied_tags"] = tags # Create a passage with minimal fields - embeddings are not returned from Turbopuffer passage = PydanticPassage( @@ -307,7 +322,8 @@ class TurbopufferClient: organization_id=getattr(row, "organization_id", None), archive_id=archive_id, # use the archive_id from the query created_at=getattr(row, "created_at", None), - metadata_=metadata, # Include tag filters in metadata + metadata_=metadata, + tags=passage_tags, # Set the actual tags from the passage # Set required fields to empty/default values since we don't store embeddings embedding=[], # Empty embedding since we don't return it from Turbopuffer embedding_config=None, # No embedding config needed for retrieved passages diff --git a/letta/prompts/prompt_generator.py b/letta/prompts/prompt_generator.py index 267cf7c1..c0e276ca 100644 --- a/letta/prompts/prompt_generator.py +++ b/letta/prompts/prompt_generator.py @@ -17,6 +17,7 @@ class PromptGenerator: timezone: str, previous_message_count: int = 0, archival_memory_size: Optional[int] = 0, + archive_tags: Optional[List[str]] = None, ) -> str: """ Generate a memory metadata block for the agent's system prompt. @@ -31,6 +32,7 @@ class PromptGenerator: timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles') previous_message_count: Number of messages in recall memory (conversation history) archival_memory_size: Number of items in archival memory (long-term storage) + archive_tags: List of unique tags available in archival memory Returns: A formatted string containing the memory metadata block with XML-style tags @@ -41,6 +43,7 @@ class PromptGenerator: - Memory blocks were last modified: 2024-01-15 09:00 AM PST - 42 previous messages between you and the user are stored in recall memory (use tools to access them) - 156 total memories you created are stored in archival memory (use tools to access them) + - Available archival memory tags: project_x, meeting_notes, research, ideas """ # Put the timestamp in the local timezone (mimicking get_local_time()) @@ -60,6 +63,10 @@ class PromptGenerator: f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)" ) + # Include archive tags if available + if archive_tags: + metadata_lines.append(f"- Available archival memory tags: {', '.join(archive_tags)}") + metadata_lines.append("") memory_metadata_block = "\n".join(metadata_lines) return memory_metadata_block @@ -90,6 +97,7 @@ class PromptGenerator: template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", previous_message_count: int = 0, archival_memory_size: int = 0, + archive_tags: Optional[List[str]] = None, ) -> str: """Prepare the final/full system message that will be fed into the LLM API @@ -114,6 +122,7 @@ class PromptGenerator: previous_message_count=previous_message_count, archival_memory_size=archival_memory_size, timezone=timezone, + archive_tags=archive_tags, ) full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index d006021a..5d48afe7 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -146,6 +146,37 @@ class ArchiveManager: session.add(archives_agents) await session.commit() + @enforce_types + async def get_default_archive_for_agent_async( + self, + agent_id: str, + actor: PydanticUser = None, + ) -> Optional[PydanticArchive]: + """Get the agent's default archive if it exists, return None otherwise.""" + # First check if agent has any archives + from letta.services.agent_manager import AgentManager + + agent_manager = AgentManager() + + archive_ids = await agent_manager.get_agent_archive_ids_async( + agent_id=agent_id, + actor=actor, + ) + + if archive_ids: + # TODO: Remove this check once we support multiple archives per agent + if len(archive_ids) > 1: + raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported") + # Get the archive + archive = await self.get_archive_by_id_async( + archive_id=archive_ids[0], + actor=actor, + ) + return archive + + # No archive found, return None + return None + @enforce_types async def get_or_create_default_archive_for_agent_async( self, diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index bc5438b4..8db4ea6c 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -1,5 +1,6 @@ import math -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional +from zoneinfo import ZoneInfo from letta.constants import ( CORE_MEMORY_LINE_NUMBER_WARNING, @@ -9,6 +10,7 @@ from letta.constants import ( ) from letta.helpers.json_helpers import json_dumps from letta.schemas.agent import AgentState +from letta.schemas.enums import TagMatchMode from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult @@ -117,57 +119,78 @@ class LettaCoreToolExecutor(ToolExecutor): return results_str async def archival_memory_search( - self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0, start: Optional[int] = 0 + self, + agent_state: AgentState, + actor: User, + query: str, + tags: Optional[list[str]] = None, + tag_match_mode: Literal["any", "all"] = "any", + top_k: Optional[int] = None, ) -> Optional[str]: """ Search archival memory using semantic (embedding-based) search. Args: - query (str): String to search for. - page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). - start (Optional[int]): Starting index for the search results. Defaults to 0. + query (str): String to search for using semantic similarity. + tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned. + tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any". + top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified. Returns: - str: Query result string + str: Query result string containing matching passages with timestamps, content, and tags. """ - if page is None or (isinstance(page, str) and page.lower().strip() == "none"): - page = 0 try: - page = int(page) - except: - raise ValueError("'page' argument must be an integer") + # Convert string to TagMatchMode enum + tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL - count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - - try: # Get results using passage manager + limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE all_results = await self.agent_manager.query_agent_passages_async( actor=actor, agent_id=agent_state.id, query_text=query, - limit=count + start, # Request enough results to handle offset + limit=limit, embedding_config=agent_state.embedding_config, embed_query=True, + tags=tags, + tag_match_mode=tag_mode, ) - # Apply pagination - end = min(count + start, len(all_results)) - paged_results = all_results[start:end] + # Format results to include tags with friendly timestamps + formatted_results = [] + for result in all_results: + # Format timestamp in agent's timezone if available + timestamp = result.created_at + if timestamp and agent_state.timezone: + try: + # Convert to agent's timezone + tz = ZoneInfo(agent_state.timezone) + local_time = timestamp.astimezone(tz) + # Format as ISO string with timezone + formatted_timestamp = local_time.isoformat() + except Exception: + # Fallback to ISO format if timezone conversion fails + formatted_timestamp = str(timestamp) + else: + # Use ISO format if no timezone is set + formatted_timestamp = str(timestamp) if timestamp else "Unknown" - # Format results to match previous implementation - formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results] + formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []}) return formatted_results, len(formatted_results) except Exception as e: raise e - async def archival_memory_insert(self, agent_state: AgentState, actor: User, content: str) -> Optional[str]: + async def archival_memory_insert( + self, agent_state: AgentState, actor: User, content: str, tags: Optional[list[str]] = None + ) -> Optional[str]: """ Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. Args: content (str): Content to write to the memory. All unicode (including emojis) are supported. + tags (Optional[list[str]]): Optional list of tags to associate with this memory for better organization and filtering. Returns: Optional[str]: None is always returned as this function does not produce a response. @@ -176,6 +199,7 @@ class LettaCoreToolExecutor(ToolExecutor): agent_state=agent_state, text=content, actor=actor, + tags=tags, ) await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True) return None diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 480bf221..69c0b87b 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -1224,6 +1224,91 @@ def test_preview_payload(client: LettaSDKClient): client.agents.delete(agent_id=temp_agent.id) +# TODO: Re-enable +# def test_archive_tags_in_system_prompt(client: LettaSDKClient): +# """Test that archive tags are correctly compiled into the system prompt.""" +# # Create a test agent +# temp_agent = client.agents.create( +# memory_blocks=[ +# CreateBlock( +# label="human", +# value="username: test_user", +# ), +# ], +# model="openai/gpt-4o-mini", +# embedding="openai/text-embedding-3-small", +# ) +# +# try: +# # Add passages with different tags to the agent's archive +# test_tags = ["project_alpha", "meeting_notes", "research", "ideas", "todo_items"] +# +# # Create passages with tags +# for i, tag in enumerate(test_tags): +# client.agents.passages.create( +# agent_id=temp_agent.id, +# text=f"Test passage {i} with tag {tag}", +# tags=[tag] +# ) +# +# # Also create a passage with multiple tags +# client.agents.passages.create( +# agent_id=temp_agent.id, +# text="Passage with multiple tags", +# tags=["multi_tag_1", "multi_tag_2"] +# ) +# +# # Get the raw payload to check the system prompt +# payload = client.agents.messages.preview_raw_payload( +# agent_id=temp_agent.id, +# request=LettaRequest( +# messages=[ +# MessageCreate( +# role="user", +# content=[ +# TextContent( +# text="Hello", +# ) +# ], +# ) +# ], +# ), +# ) +# +# # Extract the system message +# assert isinstance(payload, dict) +# assert "messages" in payload +# assert len(payload["messages"]) > 0 +# +# system_message = payload["messages"][0] +# assert system_message["role"] == "system" +# system_content = system_message["content"] +# +# # Check that the archive tags are included in the metadata +# assert "Available archival memory tags:" in system_content +# +# # Check that all unique tags are present +# all_unique_tags = set(test_tags + ["multi_tag_1", "multi_tag_2"]) +# for tag in all_unique_tags: +# assert tag in system_content, f"Tag '{tag}' not found in system prompt" +# +# # Verify the tags are in the memory_metadata section +# assert "" in system_content +# assert "" in system_content +# +# # Extract the metadata section to verify format +# metadata_start = system_content.index("") +# metadata_end = system_content.index("") +# metadata_section = system_content[metadata_start:metadata_end] +# +# # Verify the tags line is properly formatted +# assert "- Available archival memory tags:" in metadata_section +# +# finally: +# # Clean up the agent +# client.agents.delete(agent_id=temp_agent.id) + + def test_agent_tools_list(client: LettaSDKClient): """Test the optimized agent tools list endpoint for correctness.""" # Create a test agent