feat: Modify conversation search tool to be hybrid (#4362)
* Modify conversation search functionality * Gate the roles
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING
|
||||
@@ -20,16 +20,39 @@ def send_message(self: "Agent", message: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
def conversation_search(
|
||||
self: "Agent",
|
||||
query: str,
|
||||
roles: Optional[List[Literal["assistant", "user", "tool"]]] = None,
|
||||
limit: Optional[int] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Search prior conversation history using case-insensitive string matching.
|
||||
Search prior conversation history using hybrid search (text + semantic similarity).
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
query (str): String to search for using both text matching and semantic similarity.
|
||||
roles (Optional[List[Literal["assistant", "user", "tool"]]]): Optional list of message roles to filter by.
|
||||
limit (Optional[int]): Maximum number of results to return. Uses system default if not specified.
|
||||
start_date (Optional[str]): Filter results to messages created after this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30".
|
||||
end_date (Optional[str]): Filter results to messages created before this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00".
|
||||
|
||||
Examples:
|
||||
# Search all messages
|
||||
conversation_search(query="project updates")
|
||||
|
||||
# Search only assistant messages
|
||||
conversation_search(query="error handling", roles=["assistant"])
|
||||
|
||||
# Search with date range
|
||||
conversation_search(query="meetings", start_date="2024-01-15", end_date="2024-01-20")
|
||||
|
||||
# Search with limit
|
||||
conversation_search(query="debugging", limit=10)
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
str: Query result string containing matching messages with timestamps and content.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
@@ -34,12 +34,18 @@ class MessageManager:
|
||||
def _extract_message_text(self, message: PydanticMessage) -> str:
|
||||
"""Extract text content from a message's complex content structure.
|
||||
|
||||
Only extracts text from searchable message roles (assistant, user, tool).
|
||||
|
||||
Args:
|
||||
message: The message to extract text from
|
||||
|
||||
Returns:
|
||||
Concatenated text content from the message
|
||||
Concatenated text content from the message, or empty string for non-searchable roles
|
||||
"""
|
||||
# only extract text from searchable roles
|
||||
if message.role not in [MessageRole.assistant, MessageRole.user, MessageRole.tool]:
|
||||
return ""
|
||||
|
||||
if not message.content:
|
||||
return ""
|
||||
|
||||
@@ -218,7 +224,7 @@ class MessageManager:
|
||||
|
||||
for msg in result:
|
||||
text = self._extract_message_text(msg)
|
||||
if text: # only embed messages with text content
|
||||
if text: # only embed messages with text content (role filtering is handled in _extract_message_text)
|
||||
message_texts.append(text)
|
||||
message_ids.append(msg.id)
|
||||
roles.append(msg.role)
|
||||
@@ -228,15 +234,8 @@ class MessageManager:
|
||||
# generate embeddings using provided config
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
# extract provider info from embedding_config
|
||||
embedding_provider = embedding_config.get("provider", "openai")
|
||||
embedding_api_key = embedding_config.get("api_key")
|
||||
embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1")
|
||||
|
||||
embedding_client = LLMClient(
|
||||
llm_provider_type=embedding_provider,
|
||||
api_key=embedding_api_key,
|
||||
endpoint=embedding_endpoint,
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
|
||||
@@ -816,14 +815,8 @@ class MessageManager:
|
||||
# generate embedding from query text
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_provider = embedding_config.get("provider", "openai")
|
||||
embedding_api_key = embedding_config.get("api_key")
|
||||
embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1")
|
||||
|
||||
embedding_client = LLMClient(
|
||||
llm_provider_type=embedding_provider,
|
||||
api_key=embedding_api_key,
|
||||
endpoint=embedding_endpoint,
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([query_text], embedding_config)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from letta.constants import (
|
||||
@@ -10,7 +11,7 @@ from letta.constants import (
|
||||
)
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import TagMatchMode
|
||||
from letta.schemas.enums import MessageRole, TagMatchMode
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -80,43 +81,126 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
"""
|
||||
return "Sent message successfully."
|
||||
|
||||
async def conversation_search(self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
async def conversation_search(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
query: str,
|
||||
roles: Optional[List[Literal["assistant", "user", "tool"]]] = None,
|
||||
limit: Optional[int] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Search prior conversation history using case-insensitive string matching.
|
||||
Search prior conversation history using hybrid search (text + semantic similarity).
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
query (str): String to search for using both text matching and semantic similarity.
|
||||
roles (Optional[List[Literal["assistant", "user", "tool"]]]): Optional list of message roles to filter by.
|
||||
limit (Optional[int]): Maximum number of results to return. Uses system default if not specified.
|
||||
start_date (Optional[str]): Filter results to messages created after this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30".
|
||||
end_date (Optional[str]): Filter results to messages created before this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00".
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
str: Query result string containing matching messages with timestamps and content.
|
||||
"""
|
||||
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
|
||||
page = 0
|
||||
try:
|
||||
page = int(page)
|
||||
except:
|
||||
raise ValueError("'page' argument must be an integer")
|
||||
# Parse datetime parameters if provided
|
||||
start_datetime = None
|
||||
end_datetime = None
|
||||
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
messages = await MessageManager().list_user_messages_for_agent_async(
|
||||
agent_id=agent_state.id,
|
||||
actor=actor,
|
||||
query_text=query,
|
||||
limit=count,
|
||||
)
|
||||
if start_date:
|
||||
try:
|
||||
# Try parsing as full datetime first (with time)
|
||||
start_datetime = datetime.fromisoformat(start_date)
|
||||
except ValueError:
|
||||
try:
|
||||
# Fall back to date-only format
|
||||
start_datetime = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
# Set to beginning of day
|
||||
start_datetime = start_datetime.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid start_date format: {start_date}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)")
|
||||
|
||||
total = len(messages)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
# Apply agent's timezone if datetime is naive
|
||||
if start_datetime.tzinfo is None and agent_state.timezone:
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
start_datetime = start_datetime.replace(tzinfo=tz)
|
||||
|
||||
if len(messages) == 0:
|
||||
results_str = "No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [message.content[0].text for message in messages]
|
||||
results_str = f"{results_pref} {json_dumps(results_formatted)}"
|
||||
if end_date:
|
||||
try:
|
||||
# Try parsing as full datetime first (with time)
|
||||
end_datetime = datetime.fromisoformat(end_date)
|
||||
except ValueError:
|
||||
try:
|
||||
# Fall back to date-only format
|
||||
end_datetime = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
# Set to end of day for end dates
|
||||
end_datetime = end_datetime.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid end_date format: {end_date}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)")
|
||||
|
||||
return results_str
|
||||
# Apply agent's timezone if datetime is naive
|
||||
if end_datetime.tzinfo is None and agent_state.timezone:
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
end_datetime = end_datetime.replace(tzinfo=tz)
|
||||
|
||||
# Convert string roles to MessageRole enum if provided
|
||||
message_roles = None
|
||||
if roles:
|
||||
message_roles = [MessageRole(role) for role in roles]
|
||||
|
||||
# Use provided limit or default
|
||||
search_limit = limit if limit is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
|
||||
# Search using the message manager's search_messages_async method
|
||||
messages = await self.message_manager.search_messages_async(
|
||||
agent_id=agent_state.id,
|
||||
actor=actor,
|
||||
query_text=query,
|
||||
roles=message_roles,
|
||||
limit=search_limit,
|
||||
start_date=start_datetime,
|
||||
end_date=end_datetime,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
)
|
||||
|
||||
if len(messages) == 0:
|
||||
results_str = "No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(messages)} results:"
|
||||
results_formatted = []
|
||||
for message in messages:
|
||||
# Format timestamp in agent's timezone if available
|
||||
timestamp = message.created_at
|
||||
if timestamp and agent_state.timezone:
|
||||
try:
|
||||
# Convert to agent's timezone
|
||||
tz = ZoneInfo(agent_state.timezone)
|
||||
local_time = timestamp.astimezone(tz)
|
||||
# Format as ISO string with timezone
|
||||
formatted_timestamp = local_time.isoformat()
|
||||
except Exception:
|
||||
# Fallback to ISO format if timezone conversion fails
|
||||
formatted_timestamp = str(timestamp)
|
||||
else:
|
||||
# Use ISO format if no timezone is set
|
||||
formatted_timestamp = str(timestamp) if timestamp else "Unknown"
|
||||
|
||||
results_formatted.append(
|
||||
{
|
||||
"timestamp": formatted_timestamp,
|
||||
"role": message.role,
|
||||
"content": message.content[0].text if message.content else "",
|
||||
}
|
||||
)
|
||||
|
||||
results_str = f"{results_pref} {json_dumps(results_formatted)}"
|
||||
|
||||
return results_str
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def archival_memory_search(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user