feat: Allow agent archival tools to insert/search with tags [LET-4072] (#4300)
* Finish modifying archival memory tools * Add tags * Add disabled test
This commit is contained in:
@@ -102,6 +102,23 @@ class BaseAgent(ABC):
|
||||
if tool_rules_solver is not None:
|
||||
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
|
||||
|
||||
# compile archive tags if there's an attached archive
|
||||
from letta.services.archive_manager import ArchiveManager
|
||||
|
||||
archive_manager = ArchiveManager()
|
||||
archive = await archive_manager.get_default_archive_for_agent_async(
|
||||
agent_id=agent_state.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
if archive:
|
||||
archive_tags = await self.passage_manager.get_unique_tags_for_archive_async(
|
||||
archive_id=archive.id,
|
||||
actor=self.actor,
|
||||
)
|
||||
else:
|
||||
archive_tags = None
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
@@ -149,6 +166,7 @@ class BaseAgent(ABC):
|
||||
timezone=agent_state.timezone,
|
||||
previous_message_count=num_messages - len(in_context_messages),
|
||||
archival_memory_size=num_archival_memories,
|
||||
archive_tags=archive_tags,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
|
||||
@@ -63,70 +63,36 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
|
||||
return results_str
|
||||
|
||||
|
||||
async def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
||||
async def archival_memory_insert(self: "Agent", content: str, tags: Optional[list[str]] = None) -> Optional[str]:
|
||||
"""
|
||||
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
||||
|
||||
Args:
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
tags (Optional[list[str]]): Optional list of tags to associate with this memory for better organization and filtering.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
await self.passage_manager.insert_passage(
|
||||
agent_state=self.agent_state,
|
||||
text=content,
|
||||
actor=self.user,
|
||||
)
|
||||
self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user, force=True)
|
||||
return None
|
||||
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
|
||||
|
||||
|
||||
async def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]:
|
||||
async def archival_memory_search(
|
||||
self: "Agent", query: str, tags: Optional[list[str]] = None, tag_match_mode: Literal["any", "all"] = "any", top_k: Optional[int] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search.
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
start (Optional[int]): Starting index for the search results. Defaults to 0.
|
||||
query (str): String to search for using semantic similarity.
|
||||
tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned.
|
||||
tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any".
|
||||
top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified.
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
str: Query result string containing matching passages with timestamps and content.
|
||||
"""
|
||||
|
||||
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
|
||||
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
|
||||
page = 0
|
||||
try:
|
||||
page = int(page)
|
||||
except:
|
||||
raise ValueError("'page' argument must be an integer")
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
|
||||
try:
|
||||
# Get results using passage manager
|
||||
all_results = await self.agent_manager.query_agent_passages_async(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
query_text=query,
|
||||
limit=count + start, # Request enough results to handle offset
|
||||
embedding_config=self.agent_state.embedding_config,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
end = min(count + start, len(all_results))
|
||||
paged_results = all_results[start:end]
|
||||
|
||||
# Format results to match previous implementation
|
||||
formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results]
|
||||
|
||||
return formatted_results, len(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise NotImplementedError("This should never be invoked directly. Contact Letta if you see this error message.")
|
||||
|
||||
|
||||
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore
|
||||
|
||||
@@ -95,11 +95,9 @@ class TurbopufferClient:
|
||||
organization_ids = []
|
||||
archive_ids = []
|
||||
created_ats = []
|
||||
tags_arrays = [] # Store tags as arrays
|
||||
passages = []
|
||||
|
||||
# prepare tag columns
|
||||
tag_columns = {tag: [] for tag in (tags or [])}
|
||||
|
||||
for idx, (text, embedding) in enumerate(zip(text_chunks, embeddings)):
|
||||
passage_id = passage_ids[idx]
|
||||
|
||||
@@ -110,10 +108,7 @@ class TurbopufferClient:
|
||||
organization_ids.append(organization_id)
|
||||
archive_ids.append(archive_id)
|
||||
created_ats.append(timestamp)
|
||||
|
||||
# append tag values
|
||||
for tag in tag_columns:
|
||||
tag_columns[tag].append(True)
|
||||
tags_arrays.append(tags or []) # Store tags as array
|
||||
|
||||
# Create PydanticPassage object
|
||||
passage = PydanticPassage(
|
||||
@@ -123,6 +118,7 @@ class TurbopufferClient:
|
||||
archive_id=archive_id,
|
||||
created_at=timestamp,
|
||||
metadata_={},
|
||||
tags=tags or [], # Include tags in the passage
|
||||
embedding=embedding,
|
||||
embedding_config=None, # Will be set by caller if needed
|
||||
)
|
||||
@@ -136,11 +132,9 @@ class TurbopufferClient:
|
||||
"organization_id": organization_ids,
|
||||
"archive_id": archive_ids,
|
||||
"created_at": created_ats,
|
||||
"tags": tags_arrays, # Add tags as array column
|
||||
}
|
||||
|
||||
# add tag columns if any
|
||||
upsert_columns.update(tag_columns)
|
||||
|
||||
try:
|
||||
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
@@ -193,16 +187,21 @@ class TurbopufferClient:
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
from turbopuffer.types import QueryParam
|
||||
|
||||
# validate inputs based on search mode
|
||||
if search_mode == "vector" and query_embedding is None:
|
||||
raise ValueError("query_embedding is required for vector search mode")
|
||||
if search_mode == "fts" and query_text is None:
|
||||
raise ValueError("query_text is required for FTS search mode")
|
||||
if search_mode == "hybrid":
|
||||
if query_embedding is None or query_text is None:
|
||||
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
|
||||
if search_mode not in ["vector", "fts", "hybrid"]:
|
||||
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', or 'hybrid'")
|
||||
# Check if we should fallback to timestamp-based retrieval
|
||||
if query_embedding is None and query_text is None:
|
||||
# Fallback to retrieving most recent passages when no search query is provided
|
||||
search_mode = "timestamp"
|
||||
else:
|
||||
# validate inputs based on search mode
|
||||
if search_mode == "vector" and query_embedding is None:
|
||||
raise ValueError("query_embedding is required for vector search mode")
|
||||
if search_mode == "fts" and query_text is None:
|
||||
raise ValueError("query_text is required for FTS search mode")
|
||||
if search_mode == "hybrid":
|
||||
if query_embedding is None or query_text is None:
|
||||
raise ValueError("Both query_embedding and query_text are required for hybrid search mode")
|
||||
if search_mode not in ["vector", "fts", "hybrid"]:
|
||||
raise ValueError(f"Invalid search_mode: {search_mode}. Must be 'vector', 'fts', or 'hybrid'")
|
||||
|
||||
namespace_name = self._get_namespace_name(archive_id)
|
||||
|
||||
@@ -213,23 +212,38 @@ class TurbopufferClient:
|
||||
# build tag filter conditions
|
||||
tag_filter = None
|
||||
if tags:
|
||||
tag_conditions = []
|
||||
for tag in tags:
|
||||
tag_conditions.append((tag, "Eq", True))
|
||||
|
||||
if len(tag_conditions) == 1:
|
||||
tag_filter = tag_conditions[0]
|
||||
elif tag_match_mode == TagMatchMode.ALL:
|
||||
tag_filter = ("And", tag_conditions)
|
||||
if tag_match_mode == TagMatchMode.ALL:
|
||||
# For ALL mode, need to check each tag individually with Contains
|
||||
tag_conditions = []
|
||||
for tag in tags:
|
||||
tag_conditions.append(("tags", "Contains", tag))
|
||||
if len(tag_conditions) == 1:
|
||||
tag_filter = tag_conditions[0]
|
||||
else:
|
||||
tag_filter = ("And", tag_conditions)
|
||||
else: # tag_match_mode == TagMatchMode.ANY
|
||||
tag_filter = ("Or", tag_conditions)
|
||||
# For ANY mode, use ContainsAny to match any of the tags
|
||||
tag_filter = ("tags", "ContainsAny", tags)
|
||||
|
||||
if search_mode == "vector":
|
||||
if search_mode == "timestamp":
|
||||
# Fallback: retrieve most recent passages by timestamp
|
||||
query_params = {
|
||||
"rank_by": ("created_at", "desc"), # Order by created_at in descending order
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
query_params["filters"] = tag_filter
|
||||
|
||||
result = await namespace.query(**query_params)
|
||||
return self._process_single_query_results(result, archive_id, tags)
|
||||
|
||||
elif search_mode == "vector":
|
||||
# single vector search query
|
||||
query_params = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
query_params["filters"] = tag_filter
|
||||
@@ -242,7 +256,7 @@ class TurbopufferClient:
|
||||
query_params = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
query_params["filters"] = tag_filter
|
||||
@@ -258,7 +272,7 @@ class TurbopufferClient:
|
||||
vector_query = {
|
||||
"rank_by": ("vector", "ANN", query_embedding),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
vector_query["filters"] = tag_filter
|
||||
@@ -268,7 +282,7 @@ class TurbopufferClient:
|
||||
fts_query = {
|
||||
"rank_by": ("text", "BM25", query_text),
|
||||
"top_k": top_k,
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at"],
|
||||
"include_attributes": ["text", "organization_id", "archive_id", "created_at", "tags"],
|
||||
}
|
||||
if tag_filter:
|
||||
fts_query["filters"] = tag_filter
|
||||
@@ -295,10 +309,11 @@ class TurbopufferClient:
|
||||
passages_with_scores = []
|
||||
|
||||
for row in result.rows:
|
||||
# Build metadata including any tag filters that were applied
|
||||
# Extract tags from the result row
|
||||
passage_tags = getattr(row, "tags", []) or []
|
||||
|
||||
# Build metadata
|
||||
metadata = {}
|
||||
if tags:
|
||||
metadata["applied_tags"] = tags
|
||||
|
||||
# Create a passage with minimal fields - embeddings are not returned from Turbopuffer
|
||||
passage = PydanticPassage(
|
||||
@@ -307,7 +322,8 @@ class TurbopufferClient:
|
||||
organization_id=getattr(row, "organization_id", None),
|
||||
archive_id=archive_id, # use the archive_id from the query
|
||||
created_at=getattr(row, "created_at", None),
|
||||
metadata_=metadata, # Include tag filters in metadata
|
||||
metadata_=metadata,
|
||||
tags=passage_tags, # Set the actual tags from the passage
|
||||
# Set required fields to empty/default values since we don't store embeddings
|
||||
embedding=[], # Empty embedding since we don't return it from Turbopuffer
|
||||
embedding_config=None, # No embedding config needed for retrieved passages
|
||||
|
||||
@@ -17,6 +17,7 @@ class PromptGenerator:
|
||||
timezone: str,
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: Optional[int] = 0,
|
||||
archive_tags: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a memory metadata block for the agent's system prompt.
|
||||
@@ -31,6 +32,7 @@ class PromptGenerator:
|
||||
timezone: The timezone to use for formatting timestamps (e.g., 'America/Los_Angeles')
|
||||
previous_message_count: Number of messages in recall memory (conversation history)
|
||||
archival_memory_size: Number of items in archival memory (long-term storage)
|
||||
archive_tags: List of unique tags available in archival memory
|
||||
|
||||
Returns:
|
||||
A formatted string containing the memory metadata block with XML-style tags
|
||||
@@ -41,6 +43,7 @@ class PromptGenerator:
|
||||
- Memory blocks were last modified: 2024-01-15 09:00 AM PST
|
||||
- 42 previous messages between you and the user are stored in recall memory (use tools to access them)
|
||||
- 156 total memories you created are stored in archival memory (use tools to access them)
|
||||
- Available archival memory tags: project_x, meeting_notes, research, ideas
|
||||
</memory_metadata>
|
||||
"""
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
@@ -60,6 +63,10 @@ class PromptGenerator:
|
||||
f"- {archival_memory_size} total memories you created are stored in archival memory (use tools to access them)"
|
||||
)
|
||||
|
||||
# Include archive tags if available
|
||||
if archive_tags:
|
||||
metadata_lines.append(f"- Available archival memory tags: {', '.join(archive_tags)}")
|
||||
|
||||
metadata_lines.append("</memory_metadata>")
|
||||
memory_metadata_block = "\n".join(metadata_lines)
|
||||
return memory_metadata_block
|
||||
@@ -90,6 +97,7 @@ class PromptGenerator:
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
archive_tags: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
@@ -114,6 +122,7 @@ class PromptGenerator:
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
timezone=timezone,
|
||||
archive_tags=archive_tags,
|
||||
)
|
||||
|
||||
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
|
||||
|
||||
@@ -146,6 +146,37 @@ class ArchiveManager:
|
||||
session.add(archives_agents)
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
async def get_default_archive_for_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
actor: PydanticUser = None,
|
||||
) -> Optional[PydanticArchive]:
|
||||
"""Get the agent's default archive if it exists, return None otherwise."""
|
||||
# First check if agent has any archives
|
||||
from letta.services.agent_manager import AgentManager
|
||||
|
||||
agent_manager = AgentManager()
|
||||
|
||||
archive_ids = await agent_manager.get_agent_archive_ids_async(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if archive_ids:
|
||||
# TODO: Remove this check once we support multiple archives per agent
|
||||
if len(archive_ids) > 1:
|
||||
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported")
|
||||
# Get the archive
|
||||
archive = await self.get_archive_by_id_async(
|
||||
archive_id=archive_ids[0],
|
||||
actor=actor,
|
||||
)
|
||||
return archive
|
||||
|
||||
# No archive found, return None
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def get_or_create_default_archive_for_agent_async(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from letta.constants import (
|
||||
CORE_MEMORY_LINE_NUMBER_WARNING,
|
||||
@@ -9,6 +10,7 @@ from letta.constants import (
|
||||
)
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -117,57 +119,78 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
return results_str
|
||||
|
||||
async def archival_memory_search(
|
||||
self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0, start: Optional[int] = 0
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
query: str,
|
||||
tags: Optional[list[str]] = None,
|
||||
tag_match_mode: Literal["any", "all"] = "any",
|
||||
top_k: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search.
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
start (Optional[int]): Starting index for the search results. Defaults to 0.
|
||||
query (str): String to search for using semantic similarity.
|
||||
tags (Optional[list[str]]): Optional list of tags to filter search results. Only passages with these tags will be returned.
|
||||
tag_match_mode (Literal["any", "all"]): How to match tags - "any" to match passages with any of the tags, "all" to match only passages with all tags. Defaults to "any".
|
||||
top_k (Optional[int]): Maximum number of results to return. Uses system default if not specified.
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
str: Query result string containing matching passages with timestamps, content, and tags.
|
||||
"""
|
||||
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
|
||||
page = 0
|
||||
try:
|
||||
page = int(page)
|
||||
except:
|
||||
raise ValueError("'page' argument must be an integer")
|
||||
# Convert string to TagMatchMode enum
|
||||
tag_mode = TagMatchMode.ANY if tag_match_mode == "any" else TagMatchMode.ALL
|
||||
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
|
||||
try:
|
||||
# Get results using passage manager
|
||||
limit = top_k if top_k is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
all_results = await self.agent_manager.query_agent_passages_async(
|
||||
actor=actor,
|
||||
agent_id=agent_state.id,
|
||||
query_text=query,
|
||||
limit=count + start, # Request enough results to handle offset
|
||||
limit=limit,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
embed_query=True,
|
||||
tags=tags,
|
||||
tag_match_mode=tag_mode,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
end = min(count + start, len(all_results))
|
||||
paged_results = all_results[start:end]
|
||||
# Format results to include tags with friendly timestamps
|
||||
formatted_results = []
|
||||
for result in all_results:
|
||||
# Format timestamp in agent's timezone if available
|
||||
timestamp = result.created_at
|
||||
if timestamp and agent_state.timezone:
|
||||
try:
|
||||
# Convert to agent's timezone
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
local_time = timestamp.astimezone(tz)
|
||||
# Format as ISO string with timezone
|
||||
formatted_timestamp = local_time.isoformat()
|
||||
except Exception:
|
||||
# Fallback to ISO format if timezone conversion fails
|
||||
formatted_timestamp = str(timestamp)
|
||||
else:
|
||||
# Use ISO format if no timezone is set
|
||||
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
|
||||
|
||||
# Format results to match previous implementation
|
||||
formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results]
|
||||
formatted_results.append({"timestamp": formatted_timestamp, "content": result.text, "tags": result.tags or []})
|
||||
|
||||
return formatted_results, len(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def archival_memory_insert(self, agent_state: AgentState, actor: User, content: str) -> Optional[str]:
|
||||
async def archival_memory_insert(
|
||||
self, agent_state: AgentState, actor: User, content: str, tags: Optional[list[str]] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
||||
|
||||
Args:
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
tags (Optional[list[str]]): Optional list of tags to associate with this memory for better organization and filtering.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
@@ -176,6 +199,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
agent_state=agent_state,
|
||||
text=content,
|
||||
actor=actor,
|
||||
tags=tags,
|
||||
)
|
||||
await self.agent_manager.rebuild_system_prompt_async(agent_id=agent_state.id, actor=actor, force=True)
|
||||
return None
|
||||
|
||||
@@ -1224,6 +1224,91 @@ def test_preview_payload(client: LettaSDKClient):
|
||||
client.agents.delete(agent_id=temp_agent.id)
|
||||
|
||||
|
||||
# TODO: Re-enable
|
||||
# def test_archive_tags_in_system_prompt(client: LettaSDKClient):
|
||||
# """Test that archive tags are correctly compiled into the system prompt."""
|
||||
# # Create a test agent
|
||||
# temp_agent = client.agents.create(
|
||||
# memory_blocks=[
|
||||
# CreateBlock(
|
||||
# label="human",
|
||||
# value="username: test_user",
|
||||
# ),
|
||||
# ],
|
||||
# model="openai/gpt-4o-mini",
|
||||
# embedding="openai/text-embedding-3-small",
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # Add passages with different tags to the agent's archive
|
||||
# test_tags = ["project_alpha", "meeting_notes", "research", "ideas", "todo_items"]
|
||||
#
|
||||
# # Create passages with tags
|
||||
# for i, tag in enumerate(test_tags):
|
||||
# client.agents.passages.create(
|
||||
# agent_id=temp_agent.id,
|
||||
# text=f"Test passage {i} with tag {tag}",
|
||||
# tags=[tag]
|
||||
# )
|
||||
#
|
||||
# # Also create a passage with multiple tags
|
||||
# client.agents.passages.create(
|
||||
# agent_id=temp_agent.id,
|
||||
# text="Passage with multiple tags",
|
||||
# tags=["multi_tag_1", "multi_tag_2"]
|
||||
# )
|
||||
#
|
||||
# # Get the raw payload to check the system prompt
|
||||
# payload = client.agents.messages.preview_raw_payload(
|
||||
# agent_id=temp_agent.id,
|
||||
# request=LettaRequest(
|
||||
# messages=[
|
||||
# MessageCreate(
|
||||
# role="user",
|
||||
# content=[
|
||||
# TextContent(
|
||||
# text="Hello",
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
# ],
|
||||
# ),
|
||||
# )
|
||||
#
|
||||
# # Extract the system message
|
||||
# assert isinstance(payload, dict)
|
||||
# assert "messages" in payload
|
||||
# assert len(payload["messages"]) > 0
|
||||
#
|
||||
# system_message = payload["messages"][0]
|
||||
# assert system_message["role"] == "system"
|
||||
# system_content = system_message["content"]
|
||||
#
|
||||
# # Check that the archive tags are included in the metadata
|
||||
# assert "Available archival memory tags:" in system_content
|
||||
#
|
||||
# # Check that all unique tags are present
|
||||
# all_unique_tags = set(test_tags + ["multi_tag_1", "multi_tag_2"])
|
||||
# for tag in all_unique_tags:
|
||||
# assert tag in system_content, f"Tag '{tag}' not found in system prompt"
|
||||
#
|
||||
# # Verify the tags are in the memory_metadata section
|
||||
# assert "<memory_metadata>" in system_content
|
||||
# assert "</memory_metadata>" in system_content
|
||||
#
|
||||
# # Extract the metadata section to verify format
|
||||
# metadata_start = system_content.index("<memory_metadata>")
|
||||
# metadata_end = system_content.index("</memory_metadata>")
|
||||
# metadata_section = system_content[metadata_start:metadata_end]
|
||||
#
|
||||
# # Verify the tags line is properly formatted
|
||||
# assert "- Available archival memory tags:" in metadata_section
|
||||
#
|
||||
# finally:
|
||||
# # Clean up the agent
|
||||
# client.agents.delete(agent_id=temp_agent.id)
|
||||
|
||||
|
||||
def test_agent_tools_list(client: LettaSDKClient):
|
||||
"""Test the optimized agent tools list endpoint for correctness."""
|
||||
# Create a test agent
|
||||
|
||||
Reference in New Issue
Block a user