feat: Modify conversation search tool to be hybrid (#4362)

* Modify conversation search functionality

* Gate the roles
This commit is contained in:
Matthew Zhou
2025-09-02 13:45:53 -07:00
committed by GitHub
parent 6b625cb039
commit 8f425aa024
3 changed files with 153 additions and 53 deletions

View File

@@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import List, Literal, Optional
from letta.agent import Agent
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
@@ -20,16 +20,39 @@ def send_message(self: "Agent", message: str) -> Optional[str]:
return None
def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
def conversation_search(
self: "Agent",
query: str,
roles: Optional[List[Literal["assistant", "user", "tool"]]] = None,
limit: Optional[int] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> Optional[str]:
"""
Search prior conversation history using case-insensitive string matching.
Search prior conversation history using hybrid search (text + semantic similarity).
Args:
query (str): String to search for.
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
query (str): String to search for using both text matching and semantic similarity.
roles (Optional[List[Literal["assistant", "user", "tool"]]]): Optional list of message roles to filter by.
limit (Optional[int]): Maximum number of results to return. Uses system default if not specified.
start_date (Optional[str]): Filter results to messages created after this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30".
end_date (Optional[str]): Filter results to messages created before this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00".
Examples:
# Search all messages
conversation_search(query="project updates")
# Search only assistant messages
conversation_search(query="error handling", roles=["assistant"])
# Search with date range
conversation_search(query="meetings", start_date="2024-01-15", end_date="2024-01-20")
# Search with limit
conversation_search(query="debugging", limit=10)
Returns:
str: Query result string
str: Query result string containing matching messages with timestamps and content.
"""
import math

View File

@@ -34,12 +34,18 @@ class MessageManager:
def _extract_message_text(self, message: PydanticMessage) -> str:
"""Extract text content from a message's complex content structure.
Only extracts text from searchable message roles (assistant, user, tool).
Args:
message: The message to extract text from
Returns:
Concatenated text content from the message
Concatenated text content from the message, or empty string for non-searchable roles
"""
# only extract text from searchable roles
if message.role not in [MessageRole.assistant, MessageRole.user, MessageRole.tool]:
return ""
if not message.content:
return ""
@@ -218,7 +224,7 @@ class MessageManager:
for msg in result:
text = self._extract_message_text(msg)
if text: # only embed messages with text content
if text: # only embed messages with text content (role filtering is handled in _extract_message_text)
message_texts.append(text)
message_ids.append(msg.id)
roles.append(msg.role)
@@ -228,15 +234,8 @@ class MessageManager:
# generate embeddings using provided config
from letta.llm_api.llm_client import LLMClient
# extract provider info from embedding_config
embedding_provider = embedding_config.get("provider", "openai")
embedding_api_key = embedding_config.get("api_key")
embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1")
embedding_client = LLMClient(
llm_provider_type=embedding_provider,
api_key=embedding_api_key,
endpoint=embedding_endpoint,
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
@@ -816,14 +815,8 @@ class MessageManager:
# generate embedding from query text
from letta.llm_api.llm_client import LLMClient
embedding_provider = embedding_config.get("provider", "openai")
embedding_api_key = embedding_config.get("api_key")
embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1")
embedding_client = LLMClient(
llm_provider_type=embedding_provider,
api_key=embedding_api_key,
endpoint=embedding_endpoint,
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)

View File

@@ -1,5 +1,6 @@
import math
from typing import Any, Dict, Literal, Optional
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from zoneinfo import ZoneInfo
from letta.constants import (
@@ -10,7 +11,7 @@ from letta.constants import (
)
from letta.helpers.json_helpers import json_dumps
from letta.schemas.agent import AgentState
from letta.schemas.enums import TagMatchMode
from letta.schemas.enums import MessageRole, TagMatchMode
from letta.schemas.sandbox_config import SandboxConfig
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
@@ -80,43 +81,126 @@ class LettaCoreToolExecutor(ToolExecutor):
"""
return "Sent message successfully."
async def conversation_search(self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0) -> Optional[str]:
async def conversation_search(
self,
agent_state: AgentState,
actor: User,
query: str,
roles: Optional[List[Literal["assistant", "user", "tool"]]] = None,
limit: Optional[int] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> Optional[str]:
"""
Search prior conversation history using case-insensitive string matching.
Search prior conversation history using hybrid search (text + semantic similarity).
Args:
query (str): String to search for.
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
query (str): String to search for using both text matching and semantic similarity.
roles (Optional[List[Literal["assistant", "user", "tool"]]]): Optional list of message roles to filter by.
limit (Optional[int]): Maximum number of results to return. Uses system default if not specified.
start_date (Optional[str]): Filter results to messages created after this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30".
end_date (Optional[str]): Filter results to messages created before this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00".
Returns:
str: Query result string
str: Query result string containing matching messages with timestamps and content.
"""
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
except:
raise ValueError("'page' argument must be an integer")
# Parse datetime parameters if provided
start_datetime = None
end_datetime = None
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
messages = await MessageManager().list_user_messages_for_agent_async(
agent_id=agent_state.id,
actor=actor,
query_text=query,
limit=count,
)
if start_date:
try:
# Try parsing as full datetime first (with time)
start_datetime = datetime.fromisoformat(start_date)
except ValueError:
try:
# Fall back to date-only format
start_datetime = datetime.strptime(start_date, "%Y-%m-%d")
# Set to beginning of day
start_datetime = start_datetime.replace(hour=0, minute=0, second=0, microsecond=0)
except ValueError:
raise ValueError(f"Invalid start_date format: {start_date}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)")
total = len(messages)
num_pages = math.ceil(total / count) - 1 # 0 index
# Apply agent's timezone if datetime is naive
if start_datetime.tzinfo is None and agent_state.timezone:
tz = ZoneInfo(agent_state.timezone)
start_datetime = start_datetime.replace(tzinfo=tz)
if len(messages) == 0:
results_str = "No results found."
else:
results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):"
results_formatted = [message.content[0].text for message in messages]
results_str = f"{results_pref} {json_dumps(results_formatted)}"
if end_date:
try:
# Try parsing as full datetime first (with time)
end_datetime = datetime.fromisoformat(end_date)
except ValueError:
try:
# Fall back to date-only format
end_datetime = datetime.strptime(end_date, "%Y-%m-%d")
# Set to end of day for end dates
end_datetime = end_datetime.replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise ValueError(f"Invalid end_date format: {end_date}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)")
return results_str
# Apply agent's timezone if datetime is naive
if end_datetime.tzinfo is None and agent_state.timezone:
tz = ZoneInfo(agent_state.timezone)
end_datetime = end_datetime.replace(tzinfo=tz)
# Convert string roles to MessageRole enum if provided
message_roles = None
if roles:
message_roles = [MessageRole(role) for role in roles]
# Use provided limit or default
search_limit = limit if limit is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
# Search using the message manager's search_messages_async method
messages = await self.message_manager.search_messages_async(
agent_id=agent_state.id,
actor=actor,
query_text=query,
roles=message_roles,
limit=search_limit,
start_date=start_datetime,
end_date=end_datetime,
embedding_config=agent_state.embedding_config,
)
if len(messages) == 0:
results_str = "No results found."
else:
results_pref = f"Showing {len(messages)} results:"
results_formatted = []
for message in messages:
# Format timestamp in agent's timezone if available
timestamp = message.created_at
if timestamp and agent_state.timezone:
try:
# Convert to agent's timezone
tz = ZoneInfo(agent_state.timezone)
local_time = timestamp.astimezone(tz)
# Format as ISO string with timezone
formatted_timestamp = local_time.isoformat()
except Exception:
# Fallback to ISO format if timezone conversion fails
formatted_timestamp = str(timestamp)
else:
# Use ISO format if no timezone is set
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
results_formatted.append(
{
"timestamp": formatted_timestamp,
"role": message.role,
"content": message.content[0].text if message.content else "",
}
)
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str
except Exception as e:
raise e
async def archival_memory_search(
self,