feat: Allow agent archival tools to insert/search with tags [LET-4072] (#4300)

* Finish modifying archival memory tools

* Add tags

* Add disabled test
This commit is contained in:
Matthew Zhou
2025-08-29 11:55:06 -07:00
committed by GitHub
parent 04767aa4fe
commit 23b2769dc4
7 changed files with 255 additions and 106 deletions

View File

@@ -102,6 +102,23 @@ class BaseAgent(ABC):
if tool_rules_solver is not None:
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
# compile archive tags if there's an attached archive
from letta.services.archive_manager import ArchiveManager
archive_manager = ArchiveManager()
archive = await archive_manager.get_default_archive_for_agent_async(
agent_id=agent_state.id,
actor=self.actor,
)
if archive:
archive_tags = await self.passage_manager.get_unique_tags_for_archive_async(
archive_id=archive.id,
actor=self.actor,
)
else:
archive_tags = None
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_system_message_text = curr_system_message.content[0].text
@@ -149,6 +166,7 @@ class BaseAgent(ABC):
timezone=agent_state.timezone,
previous_message_count=num_messages - len(in_context_messages),
archival_memory_size=num_archival_memories,
archive_tags=archive_tags,
)
diff = united_diff(curr_system_message_text, new_system_message_str)

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Literal, Optional
from letta.agent import Agent
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
@@ -63,70 +63,36 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
return results_str
async def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
async def archival_memory_insert(self: "Agent", content: str, tags: Optional[list[str]] = None) -> Optional[str]:
"""
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
Args:
content (str): Content to write to the memory. All unicode (including emojis) are supported.
tags (Optional[list[str]]): Optional list of tags to associate with this memory for better organization and filtering.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
await self.passage_manager.insert_passage(
agent_state=self.agent_state,
text=content,
actor=self.user,
)
self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user, force=True)
return None
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
async def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]:
async def archival_memory_search(
self: "Agent", query: str, tags: Optional[list[str]] = None, tag_match_mode: Literal["any", "all"] = "any", top_k: Optional[int] = None
) -> Optional[str]:
"""
Search archival memory using semantic (embedding-based) search.
Args:
query (str): String to search for.
page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
start (Optional[int]): Starting index for the search results. Defaults to 0.
query (str): String to search for using semantic similarity.
tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned.
tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any".
top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified.
Returns:
str: Query result string
str: Query result string containing matching passages with timestamps and content.
"""
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError("'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
try:
# Get results using passage manager
all_results = await self.agent_manager.query_agent_passages_async(
actor=self.user,
agent_id=self.agent_state.id,
query_text=query,
limit=count + start, # Request enough results to handle offset
embedding_config=self.agent_state.embedding_config,
embed_query=True,
)
# Apply pagination
end = min(count + start, len(all_results))
paged_results = all_results[start:end]
# Format results to match previous implementation
formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results]
return formatted_results, len(formatted_results)
except Exception as e:
raise e
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore

View File

@@ -95,11 +95,9 @@ class TurbopufferClient:
organization_ids = []
archive_ids = []
created_ats = []
tags_arrays = [] # Store tags as arrays
passages = []
# prepare tag columns
tag_columns = {tag: [] for tag in (tags or [])}
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
passage_id = passage_ids[idx]
@@ -110,10 +108,7 @@ class TurbopufferClient:
organization_ids.append(organization_id)
archive_ids.append(archive_id)
created_ats.append(timestamp)
# append tag values
for tag in tag_columns:
tag_columns[tag].append(True)
tags_arrays.append(tags or []) # Store tags as array
# Create PydanticPassage object
passage = PydanticPassage(
@@ -123,6 +118,7 @@ class TurbopufferClient:
archive_id=archive_id,
created_at=timestamp,
metadata_={},
tags=tags or [], # Include tags in the passage
embedding=embedding,
embedding_config=None, # Will be set by caller if needed
)
@@ -136,11 +132,9 @@ class TurbopufferClient:
"organization_id": organization_ids,
"archive_id": archive_ids,
"created_at": created_ats,
"tags": tags_arrays, # Add tags as array column
}
# add tag columns if any
upsert_columns.update(tag_columns)
try:
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
@@ -193,16 +187,21 @@ class TurbopufferClient:
from turbopuffer import AsyncTurbopuffer
from turbopuffer.types import QueryParam
# validate inputs based on search mode
if search_mode == "vector" and query_embedding is None:
raise ValueError("query_embedding is required for vector search mode")
if search_mode == "fts" and query_text is None:
raise ValueError("query_text is required for FTS search mode")
if search_mode == "hybrid":
if query_embedding is None or query_text is None:
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
if search_mode not in ["vector", "fts", "hybrid"]:
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', or 'hybrid'")
# Check if we should fallback to timestamp-based retrieval
if query_embedding is None and query_text is None:
# Fallback to retrieving most recent passages when no search query is provided
search_mode = "timestamp"
else:
# validate inputs based on search mode
if search_mode == "vector" and query_embedding is None:
raise ValueError("query_embedding is required for vector search mode")
if search_mode == "fts" and query_text is None:
raise ValueError("query_text is required for FTS search mode")
if search_mode == "hybrid":
if query_embedding is None or query_text is None:
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
if search_mode not in ["vector", "fts", "hybrid"]:
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', or 'hybrid'")
namespace_name = self._get_namespace_name(archive_id)
@@ -213,23 +212,38 @@ class TurbopufferClient:
# build tag filter conditions
tag_filter = None
if tags:
tag_conditions = []
for tag in tags:
tag_conditions.append((tag, "Eq", True))
if len(tag_conditions) == 1:
tag_filter = tag_conditions[0]
elif tag_match_mode == TagMatchMode.ALL:
tag_filter = ("And", tag_conditions)
if tag_match_mode == TagMatchMode.ALL:
# For ALL mode, need to check each tag individually with Contains
tag_conditions = []
for tag in tags:
tag_conditions.append(("tags", "Contains", tag))
if len(tag_conditions) == 1:
tag_filter = tag_conditions[0]
else:
tag_filter = ("And", tag_conditions)
else: # tag_match_mode == TagMatchMode.ANY
tag_filter = ("Or", tag_conditions)
# For ANY mode, use ContainsAny to match any of the tags
tag_filter = ("tags", "ContainsAny", tags)
if search_mode == "vector":
if search_mode == "timestamp":
# Fallback: retrieve most recent passages by timestamp
query_params = {
"rank_by": ("created_at", "desc"), # Order by created_at in descending order
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
query_params["filters"] = tag_filter
result = await namespace.query(**query_params)
return self._process_single_query_results(result, archive_id, tags)
elif search_mode == "vector":
# single vector search query
query_params = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
query_params["filters"] = tag_filter
@@ -242,7 +256,7 @@ class TurbopufferClient:
query_params = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
query_params["filters"] = tag_filter
@@ -258,7 +272,7 @@ class TurbopufferClient:
vector_query = {
"rank_by": ("vector", "ANN", query_embedding),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
vector_query["filters"] = tag_filter
@@ -268,7 +282,7 @@ class TurbopufferClient:
fts_query = {
"rank_by": ("text", "BM25", query_text),
"top_k": top_k,
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
}
if tag_filter:
fts_query["filters"] = tag_filter
@@ -295,10 +309,11 @@ class TurbopufferClient:
passages_with_scores = []
for row in result.rows:
# Build metadata including any tag filters that were applied
# Extract tags from the result row
passage_tags = getattr(row, "tags", []) or []
# Build metadata
metadata = {}
if tags:
metadata["applied_tags"] = tags
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
passage = PydanticPassage(
@@ -307,7 +322,8 @@ class TurbopufferClient:
organization_id=getattr(row, "organization_id", None),
archive_id=archive_id, # use the archive_id from the query
created_at=getattr(row, "created_at", None),
metadata_=metadata, # Include tag filters in metadata
metadata_=metadata,
tags=passage_tags, # Set the actual tags from the passage
# Set required fields to empty/default values since we don't store embeddings
embedding=[], # Empty embedding since we don't return it from Turbopuffer
embedding_config=None, # No embedding config needed for retrieved passages

View File

@@ -17,6 +17,7 @@ class PromptGenerator:
timezone: str,
previous_message_count: int = 0,
archival_memory_size: Optional[int] = 0,
archive_tags: Optional[List[str]] = None,
) -> str:
"""
Generate a memory metadata block for the agent's system prompt.
@@ -31,6 +32,7 @@ class PromptGenerator:
timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
previous_message_count: Number of messages in recall memory (conversation history)
archival_memory_size: Number of items in archival memory (long-term storage)
archive_tags: List of unique tags available in archival memory
Returns:
A formatted string containing the memory metadata block with XML-style tags
@@ -41,6 +43,7 @@ class PromptGenerator:
- Memory blocks were last modified: 2024-01-15 09:00 AM PST
- 42 previous messages between you and the user are stored in recall memory (use tools to access them)
- 156 total memories you created are stored in archival memory (use tools to access them)
- Available archival memory tags: project_x, meeting_notes, research, ideas
</memory_metadata>
"""
# Put the timestamp in the local timezone (mimicking get_local_time())
@@ -60,6 +63,10 @@ class PromptGenerator:
f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)"
)
# Include archive tags if available
if archive_tags:
metadata_lines.append(f"- Available archival memory tags: {', '.join(archive_tags)}")
metadata_lines.append("</memory_metadata>")
memory_metadata_block = "\n".join(metadata_lines)
return memory_metadata_block
@@ -90,6 +97,7 @@ class PromptGenerator:
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
archive_tags: Optional[List[str]] = None,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
@@ -114,6 +122,7 @@ class PromptGenerator:
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
timezone=timezone,
archive_tags=archive_tags,
)
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string

View File

@@ -146,6 +146,37 @@ class ArchiveManager:
session.add(archives_agents)
await session.commit()
@enforce_types
async def get_default_archive_for_agent_async(
self,
agent_id: str,
actor: PydanticUser = None,
) -> Optional[PydanticArchive]:
"""Get the agent's default archive if it exists, return None otherwise."""
# First check if agent has any archives
from letta.services.agent_manager import AgentManager
agent_manager = AgentManager()
archive_ids = await agent_manager.get_agent_archive_ids_async(
agent_id=agent_id,
actor=actor,
)
if archive_ids:
# TODO: Remove this check once we support multiple archives per agent
if len(archive_ids) > 1:
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported")
# Get the archive
archive = await self.get_archive_by_id_async(
archive_id=archive_ids[0],
actor=actor,
)
return archive
# No archive found, return None
return None
@enforce_types
async def get_or_create_default_archive_for_agent_async(
self,

View File

@@ -1,5 +1,6 @@
import math
from typing import Any, Dict, Optional
from typing import Any, Dict, Literal, Optional
from zoneinfo import ZoneInfo
from letta.constants import (
CORE_MEMORY_LINE_NUMBER_WARNING,
@@ -9,6 +10,7 @@ from letta.constants import (
)
from letta.helpers.json_helpers import json_dumps
from letta.schemas.agent import AgentState
from letta.schemas.enums import TagMatchMode
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
@@ -117,57 +119,78 @@ class LettaCoreToolExecutor(ToolExecutor):
return results_str
async def archival_memory_search(
self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0, start: Optional[int] = 0
self,
agent_state: AgentState,
actor: User,
query: str,
tags: Optional[list[str]] = None,
tag_match_mode: Literal["any", "all"] = "any",
top_k: Optional[int] = None,
) -> Optional[str]:
"""
Search archival memory using semantic (embedding-based) search.
Args:
query (str): String to search for.
page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
start (Optional[int]): Starting index for the search results. Defaults to 0.
query (str): String to search for using semantic similarity.
tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned.
tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any".
top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified.
Returns:
str: Query result string
str: Query result string containing matching passages with timestamps, content, and tags.
"""
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError("'page' argument must be an integer")
# Convert string to TagMatchMode enum
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
try:
# Get results using passage manager
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
all_results = await self.agent_manager.query_agent_passages_async(
actor=actor,
agent_id=agent_state.id,
query_text=query,
limit=count + start, # Request enough results to handle offset
limit=limit,
embedding_config=agent_state.embedding_config,
embed_query=True,
tags=tags,
tag_match_mode=tag_mode,
)
# Apply pagination
end = min(count + start, len(all_results))
paged_results = all_results[start:end]
# Format results to include tags with friendly timestamps
formatted_results = []
for result in all_results:
# Format timestamp in agent's timezone if available
timestamp = result.created_at
if timestamp and agent_state.timezone:
try:
# Convert to agent's timezone
tz = ZoneInfo(agent_state.timezone)
local_time = timestamp.astimezone(tz)
# Format as ISO string with timezone
formatted_timestamp = local_time.isoformat()
except Exception:
# Fallback to ISO format if timezone conversion fails
formatted_timestamp = str(timestamp)
else:
# Use ISO format if no timezone is set
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
# Format results to match previous implementation
formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results]
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
return formatted_results, len(formatted_results)
except Exception as e:
raise e
async def archival_memory_insert(self, agent_state: AgentState, actor: User, content: str) -> Optional[str]:
async def archival_memory_insert(
self, agent_state: AgentState, actor: User, content: str, tags: Optional[list[str]] = None
) -> Optional[str]:
"""
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
Args:
content (str): Content to write to the memory. All unicode (including emojis) are supported.
tags (Optional[list[str]]): Optional list of tags to associate with this memory for better organization and filtering.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
@@ -176,6 +199,7 @@ class LettaCoreToolExecutor(ToolExecutor):
agent_state=agent_state,
text=content,
actor=actor,
tags=tags,
)
await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True)
return None

View File

@@ -1224,6 +1224,91 @@ def test_preview_payload(client: LettaSDKClient):
client.agents.delete(agent_id=temp_agent.id)
# TODO: Re-enable
# def test_archive_tags_in_system_prompt(client: LettaSDKClient):
# """Test that archive tags are correctly compiled into the system prompt."""
# # Create a test agent
# temp_agent = client.agents.create(
# memory_blocks=[
# CreateBlock(
# label="human",
# value="username: test_user",
# ),
# ],
# model="openai/gpt-4o-mini",
# embedding="openai/text-embedding-3-small",
# )
#
# try:
# # Add passages with different tags to the agent's archive
# test_tags = ["project_alpha", "meeting_notes", "research", "ideas", "todo_items"]
#
# # Create passages with tags
# for i, tag in enumerate(test_tags):
# client.agents.passages.create(
# agent_id=temp_agent.id,
# text=f"Test passage {i} with tag {tag}",
# tags=[tag]
# )
#
# # Also create a passage with multiple tags
# client.agents.passages.create(
# agent_id=temp_agent.id,
# text="Passage with multiple tags",
# tags=["multi_tag_1", "multi_tag_2"]
# )
#
# # Get the raw payload to check the system prompt
# payload = client.agents.messages.preview_raw_payload(
# agent_id=temp_agent.id,
# request=LettaRequest(
# messages=[
# MessageCreate(
# role="user",
# content=[
# TextContent(
# text="Hello",
# )
# ],
# )
# ],
# ),
# )
#
# # Extract the system message
# assert isinstance(payload, dict)
# assert "messages" in payload
# assert len(payload["messages"]) > 0
#
# system_message = payload["messages"][0]
# assert system_message["role"] == "system"
# system_content = system_message["content"]
#
# # Check that the archive tags are included in the metadata
# assert "Available archival memory tags:" in system_content
#
# # Check that all unique tags are present
# all_unique_tags = set(test_tags + ["multi_tag_1", "multi_tag_2"])
# for tag in all_unique_tags:
# assert tag in system_content, f"Tag '{tag}' not found in system prompt"
#
# # Verify the tags are in the memory_metadata section
# assert "<memory_metadata>" in system_content
# assert "</memory_metadata>" in system_content
#
# # Extract the metadata section to verify format
# metadata_start = system_content.index("<memory_metadata>")
# metadata_end = system_content.index("</memory_metadata>")
# metadata_section = system_content[metadata_start:metadata_end]
#
# # Verify the tags line is properly formatted
# assert "- Available archival memory tags:" in metadata_section
#
# finally:
# # Clean up the agent
# client.agents.delete(agent_id=temp_agent.id)
def test_agent_tools_list(client: LettaSDKClient):
"""Test the optimized agent tools list endpoint for correctness."""
# Create a test agent