From 8f425aa024060c50119102cbc3ebc568ca27fe77 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 2 Sep 2025 13:45:53 -0700 Subject: [PATCH] feat: Modify conversation search tool to be hybrid (#4362) * Modify conversation search functionality * Gate the roles --- letta/functions/function_sets/base.py | 35 ++++- letta/services/message_manager.py | 31 ++-- .../tool_executor/core_tool_executor.py | 140 ++++++++++++++---- 3 files changed, 153 insertions(+), 53 deletions(-) diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index 5ae5ab9a..0f802a72 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import List, Literal, Optional from letta.agent import Agent from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING @@ -20,16 +20,39 @@ def send_message(self: "Agent", message: str) -> Optional[str]: return None -def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]: +def conversation_search( + self: "Agent", + query: str, + roles: Optional[List[Literal["assistant", "user", "tool"]]] = None, + limit: Optional[int] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, +) -> Optional[str]: """ - Search prior conversation history using case-insensitive string matching. + Search prior conversation history using hybrid search (text + semantic similarity). Args: - query (str): String to search for. - page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). + query (str): String to search for using both text matching and semantic similarity. + roles (Optional[List[Literal["assistant", "user", "tool"]]]): Optional list of message roles to filter by. + limit (Optional[int]): Maximum number of results to return. Uses system default if not specified. + start_date (Optional[str]): Filter results to messages created after this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30". + end_date (Optional[str]): Filter results to messages created before this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00". + + Examples: + # Search all messages + conversation_search(query="project updates") + + # Search only assistant messages + conversation_search(query="error handling", roles=["assistant"]) + + # Search with date range + conversation_search(query="meetings", start_date="2024-01-15", end_date="2024-01-20") + + # Search with limit + conversation_search(query="debugging", limit=10) Returns: - str: Query result string + str: Query result string containing matching messages with timestamps and content. """ import math diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index cc62b10f..32cd7213 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -34,12 +34,18 @@ class MessageManager: def _extract_message_text(self, message: PydanticMessage) -> str: """Extract text content from a message's complex content structure. + Only extracts text from searchable message roles (assistant, user, tool). + Args: message: The message to extract text from Returns: - Concatenated text content from the message + Concatenated text content from the message, or empty string for non-searchable roles """ + # only extract text from searchable roles + if message.role not in [MessageRole.assistant, MessageRole.user, MessageRole.tool]: + return "" + if not message.content: return "" @@ -218,7 +224,7 @@ class MessageManager: for msg in result: text = self._extract_message_text(msg) - if text: # only embed messages with text content + if text: # only embed messages with text content (role filtering is handled in _extract_message_text) message_texts.append(text) message_ids.append(msg.id) roles.append(msg.role) @@ -228,15 +234,8 @@ class MessageManager: # generate embeddings using provided config from letta.llm_api.llm_client import LLMClient - # extract provider info from embedding_config - embedding_provider = embedding_config.get("provider", "openai") - embedding_api_key = embedding_config.get("api_key") - embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1") - - embedding_client = LLMClient( - llm_provider_type=embedding_provider, - api_key=embedding_api_key, - endpoint=embedding_endpoint, + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, actor=actor, ) embeddings = await embedding_client.request_embeddings(message_texts, embedding_config) @@ -816,14 +815,8 @@ class MessageManager: # generate embedding from query text from letta.llm_api.llm_client import LLMClient - embedding_provider = embedding_config.get("provider", "openai") - embedding_api_key = embedding_config.get("api_key") - embedding_endpoint = embedding_config.get("endpoint", "https://api.openai.com/v1") - - embedding_client = LLMClient( - llm_provider_type=embedding_provider, - api_key=embedding_api_key, - endpoint=embedding_endpoint, + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, actor=actor, ) embeddings = await embedding_client.request_embeddings([query_text], embedding_config) diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 695041b2..7e709ae1 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -1,5 +1,6 @@ import math -from typing import Any, Dict, Literal, Optional +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional from zoneinfo import ZoneInfo from letta.constants import ( @@ -10,7 +11,7 @@ from letta.constants import ( ) from letta.helpers.json_helpers import json_dumps from letta.schemas.agent import AgentState -from letta.schemas.enums import TagMatchMode +from letta.schemas.enums import MessageRole, TagMatchMode from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult @@ -80,43 +81,126 @@ class LettaCoreToolExecutor(ToolExecutor): """ return "Sent message successfully." - async def conversation_search(self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0) -> Optional[str]: + async def conversation_search( + self, + agent_state: AgentState, + actor: User, + query: str, + roles: Optional[List[Literal["assistant", "user", "tool"]]] = None, + limit: Optional[int] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ) -> Optional[str]: """ - Search prior conversation history using case-insensitive string matching. + Search prior conversation history using hybrid search (text + semantic similarity). Args: - query (str): String to search for. - page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). + query (str): String to search for using both text matching and semantic similarity. + roles (Optional[List[Literal["assistant", "user", "tool"]]]): Optional list of message roles to filter by. + limit (Optional[int]): Maximum number of results to return. Uses system default if not specified. + start_date (Optional[str]): Filter results to messages created after this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-15", "2024-01-15T14:30". + end_date (Optional[str]): Filter results to messages created before this date. ISO 8601 format: "YYYY-MM-DD" or "YYYY-MM-DDTHH:MM". Examples: "2024-01-20", "2024-01-20T17:00". Returns: - str: Query result string + str: Query result string containing matching messages with timestamps and content. """ - if page is None or (isinstance(page, str) and page.lower().strip() == "none"): - page = 0 try: - page = int(page) - except: - raise ValueError("'page' argument must be an integer") + # Parse datetime parameters if provided + start_datetime = None + end_datetime = None - count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - messages = await MessageManager().list_user_messages_for_agent_async( - agent_id=agent_state.id, - actor=actor, - query_text=query, - limit=count, - ) + if start_date: + try: + # Try parsing as full datetime first (with time) + start_datetime = datetime.fromisoformat(start_date) + except ValueError: + try: + # Fall back to date-only format + start_datetime = datetime.strptime(start_date, "%Y-%m-%d") + # Set to beginning of day + start_datetime = start_datetime.replace(hour=0, minute=0, second=0, microsecond=0) + except ValueError: + raise ValueError(f"Invalid start_date format: {start_date}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)") - total = len(messages) - num_pages = math.ceil(total / count) - 1 # 0 index + # Apply agent's timezone if datetime is naive + if start_datetime.tzinfo is None and agent_state.timezone: + tz = ZoneInfo(agent_state.timezone) + start_datetime = start_datetime.replace(tzinfo=tz) - if len(messages) == 0: - results_str = "No results found." - else: - results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):" - results_formatted = [message.content[0].text for message in messages] - results_str = f"{results_pref} {json_dumps(results_formatted)}" + if end_date: + try: + # Try parsing as full datetime first (with time) + end_datetime = datetime.fromisoformat(end_date) + except ValueError: + try: + # Fall back to date-only format + end_datetime = datetime.strptime(end_date, "%Y-%m-%d") + # Set to end of day for end dates + end_datetime = end_datetime.replace(hour=23, minute=59, second=59, microsecond=999999) + except ValueError: + raise ValueError(f"Invalid end_date format: {end_date}. Use ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM)") - return results_str + # Apply agent's timezone if datetime is naive + if end_datetime.tzinfo is None and agent_state.timezone: + tz = ZoneInfo(agent_state.timezone) + end_datetime = end_datetime.replace(tzinfo=tz) + + # Convert string roles to MessageRole enum if provided + message_roles = None + if roles: + message_roles = [MessageRole(role) for role in roles] + + # Use provided limit or default + search_limit = limit if limit is not None else RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE + + # Search using the message manager's search_messages_async method + messages = await self.message_manager.search_messages_async( + agent_id=agent_state.id, + actor=actor, + query_text=query, + roles=message_roles, + limit=search_limit, + start_date=start_datetime, + end_date=end_datetime, + embedding_config=agent_state.embedding_config, + ) + + if len(messages) == 0: + results_str = "No results found." + else: + results_pref = f"Showing {len(messages)} results:" + results_formatted = [] + for message in messages: + # Format timestamp in agent's timezone if available + timestamp = message.created_at + if timestamp and agent_state.timezone: + try: + # Convert to agent's timezone + tz = ZoneInfo(agent_state.timezone) + local_time = timestamp.astimezone(tz) + # Format as ISO string with timezone + formatted_timestamp = local_time.isoformat() + except Exception: + # Fallback to ISO format if timezone conversion fails + formatted_timestamp = str(timestamp) + else: + # Use ISO format if no timezone is set + formatted_timestamp = str(timestamp) if timestamp else "Unknown" + + results_formatted.append( + { + "timestamp": formatted_timestamp, + "role": message.role, + "content": message.content[0].text if message.content else "", + } + ) + + results_str = f"{results_pref} {json_dumps(results_formatted)}" + + return results_str + + except Exception as e: + raise e async def archival_memory_search( self,