feat: consolidate message persistence (#2518)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -98,6 +98,40 @@ async def _prepare_in_context_messages_async(
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
|
||||
|
||||
async def _prepare_in_context_messages_no_persist_async(
|
||||
input_messages: List[MessageCreate],
|
||||
agent_state: AgentState,
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
|
||||
Args:
|
||||
input_messages (List[MessageCreate]): The new user input messages to process.
|
||||
agent_state (AgentState): The current state of the agent, including message buffer config.
|
||||
message_manager (MessageManager): The manager used to retrieve and create messages.
|
||||
actor (User): The user performing the action, used for access control and attribution.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Message], List[Message]]: A tuple containing:
|
||||
- The current in-context messages (existing context for the agent).
|
||||
- The new in-context messages (messages created from the new input).
|
||||
"""
|
||||
|
||||
if agent_state.message_buffer_autoclear:
|
||||
# If autoclear is enabled, only include the most recent system message (usually at index 0)
|
||||
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)]
|
||||
else:
|
||||
# Otherwise, include the full list of messages by ID for context
|
||||
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
# Create a new user message from the input but dont store it yet
|
||||
new_in_context_messages = create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
|
||||
|
||||
def serialize_message_history(messages: List[str], context: str) -> str:
|
||||
"""
|
||||
Produce an XML document like:
|
||||
|
||||
@@ -8,7 +8,12 @@ from openai.types.chat import ChatCompletionChunk
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async, generate_step_id
|
||||
from letta.agents.helpers import (
|
||||
_create_letta_response,
|
||||
_prepare_in_context_messages_async,
|
||||
_prepare_in_context_messages_no_persist_async,
|
||||
generate_step_id,
|
||||
)
|
||||
from letta.errors import LLMContextWindowExceededError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
@@ -173,7 +178,11 @@ class LettaAgent(BaseAgent):
|
||||
reasoning = None
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning
|
||||
tool_call,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -272,7 +281,12 @@ class LettaAgent(BaseAgent):
|
||||
reasoning = None
|
||||
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call, agent_state, tool_rules_solver, response.usage, reasoning_content=reasoning, step_id=step_id
|
||||
tool_call,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
response.usage,
|
||||
reasoning_content=reasoning,
|
||||
step_id=step_id,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
@@ -323,10 +337,15 @@ class LettaAgent(BaseAgent):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
agent_id=self.agent_id, include_relationships=["tools", "memory"], actor=self.actor
|
||||
)
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
|
||||
# Special strategy to lower TTFT
|
||||
# Delay persistence of the initial input message as much as possible
|
||||
persisted_input_messages = False
|
||||
initial_messages = new_in_context_messages
|
||||
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
@@ -388,6 +407,12 @@ class LettaAgent(BaseAgent):
|
||||
usage.prompt_tokens += interface.input_tokens
|
||||
usage.total_tokens += interface.input_tokens + interface.output_tokens
|
||||
|
||||
# Persist input messages if not already
|
||||
# Special strategy to lower TTFT
|
||||
if not persisted_input_messages:
|
||||
await self.message_manager.create_many_messages_async(initial_messages, actor=self.actor)
|
||||
persisted_input_messages = True
|
||||
|
||||
# Process resulting stream content
|
||||
tool_call = interface.get_tool_call_object()
|
||||
reasoning_content = interface.get_reasoning_content()
|
||||
@@ -698,6 +723,7 @@ class LettaAgent(BaseAgent):
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
new_in_context_messages: Optional[List[Message]] = None,
|
||||
) -> Tuple[List[Message], bool]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
|
||||
Reference in New Issue
Block a user