From 6e633bd8f9ddde270081e70e54dc2cab89e98d48 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 3 Sep 2025 10:00:19 -0700 Subject: [PATCH] feat: Change namespace to be org scoped and filter on agent_id [LET-4163] (#4368) * Change to org scoped and agent_id filtering * Finish modifying conversation search tool * Fix failing tests * Get rid of bad imports --- letta/agents/helpers.py | 1 + letta/agents/letta_agent.py | 20 +- letta/constants.py | 3 + letta/helpers/tpuf_client.py | 42 ++-- letta/services/agent_manager.py | 28 ++- letta/services/message_manager.py | 227 ++++++++++++++++-- .../tool_executor/core_tool_executor.py | 74 +++++- tests/integration_test_turbopuffer.py | 27 ++- 8 files changed, 354 insertions(+), 68 deletions(-) diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 608199e5..f4a58b65 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -117,6 +117,7 @@ async def _prepare_in_context_messages_async( new_in_context_messages = await message_manager.create_many_messages_async( create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor), actor=actor, + embedding_config=agent_state.embedding_config, ) return current_in_context_messages, new_in_context_messages diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index a5b4d4ec..76183c44 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -494,7 +494,9 @@ class LettaAgent(BaseAgent): for message in initial_messages: message.is_err = True message.step_id = effective_step_id - await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + await self.message_manager.create_many_messages_async( + initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config + ) elif step_progression <= StepProgression.LOGGED_TRACE: if stop_reason is None: self.logger.error("Error in step after logging step") @@ -820,7 +822,9 @@ class LettaAgent(BaseAgent): for message in initial_messages: message.is_err = True message.step_id = effective_step_id - await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + await self.message_manager.create_many_messages_async( + initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config + ) elif step_progression <= StepProgression.LOGGED_TRACE: if stop_reason is None: self.logger.error("Error in step after logging step") @@ -1254,7 +1258,9 @@ class LettaAgent(BaseAgent): for message in initial_messages: message.is_err = True message.step_id = effective_step_id - await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + await self.message_manager.create_many_messages_async( + initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config + ) elif step_progression <= StepProgression.LOGGED_TRACE: if stop_reason is None: self.logger.error("Error in step after logging step") @@ -1660,7 +1666,9 @@ class LettaAgent(BaseAgent): is_approval_response=True, ) messages_to_persist = (initial_messages or []) + tool_call_messages - persisted_messages = await self.message_manager.create_many_messages_async(messages_to_persist, actor=self.actor) + persisted_messages = await self.message_manager.create_many_messages_async( + messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config + ) return persisted_messages, continue_stepping, stop_reason # 1. Parse and validate the tool-call envelope @@ -1770,7 +1778,9 @@ class LettaAgent(BaseAgent): ) messages_to_persist = (initial_messages or []) + tool_call_messages - persisted_messages = await self.message_manager.create_many_messages_async(messages_to_persist, actor=self.actor) + persisted_messages = await self.message_manager.create_many_messages_async( + messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config + ) if run_id: await self.job_manager.add_messages_to_job_async( diff --git a/letta/constants.py b/letta/constants.py index 9325e4d3..531d8074 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -164,6 +164,9 @@ def FUNCTION_RETURN_VALUE_TRUNCATED(return_str, return_char: int, return_char_li DEFAULT_MESSAGE_TOOL = SEND_MESSAGE_TOOL_NAME DEFAULT_MESSAGE_TOOL_KWARG = "message" +# The name of the conversation search tool - messages with this tool should not be indexed +CONVERSATION_SEARCH_TOOL_NAME = "conversation_search" + PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg" REQUEST_HEARTBEAT_PARAM = "request_heartbeat" diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 7897b042..78eb0572 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -44,9 +44,17 @@ class TurbopufferClient: return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id) @trace_method - async def _get_message_namespace_name(self, agent_id: str) -> str: - """Get namespace name for a specific agent's messages.""" - return await self.agent_manager.get_or_set_vector_db_namespace_async(agent_id) + async def _get_message_namespace_name(self, agent_id: str, organization_id: str) -> str: + """Get namespace name for messages (org-scoped). + + Args: + agent_id: Agent ID (stored for future sharding) + organization_id: Organization ID for namespace generation + + Returns: + The org-scoped namespace name for messages + """ + return await self.agent_manager.get_or_set_vector_db_namespace_async(agent_id, organization_id) @trace_method async def insert_archival_memories( @@ -191,7 +199,7 @@ class TurbopufferClient: """ from turbopuffer import AsyncTurbopuffer - namespace_name = await self._get_message_namespace_name(agent_id) + namespace_name = await self._get_message_namespace_name(agent_id, organization_id) # validation checks if not message_ids: @@ -481,6 +489,7 @@ class TurbopufferClient: async def query_messages( self, agent_id: str, + organization_id: str, query_embedding: Optional[List[float]] = None, query_text: Optional[str] = None, search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp" @@ -494,7 +503,8 @@ class TurbopufferClient: """Query messages from Turbopuffer using vector search, full-text search, or hybrid search. Args: - agent_id: ID of the agent + agent_id: ID of the agent (used for filtering results) + organization_id: Organization ID for namespace lookup query_embedding: Embedding vector for vector search (required for "vector" and "hybrid" modes) query_text: Text query for full-text search (required for "fts" and "hybrid" modes) search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector") @@ -513,7 +523,10 @@ class TurbopufferClient: # Fallback to retrieving most recent messages when no search query is provided search_mode = "timestamp" - namespace_name = await self._get_message_namespace_name(agent_id) + namespace_name = await self._get_message_namespace_name(agent_id, organization_id) + + # build agent_id filter + agent_filter = ("agent_id", "Eq", agent_id) # build role filter conditions role_filter = None @@ -532,7 +545,7 @@ class TurbopufferClient: date_filters.append(("created_at", "Lte", end_date)) # combine all filters - all_filters = [] + all_filters = [agent_filter] # always include agent_id filter if role_filter: all_filters.append(role_filter) if date_filters: @@ -776,14 +789,14 @@ class TurbopufferClient: raise @trace_method - async def delete_messages(self, agent_id: str, message_ids: List[str]) -> bool: + async def delete_messages(self, agent_id: str, organization_id: str, message_ids: List[str]) -> bool: """Delete multiple messages from Turbopuffer.""" from turbopuffer import AsyncTurbopuffer if not message_ids: return True - namespace_name = await self._get_message_namespace_name(agent_id) + namespace_name = await self._get_message_namespace_name(agent_id, organization_id) try: async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: @@ -797,18 +810,19 @@ class TurbopufferClient: raise @trace_method - async def delete_all_messages(self, agent_id: str) -> bool: + async def delete_all_messages(self, agent_id: str, organization_id: str) -> bool: """Delete all messages for an agent from Turbopuffer.""" from turbopuffer import AsyncTurbopuffer - namespace_name = await self._get_message_namespace_name(agent_id) + namespace_name = await self._get_message_namespace_name(agent_id, organization_id) try: async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: namespace = client.namespace(namespace_name) - # Turbopuffer has a delete_all() method on namespace - await namespace.delete_all() - logger.info(f"Successfully deleted all messages for agent {agent_id}") + # Use delete_by_filter to only delete messages for this agent + # since namespace is now org-scoped + result = await namespace.write(delete_by_filter=("agent_id", "Eq", agent_id)) + logger.info(f"Successfully deleted all messages for agent {agent_id} (deleted {result.rows_affected} rows)") return True except Exception as e: logger.error(f"Failed to delete all messages from Turbopuffer: {e}") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 90a9ef02..01214505 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -717,7 +717,9 @@ class AgentManager: # Only create messages if we initialized with messages if not _init_with_no_messages: - await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor) + await self.message_manager.create_many_messages_async( + pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config + ) return result @enforce_types @@ -1882,8 +1884,8 @@ class AgentManager: async def append_to_in_context_messages_async( self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser ) -> PydanticAgentState: - messages = await self.message_manager.create_many_messages_async(messages, actor=actor) agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + messages = await self.message_manager.create_many_messages_async(messages, actor=actor, embedding_config=agent.embedding_config) message_ids = agent.message_ids or [] message_ids += [m.id for m in messages] return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor) @@ -3577,8 +3579,17 @@ class AgentManager: async def get_or_set_vector_db_namespace_async( self, agent_id: str, + organization_id: str, ) -> str: - """Get the vector database namespace for an agent, creating it if it doesn't exist.""" + """Get the vector database namespace for an agent, creating it if it doesn't exist. + + Args: + agent_id: Agent ID to check/store namespace + organization_id: Organization ID for namespace generation + + Returns: + The org-scoped namespace name + """ from sqlalchemy import update from letta.settings import settings @@ -3591,14 +3602,17 @@ class AgentManager: if row and row[0]: return row[0] - # generate namespace name using same logic as tpuf_client + # TODO: In the future, we might use agent_id for sharding the namespace + # For now, all messages in an org share the same namespace + + # generate org-scoped namespace name environment = settings.environment if environment: - namespace_name = f"messages_{agent_id}_{environment.lower()}" + namespace_name = f"messages_{organization_id}_{environment.lower()}" else: - namespace_name = f"messages_{agent_id}" + namespace_name = f"messages_{organization_id}" - # update the agent with the namespace + # update the agent with the namespace (keeps agent-level tracking for future sharding) await session.execute(update(AgentModel).where(AgentModel.id == agent_id).values(_vector_db_namespace=namespace_name)) await session.commit() diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index dd927e42..51c96558 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -5,6 +5,7 @@ from typing import List, Optional, Sequence from sqlalchemy import delete, exists, func, select, text +from letta.constants import CONVERSATION_SEARCH_TOOL_NAME, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger from letta.orm.agent import Agent as AgentModel from letta.orm.errors import NoResultFound @@ -36,32 +37,183 @@ class MessageManager: """Extract text content from a message's complex content structure. Only extracts text from searchable message roles (assistant, user, tool). + Returns JSON format for all message types for consistency. Args: message: The message to extract text from Returns: - Concatenated text content from the message, or empty string for non-searchable roles + JSON string with message content, or empty string for non-searchable roles """ # only extract text from searchable roles if message.role not in [MessageRole.assistant, MessageRole.user, MessageRole.tool]: return "" + # skip tool messages related to send_message and conversation_search entirely + if message.role == MessageRole.tool and message.name in [DEFAULT_MESSAGE_TOOL, CONVERSATION_SEARCH_TOOL_NAME]: + return "" + if not message.content: return "" - # handle string content (legacy) + # extract raw content text if isinstance(message.content, str): - return message.content + content_str = message.content + else: + text_parts = [] + for content_item in message.content: + text = content_item.to_text() + if text: + text_parts.append(text) + content_str = " ".join(text_parts) - # handle list of content items using the to_text() method - text_parts = [] - for content_item in message.content: - text = content_item.to_text() - if text: # only add non-None text - text_parts.append(text) + # skip heartbeat messages entirely + try: + if content_str.strip().startswith("{"): + parsed_content = json.loads(content_str) + if isinstance(parsed_content, dict) and parsed_content.get("type") == "heartbeat": + return "" + except (json.JSONDecodeError, ValueError): + pass - return " ".join(text_parts) + # format everything as JSON + if message.role == MessageRole.user: + # check if content_str is already valid JSON to avoid double nesting + try: + # if it's already valid JSON, return as-is + json.loads(content_str) + return content_str + except (json.JSONDecodeError, ValueError): + # if not valid JSON, wrap it + return json.dumps({"content": content_str}) + + elif message.role == MessageRole.assistant and message.tool_calls: + # skip assistant messages that call conversation_search + for tool_call in message.tool_calls: + if tool_call.function.name == CONVERSATION_SEARCH_TOOL_NAME: + return "" + + # check if any tool call is send_message + for tool_call in message.tool_calls: + if tool_call.function.name == DEFAULT_MESSAGE_TOOL: + # extract the actual message from tool call arguments + try: + args = json.loads(tool_call.function.arguments) + actual_message = args.get(DEFAULT_MESSAGE_TOOL_KWARG, "") + + return json.dumps({"thinking": content_str, "message": actual_message}) + except (json.JSONDecodeError, KeyError): + # fallback if parsing fails + pass + + # default for other messages (tool responses, assistant without send_message) + # check if content_str is already valid JSON to avoid double nesting + if message.role == MessageRole.assistant: + try: + # if it's already valid JSON, return as-is + json.loads(content_str) + return content_str + except (json.JSONDecodeError, ValueError): + # if not valid JSON, wrap it + return json.dumps({"content": content_str}) + else: + # for tool messages and others, wrap in content + return json.dumps({"content": content_str}) + + def _combine_assistant_tool_messages(self, messages: List[PydanticMessage]) -> List[PydanticMessage]: + """Combine assistant messages with their corresponding tool results when IDs match. + + Args: + messages: List of messages to process + + Returns: + List of messages with assistant+tool combinations merged + """ + from letta.constants import DEFAULT_MESSAGE_TOOL + + combined_messages = [] + i = 0 + + while i < len(messages): + current_msg = messages[i] + + # skip heartbeat messages + if self._extract_message_text(current_msg) == "": + i += 1 + continue + + # if this is an assistant message with tool calls, look for matching tool response + if current_msg.role == MessageRole.assistant and current_msg.tool_calls and i + 1 < len(messages): + next_msg = messages[i + 1] + + # check if next message is a tool response that matches + if ( + next_msg.role == MessageRole.tool + and next_msg.tool_call_id + and any(tc.id == next_msg.tool_call_id for tc in current_msg.tool_calls) + ): + # combine the messages - get raw content to avoid double-processing + assistant_text = current_msg.content[0].text if current_msg.content else "" + + # for non-send_message tools, include tool result + if next_msg.name != DEFAULT_MESSAGE_TOOL: + tool_result_text = next_msg.content[0].text if next_msg.content else "" + + # get the tool call that matches this result (we know it exists from the condition above) + matching_tool_call = next((tc for tc in current_msg.tool_calls if tc.id == next_msg.tool_call_id), None) + + # format tool call with parameters + try: + args = json.loads(matching_tool_call.function.arguments) + if args: + # format parameters nicely + param_strs = [f"{k}={repr(v)}" for k, v in args.items()] + tool_call_str = f"{matching_tool_call.function.name}({', '.join(param_strs)})" + else: + tool_call_str = f"{matching_tool_call.function.name}()" + except (json.JSONDecodeError, KeyError): + tool_call_str = f"{matching_tool_call.function.name}()" + + # format tool result cleanly + try: + if tool_result_text.strip().startswith("{"): + parsed_result = json.loads(tool_result_text) + if isinstance(parsed_result, dict): + # extract key information from tool result + if "message" in parsed_result: + tool_result_summary = parsed_result["message"] + elif "status" in parsed_result: + tool_result_summary = f"Status: {parsed_result['status']}" + else: + tool_result_summary = tool_result_text + else: + tool_result_summary = tool_result_text + else: + tool_result_summary = tool_result_text + except (json.JSONDecodeError, ValueError): + tool_result_summary = tool_result_text + + combined_data = {"thinking": assistant_text, "tool_call": tool_call_str, "tool_result": tool_result_summary} + combined_text = json.dumps(combined_data) + else: + combined_text = assistant_text + + # create a new combined message + from letta.schemas.letta_message_content import TextContent + + combined_message = current_msg.model_copy() + combined_message.content = [TextContent(text=combined_text)] + combined_messages.append(combined_message) + + # skip the tool message since we combined it + i += 2 + continue + + # if no combination, add the message as-is + combined_messages.append(current_msg) + i += 1 + + return combined_messages @enforce_types @trace_method @@ -223,9 +375,11 @@ class MessageManager: message_ids = [] roles = [] created_ats = [] + # combine assistant+tool messages before embedding + combined_messages = self._combine_assistant_tool_messages(result) - for msg in result: - text = self._extract_message_text(msg) + for msg in combined_messages: + text = self._extract_message_text(msg).strip() if text: # only embed messages with text content (role filtering is handled in _extract_message_text) message_texts.append(text) message_ids.append(msg.id) @@ -256,6 +410,7 @@ class MessageManager: logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}") except Exception as e: logger.error(f"Failed to embed messages in Turbopuffer: {e}") + if strict_mode: raise # Re-raise the exception in strict mode @@ -397,7 +552,9 @@ class MessageManager: tpuf_client = TurbopufferClient() # delete old message from turbopuffer - await tpuf_client.delete_messages(agent_id=pydantic_message.agent_id, message_ids=[message_id]) + await tpuf_client.delete_messages( + agent_id=pydantic_message.agent_id, organization_id=actor.organization_id, message_ids=[message_id] + ) # generate new embedding from letta.llm_api.llm_client import LLMClient @@ -487,7 +644,9 @@ class MessageManager: if should_use_tpuf_for_messages() and agent_id: try: tpuf_client = TurbopufferClient() - await tpuf_client.delete_messages(agent_id=agent_id, message_ids=[message_id]) + await tpuf_client.delete_messages( + agent_id=agent_id, organization_id=actor.organization_id, message_ids=[message_id] + ) logger.info(f"Successfully deleted message {message_id} from Turbopuffer") except Exception as e: logger.error(f"Failed to delete message from Turbopuffer: {e}") @@ -834,7 +993,7 @@ class MessageManager: # for now, log a warning logger.warning(f"Turbopuffer deletion with exclude_ids not fully supported, using delete_all for agent {agent_id}") # delete all messages for the agent from turbopuffer - await tpuf_client.delete_all_messages(agent_id) + await tpuf_client.delete_all_messages(agent_id, actor.organization_id) logger.info(f"Successfully deleted all messages for agent {agent_id} from Turbopuffer") except Exception as e: logger.error(f"Failed to delete messages from Turbopuffer: {e}") @@ -882,7 +1041,7 @@ class MessageManager: tpuf_client = TurbopufferClient() # delete from each affected agent's namespace for agent_id in agent_ids: - await tpuf_client.delete_messages(agent_id=agent_id, message_ids=message_ids) + await tpuf_client.delete_messages(agent_id=agent_id, organization_id=actor.organization_id, message_ids=message_ids) logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer") except Exception as e: logger.error(f"Failed to delete messages from Turbopuffer: {e}") @@ -958,6 +1117,7 @@ class MessageManager: tpuf_client = TurbopufferClient() results = await tpuf_client.query_messages( agent_id=agent_id, + organization_id=actor.organization_id, query_embedding=query_embedding, query_text=query_text, search_mode=search_mode, @@ -967,20 +1127,35 @@ class MessageManager: end_date=end_date, ) - # fetch full message objects from database using the IDs - message_ids = [msg_dict["id"] for msg_dict, _ in results] - if message_ids: - messages = await self.get_messages_by_ids_async(message_ids, actor) - # maintain the order from turbopuffer results - message_dict = {msg.id: msg for msg in messages} - return [message_dict[msg_id] for msg_id in message_ids if msg_id in message_dict] + # create message-like objects using turbopuffer data (which already has properly extracted text) + if results: + # create simplified message objects from turbopuffer data + from letta.schemas.letta_message_content import TextContent + from letta.schemas.message import Message as PydanticMessage + + turbopuffer_messages = [] + for msg_dict, score in results: + # create a message object with the properly extracted text from turbopuffer + message = PydanticMessage( + id=msg_dict["id"], + agent_id=agent_id, + role=MessageRole(msg_dict["role"]), + content=[TextContent(text=msg_dict["text"])], + created_at=msg_dict["created_at"], + updated_at=msg_dict["created_at"], # use created_at as fallback + created_by_id=actor.id, + last_updated_by_id=actor.id, + ) + turbopuffer_messages.append(message) + + return turbopuffer_messages else: return [] except Exception as e: logger.error(f"Failed to search messages with Turbopuffer, falling back to SQL: {e}") # fall back to SQL search - return await self.list_messages_for_agent_async( + messages = await self.list_messages_for_agent_async( agent_id=agent_id, actor=actor, query_text=query_text, @@ -988,9 +1163,10 @@ class MessageManager: limit=limit, ascending=False, ) + return self._combine_assistant_tool_messages(messages) else: # use sql-based search - return await self.list_messages_for_agent_async( + messages = await self.list_messages_for_agent_async( agent_id=agent_id, actor=actor, query_text=query_text, @@ -998,3 +1174,4 @@ class MessageManager: limit=limit, ascending=False, ) + return self._combine_assistant_tool_messages(messages) diff --git a/letta/services/tool_executor/core_tool_executor.py b/letta/services/tool_executor/core_tool_executor.py index 7e709ae1..ac28cf31 100644 --- a/letta/services/tool_executor/core_tool_executor.py +++ b/letta/services/tool_executor/core_tool_executor.py @@ -1,4 +1,3 @@ -import math from datetime import datetime from typing import Any, Dict, List, Literal, Optional from zoneinfo import ZoneInfo @@ -10,16 +9,18 @@ from letta.constants import ( RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE, ) from letta.helpers.json_helpers import json_dumps +from letta.log import get_logger from letta.schemas.agent import AgentState from letta.schemas.enums import MessageRole, TagMatchMode from letta.schemas.sandbox_config import SandboxConfig from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User -from letta.services.message_manager import MessageManager from letta.services.tool_executor.tool_executor_base import ToolExecutor from letta.utils import get_friendly_error_msg +logger = get_logger(__name__) + class LettaCoreToolExecutor(ToolExecutor): """Executor for LETTA core tools with direct implementation of functions.""" @@ -170,9 +171,24 @@ class LettaCoreToolExecutor(ToolExecutor): else: results_pref = f"Showing {len(messages)} results:" results_formatted = [] + # get current time in UTC, then convert to agent timezone for consistent comparison + from datetime import timezone + + now_utc = datetime.now(timezone.utc) + if agent_state.timezone: + try: + tz = ZoneInfo(agent_state.timezone) + now = now_utc.astimezone(tz) + except Exception: + now = now_utc + else: + now = now_utc + for message in messages: # Format timestamp in agent's timezone if available timestamp = message.created_at + time_delta_str = "" + if timestamp and agent_state.timezone: try: # Convert to agent's timezone @@ -180,6 +196,23 @@ class LettaCoreToolExecutor(ToolExecutor): local_time = timestamp.astimezone(tz) # Format as ISO string with timezone formatted_timestamp = local_time.isoformat() + + # Calculate time delta + delta = now - local_time + total_seconds = int(delta.total_seconds()) + + if total_seconds < 60: + time_delta_str = f"{total_seconds}s ago" + elif total_seconds < 3600: + minutes = total_seconds // 60 + time_delta_str = f"{minutes}m ago" + elif total_seconds < 86400: + hours = total_seconds // 3600 + time_delta_str = f"{hours}h ago" + else: + days = total_seconds // 86400 + time_delta_str = f"{days}d ago" + except Exception: # Fallback to ISO format if timezone conversion fails formatted_timestamp = str(timestamp) @@ -187,14 +220,37 @@ class LettaCoreToolExecutor(ToolExecutor): # Use ISO format if no timezone is set formatted_timestamp = str(timestamp) if timestamp else "Unknown" - results_formatted.append( - { - "timestamp": formatted_timestamp, - "role": message.role, - "content": message.content[0].text if message.content else "", - } - ) + content = self.message_manager._extract_message_text(message) + # Create the base result dict + result_dict = { + "timestamp": formatted_timestamp, + "time_ago": time_delta_str, + "role": message.role, + } + + # _extract_message_text returns already JSON-encoded strings + # We need to parse them to get the actual content structure + if content: + try: + import json + + parsed_content = json.loads(content) + + # Add the parsed content directly to avoid double JSON encoding + if isinstance(parsed_content, dict): + # Merge the parsed content into result_dict + result_dict.update(parsed_content) + else: + # If it's not a dict, add as content + result_dict["content"] = parsed_content + except (json.JSONDecodeError, ValueError): + # if not valid JSON, add as plain content + result_dict["content"] = content + + results_formatted.append(result_dict) + + # Don't double-encode - results_formatted already has the parsed content results_str = f"{results_pref} {json_dumps(results_formatted)}" return results_str diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index 68d137de..d8f1aacc 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -850,7 +850,7 @@ class TestTurbopufferMessagesIntegration: agent_id="test-agent", ) text1 = manager._extract_message_text(msg1) - assert text1 == "Simple text content" + assert text1 == '{"content": "Simple text content"}' # Test 2: List with single TextContent msg2 = PydanticMessage( @@ -859,7 +859,7 @@ class TestTurbopufferMessagesIntegration: agent_id="test-agent", ) text2 = manager._extract_message_text(msg2) - assert text2 == "Single text content" + assert text2 == '{"content": "Single text content"}' # Test 3: List with multiple TextContent items msg3 = PydanticMessage( @@ -872,7 +872,7 @@ class TestTurbopufferMessagesIntegration: agent_id="test-agent", ) text3 = manager._extract_message_text(msg3) - assert text3 == "First part Second part Third part" + assert text3 == '{"content": "First part Second part Third part"}' # Test 4: Empty content msg4 = PydanticMessage( @@ -910,7 +910,10 @@ class TestTurbopufferMessagesIntegration: "Tool result: Found 5 results", "I should help the user", ] - assert text6 == " ".join(expected_parts) + assert ( + text6 + == '{"content": "User said: Tool call: search({\\n \\"query\\": \\"test\\"\\n}) Tool result: Found 5 results I should help the user"}' + ) @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") @@ -1094,6 +1097,7 @@ class TestTurbopufferMessagesIntegration: # Verify we can query the messages results = await client.query_messages( agent_id=agent_id, + organization_id=org_id, search_mode="timestamp", top_k=10, ) @@ -1160,6 +1164,7 @@ class TestTurbopufferMessagesIntegration: query_embedding = [0.9, 0.0, 0.1] # Similar to Python messages results = await client.query_messages( agent_id=agent_id, + organization_id=org_id, query_embedding=query_embedding, search_mode="vector", top_k=2, @@ -1226,6 +1231,7 @@ class TestTurbopufferMessagesIntegration: # Hybrid search - vector similar to ML but text contains "quick" results = await client.query_messages( agent_id=agent_id, + organization_id=org_id, query_embedding=[0.7, 0.3, 0.0], # Similar to ML messages query_text="quick", # Text search for "quick" search_mode="hybrid", @@ -1291,6 +1297,7 @@ class TestTurbopufferMessagesIntegration: # Query only user messages user_results = await client.query_messages( agent_id=agent_id, + organization_id=org_id, search_mode="timestamp", top_k=10, roles=[MessageRole.user], @@ -1304,6 +1311,7 @@ class TestTurbopufferMessagesIntegration: # Query assistant and system messages non_user_results = await client.query_messages( agent_id=agent_id, + organization_id=org_id, search_mode="timestamp", top_k=10, roles=[MessageRole.assistant, MessageRole.system], @@ -1773,6 +1781,7 @@ class TestTurbopufferMessagesIntegration: three_days_ago = now - timedelta(days=3) recent_results = await client.query_messages( agent_id=agent_id, + organization_id=org_id, search_mode="timestamp", top_k=10, start_date=three_days_ago, @@ -1788,6 +1797,7 @@ class TestTurbopufferMessagesIntegration: two_weeks_ago = now - timedelta(days=14) week_results = await client.query_messages( agent_id=agent_id, + organization_id=org_id, search_mode="timestamp", top_k=10, start_date=two_weeks_ago, @@ -1801,6 +1811,7 @@ class TestTurbopufferMessagesIntegration: # Query with vector search and date filtering filtered_vector_results = await client.query_messages( agent_id=agent_id, + organization_id=org_id, query_embedding=[1.0, 2.0, 3.0], search_mode="vector", top_k=10, @@ -1847,17 +1858,17 @@ class TestNamespaceTracking: async def test_agent_namespace_tracking(self, server, default_user, sarah_agent, enable_message_embedding): """Test that agent message namespaces are properly tracked in database""" # Get namespace - should be generated and stored - namespace = await server.agent_manager.get_or_set_vector_db_namespace_async(sarah_agent.id) + namespace = await server.agent_manager.get_or_set_vector_db_namespace_async(sarah_agent.id, default_user.organization_id) - # Should have messages_ prefix and environment suffix + # Should have messages_org_ prefix and environment suffix expected_prefix = "messages_" assert namespace.startswith(expected_prefix) - assert sarah_agent.id in namespace + assert default_user.organization_id in namespace if settings.environment: assert settings.environment.lower() in namespace # Call again - should return same namespace from database - namespace2 = await server.agent_manager.get_or_set_vector_db_namespace_async(sarah_agent.id) + namespace2 = await server.agent_manager.get_or_set_vector_db_namespace_async(sarah_agent.id, default_user.organization_id) assert namespace == namespace2 @pytest.mark.asyncio