fix: dont let message ids diverge in memory vs db (#6537)

This commit is contained in:
cthomas
2025-12-05 22:54:58 -08:00
committed by Caren Thomas
parent 74e0172efe
commit 4916d281ce

View File

@@ -636,7 +636,6 @@ class LettaAgentV3(LettaAgentV2):
persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
tool_calls=tool_calls,
valid_tool_names=[tool["name"] for tool in valid_tools],
agent_state=self.agent_state,
tool_rules_solver=self.tool_rules_solver,
usage=UsageStatistics(
completion_tokens=self.usage.completion_tokens,
@@ -782,7 +781,6 @@ class LettaAgentV3(LettaAgentV2):
async def _handle_ai_response(
self,
valid_tool_names: list[str],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
usage: UsageStatistics,
content: list[TextContent | ReasoningContent | RedactedReasoningContent | OmittedReasoningContent] | None = None,
@@ -809,7 +807,7 @@ class LettaAgentV3(LettaAgentV2):
# Case 1a: No tool call, no content (LLM no-op)
if content is None or len(content) == 0:
# Check if there are required-before-exit tools that haven't been called
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in self.agent_state.tools]))
if uncalled:
heartbeat_reason = (
f"{NON_USER_MSG_PREFIX}ToolRuleViolated: You must call {', '.join(uncalled)} at least once to exit the loop."
@@ -817,10 +815,10 @@ class LettaAgentV3(LettaAgentV2):
from letta.server.rest_api.utils import create_heartbeat_system_message
heartbeat_msg = create_heartbeat_system_message(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
agent_id=self.agent_state.id,
model=self.agent_state.llm_config.model,
function_call_success=True,
timezone=agent_state.timezone,
timezone=self.agent_state.timezone,
heartbeat_reason=heartbeat_reason,
run_id=run_id,
)
@@ -835,21 +833,21 @@ class LettaAgentV3(LettaAgentV2):
# Case 1b: No tool call but has content
else:
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
agent_state=agent_state,
agent_state=self.agent_state,
tool_call_name=None,
tool_rule_violated=False,
tool_rules_solver=tool_rules_solver,
is_final_step=is_final_step,
)
assistant_message = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
agent_id=self.agent_state.id,
model=self.agent_state.llm_config.model,
function_name=None,
function_arguments=None,
tool_execution_result=None,
tool_call_id=None,
function_response=None,
timezone=agent_state.timezone,
timezone=self.agent_state.timezone,
continue_stepping=continue_stepping,
heartbeat_reason=heartbeat_reason,
reasoning_content=content,
@@ -870,7 +868,11 @@ class LettaAgentV3(LettaAgentV2):
message.step_id = step_id
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
messages_to_persist,
actor=self.actor,
run_id=run_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
return persisted_messages, continue_stepping, stop_reason
@@ -880,8 +882,8 @@ class LettaAgentV3(LettaAgentV2):
allowed_tool_calls = [t for t in tool_calls if not tool_rules_solver.is_requires_approval_tool(t.function.name)]
if requested_tool_calls:
approval_messages = create_approval_request_message_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
agent_id=self.agent_state.id,
model=self.agent_state.llm_config.model,
requested_tool_calls=requested_tool_calls,
allowed_tool_calls=allowed_tool_calls,
reasoning_content=content,
@@ -901,8 +903,8 @@ class LettaAgentV3(LettaAgentV2):
messages_to_persist,
actor=self.actor,
run_id=run_id,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
return persisted_messages, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
@@ -955,7 +957,7 @@ class LettaAgentV3(LettaAgentV2):
denial_returns = create_tool_returns_for_denials(
tool_calls=[tool_call_denial],
denial_reason=tool_call_denial.reason,
timezone=agent_state.timezone,
timezone=self.agent_state.timezone,
)
result_tool_returns.extend(denial_returns)
@@ -964,7 +966,7 @@ class LettaAgentV3(LettaAgentV2):
# 5a. Validate parallel tool calling constraints
if len(tool_calls) > 1:
# No parallel tool calls with tool rules
if agent_state.tool_rules and len([r for r in agent_state.tool_rules if r.type != "requires_approval"]) > 0:
if self.agent_state.tool_rules and len([r for r in self.agent_state.tool_rules if r.type != "requires_approval"]) > 0:
raise ValueError(
"Parallel tool calling is not allowed when tool rules are present. Disable tool rules to use parallel tool calls."
)
@@ -985,7 +987,7 @@ class LettaAgentV3(LettaAgentV2):
if not tool_rule_violated:
prefill_args = tool_rules_solver.last_prefilled_args_by_tool.get(name)
if prefill_args:
target_tool = next((t for t in agent_state.tools if t.name == name), None)
target_tool = next((t for t in self.agent_state.tools if t.name == name), None)
provenance = tool_rules_solver.last_prefilled_args_provenance.get(name)
try:
args = merge_and_validate_prefilled_args(
@@ -1028,11 +1030,11 @@ class LettaAgentV3(LettaAgentV2):
result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver)
return result, 0
t0 = get_utc_timestamp_ns()
target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None)
target_tool = next((x for x in self.agent_state.tools if x.name == spec["name"]), None)
res = await self._execute_tool(
target_tool=target_tool,
tool_args=spec["args"],
agent_state=agent_state,
agent_state=self.agent_state,
agent_step_span=agent_step_span,
step_id=step_id,
)
@@ -1047,7 +1049,7 @@ class LettaAgentV3(LettaAgentV2):
serial_items = []
for idx, spec in enumerate(exec_specs):
target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None)
target_tool = next((x for x in self.agent_state.tools if x.name == spec["name"]), None)
if target_tool and target_tool.enable_parallel_execution:
parallel_items.append((idx, spec))
else:
@@ -1078,7 +1080,7 @@ class LettaAgentV3(LettaAgentV2):
# Validate and format function response
truncate = spec["name"] not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
return_char_limit = next((t.return_char_limit for t in agent_state.tools if t.name == spec["name"]), None)
return_char_limit = next((t.return_char_limit for t in self.agent_state.tools if t.name == spec["name"]), None)
function_response_string = validate_function_response(
tool_execution_result.func_return,
return_char_limit=return_char_limit,
@@ -1090,7 +1092,7 @@ class LettaAgentV3(LettaAgentV2):
self.last_function_response = package_function_response(
was_success=tool_execution_result.success_flag,
response_string=function_response_string,
timezone=agent_state.timezone,
timezone=self.agent_state.timezone,
)
# Register successful tool call with solver
@@ -1104,7 +1106,7 @@ class LettaAgentV3(LettaAgentV2):
sr = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
else:
cont, hb_reason, sr = self._decide_continuation(
agent_state=agent_state,
agent_state=self.agent_state,
tool_call_name=spec["name"],
tool_rule_violated=spec["violated"],
tool_rules_solver=tool_rules_solver,
@@ -1119,12 +1121,12 @@ class LettaAgentV3(LettaAgentV2):
# Use the parallel message creation function for both single and multiple tools
parallel_messages = create_parallel_tool_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
agent_id=self.agent_state.id,
model=self.agent_state.llm_config.model,
tool_call_specs=tool_call_specs,
tool_execution_results=tool_execution_results,
function_responses=function_responses,
timezone=agent_state.timezone,
timezone=self.agent_state.timezone,
run_id=run_id,
step_id=step_id,
reasoning_content=content,
@@ -1147,8 +1149,8 @@ class LettaAgentV3(LettaAgentV2):
messages_to_persist,
actor=self.actor,
run_id=run_id,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
# Update message_ids immediately after persistence to prevent desync
@@ -1162,9 +1164,9 @@ class LettaAgentV3(LettaAgentV2):
and persisted_messages[0].role == "approval"
and persisted_messages[1].role == "tool"
):
agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages[:2]]
self.agent_state.message_ids = self.agent_state.message_ids + [m.id for m in persisted_messages[:2]]
await self.agent_manager.update_message_ids_async(
agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=self.actor
agent_id=self.agent_state.id, message_ids=self.agent_state.message_ids, actor=self.actor
)
# 5g. Aggregate continuation decisions