From 002018345940e876cebce221a0ea1f8139a5f046 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 29 May 2025 16:23:13 -0700 Subject: [PATCH] feat: consolidate message persistence (#2518) Co-authored-by: Matt Zhou --- letta/agents/helpers.py | 34 ++++++++++++++++++++++++++++++++++ letta/agents/letta_agent.py | 34 ++++++++++++++++++++++++++++++---- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 071d3029..0d653813 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -98,6 +98,40 @@ async def _prepare_in_context_messages_async( return current_in_context_messages, new_in_context_messages +async def _prepare_in_context_messages_no_persist_async( + input_messages: List[MessageCreate], + agent_state: AgentState, + message_manager: MessageManager, + actor: User, +) -> Tuple[List[Message], List[Message]]: + """ + Prepares in-context messages for an agent, based on the current state and a new user input. + + Args: + input_messages (List[MessageCreate]): The new user input messages to process. + agent_state (AgentState): The current state of the agent, including message buffer config. + message_manager (MessageManager): The manager used to retrieve and create messages. + actor (User): The user performing the action, used for access control and attribution. + + Returns: + Tuple[List[Message], List[Message]]: A tuple containing: + - The current in-context messages (existing context for the agent). + - The new in-context messages (messages created from the new input). + """ + + if agent_state.message_buffer_autoclear: + # If autoclear is enabled, only include the most recent system message (usually at index 0) + current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)] + else: + # Otherwise, include the full list of messages by ID for context + current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) + + # Create a new user message from the input but dont store it yet + new_in_context_messages = create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=actor) + + return current_in_context_messages, new_in_context_messages + + def serialize_message_history(messages: List[str], context: str) -> str: """ Produce an XML document like: diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index bd5d5821..7b86afc5 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -8,7 +8,12 @@ from openai.types.chat import ChatCompletionChunk from letta.agents.base_agent import BaseAgent from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent -from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id +from letta.agents.helpers import ( + _create_letta_response, + _prepare_in_context_messages_async, + _prepare_in_context_messages_no_persist_async, + generate_step_id, +) from letta.errors import LLMContextWindowExceededError from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_timestamp_ns @@ -173,7 +178,11 @@ class LettaAgent(BaseAgent): reasoning = None persisted_messages, should_continue = await self._handle_ai_response( - tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning + tool_call, + agent_state, + tool_rules_solver, + response.usage, + reasoning_content=reasoning, ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -272,7 +281,12 @@ class LettaAgent(BaseAgent): reasoning = None persisted_messages, should_continue = await self._handle_ai_response( - tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning, step_id=step_id + tool_call, + agent_state, + tool_rules_solver, + response.usage, + reasoning_content=reasoning, + step_id=step_id, ) self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) @@ -323,10 +337,15 @@ class LettaAgent(BaseAgent): agent_state = await self.agent_manager.get_agent_by_id_async( agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor ) - current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( + current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async( input_messages, agent_state, self.message_manager, self.actor ) + # Special strategy to lower TTFT + # Delay persistence of the initial input message as much as possible + persisted_input_messages = False + initial_messages = new_in_context_messages + tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( provider_type=agent_state.llm_config.model_endpoint_type, @@ -388,6 +407,12 @@ class LettaAgent(BaseAgent): usage.prompt_tokens += interface.input_tokens usage.total_tokens += interface.input_tokens + interface.output_tokens + # Persist input messages if not already + # Special strategy to lower TTFT + if not persisted_input_messages: + await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor) + persisted_input_messages = True + # Process resulting stream content tool_call = interface.get_tool_call_object() reasoning_content = interface.get_reasoning_content() @@ -698,6 +723,7 @@ class LettaAgent(BaseAgent): pre_computed_assistant_message_id: Optional[str] = None, pre_computed_tool_message_id: Optional[str] = None, step_id: str | None = None, + new_in_context_messages: Optional[List[Message]] = None, ) -> Tuple[List[Message], bool]: """ Now that streaming is done, handle the final AI response.