From 4916d281ce0a5e01d1bb2910e52444a67322a476 Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 5 Dec 2025 22:54:58 -0800 Subject: [PATCH] fix: dont let message ids diverge in memory vs db (#6537) --- letta/agents/letta_agent_v3.py | 64 ++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 58960d6b..c43cda09 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -636,7 +636,6 @@ class LettaAgentV3(LettaAgentV2): persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response( tool_calls=tool_calls, valid_tool_names=[tool["name"] for tool in valid_tools], - agent_state=self.agent_state, tool_rules_solver=self.tool_rules_solver, usage=UsageStatistics( completion_tokens=self.usage.completion_tokens, @@ -782,7 +781,6 @@ class LettaAgentV3(LettaAgentV2): async def _handle_ai_response( self, valid_tool_names: list[str], - agent_state: AgentState, tool_rules_solver: ToolRulesSolver, usage: UsageStatistics, content: list[TextContent | ReasoningContent | RedactedReasoningContent | OmittedReasoningContent] | None = None, @@ -809,7 +807,7 @@ class LettaAgentV3(LettaAgentV2): # Case 1a: No tool call, no content (LLM no-op) if content is None or len(content) == 0: # Check if there are required-before-exit tools that haven't been called - uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools])) + uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in self.agent_state.tools])) if uncalled: heartbeat_reason = ( f"{NON_USER_MSG_PREFIX}ToolRuleViolated: You must call {', '.join(uncalled)} at least once to exit the loop." @@ -817,10 +815,10 @@ class LettaAgentV3(LettaAgentV2): from letta.server.rest_api.utils import create_heartbeat_system_message heartbeat_msg = create_heartbeat_system_message( - agent_id=agent_state.id, - model=agent_state.llm_config.model, + agent_id=self.agent_state.id, + model=self.agent_state.llm_config.model, function_call_success=True, - timezone=agent_state.timezone, + timezone=self.agent_state.timezone, heartbeat_reason=heartbeat_reason, run_id=run_id, ) @@ -835,21 +833,21 @@ class LettaAgentV3(LettaAgentV2): # Case 1b: No tool call but has content else: continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation( - agent_state=agent_state, + agent_state=self.agent_state, tool_call_name=None, tool_rule_violated=False, tool_rules_solver=tool_rules_solver, is_final_step=is_final_step, ) assistant_message = create_letta_messages_from_llm_response( - agent_id=agent_state.id, - model=agent_state.llm_config.model, + agent_id=self.agent_state.id, + model=self.agent_state.llm_config.model, function_name=None, function_arguments=None, tool_execution_result=None, tool_call_id=None, function_response=None, - timezone=agent_state.timezone, + timezone=self.agent_state.timezone, continue_stepping=continue_stepping, heartbeat_reason=heartbeat_reason, reasoning_content=content, @@ -870,7 +868,11 @@ class LettaAgentV3(LettaAgentV2): message.step_id = step_id persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id + messages_to_persist, + actor=self.actor, + run_id=run_id, + project_id=self.agent_state.project_id, + template_id=self.agent_state.template_id, ) return persisted_messages, continue_stepping, stop_reason @@ -880,8 +882,8 @@ class LettaAgentV3(LettaAgentV2): allowed_tool_calls = [t for t in tool_calls if not tool_rules_solver.is_requires_approval_tool(t.function.name)] if requested_tool_calls: approval_messages = create_approval_request_message_from_llm_response( - agent_id=agent_state.id, - model=agent_state.llm_config.model, + agent_id=self.agent_state.id, + model=self.agent_state.llm_config.model, requested_tool_calls=requested_tool_calls, allowed_tool_calls=allowed_tool_calls, reasoning_content=content, @@ -901,8 +903,8 @@ class LettaAgentV3(LettaAgentV2): messages_to_persist, actor=self.actor, run_id=run_id, - project_id=agent_state.project_id, - template_id=agent_state.template_id, + project_id=self.agent_state.project_id, + template_id=self.agent_state.template_id, ) return persisted_messages, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value) @@ -955,7 +957,7 @@ class LettaAgentV3(LettaAgentV2): denial_returns = create_tool_returns_for_denials( tool_calls=[tool_call_denial], denial_reason=tool_call_denial.reason, - timezone=agent_state.timezone, + timezone=self.agent_state.timezone, ) result_tool_returns.extend(denial_returns) @@ -964,7 +966,7 @@ class LettaAgentV3(LettaAgentV2): # 5a. Validate parallel tool calling constraints if len(tool_calls) > 1: # No parallel tool calls with tool rules - if agent_state.tool_rules and len([r for r in agent_state.tool_rules if r.type != "requires_approval"]) > 0: + if self.agent_state.tool_rules and len([r for r in self.agent_state.tool_rules if r.type != "requires_approval"]) > 0: raise ValueError( "Parallel tool calling is not allowed when tool rules are present. Disable tool rules to use parallel tool calls." ) @@ -985,7 +987,7 @@ class LettaAgentV3(LettaAgentV2): if not tool_rule_violated: prefill_args = tool_rules_solver.last_prefilled_args_by_tool.get(name) if prefill_args: - target_tool = next((t for t in agent_state.tools if t.name == name), None) + target_tool = next((t for t in self.agent_state.tools if t.name == name), None) provenance = tool_rules_solver.last_prefilled_args_provenance.get(name) try: args = merge_and_validate_prefilled_args( @@ -1028,11 +1030,11 @@ class LettaAgentV3(LettaAgentV2): result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver) return result, 0 t0 = get_utc_timestamp_ns() - target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None) + target_tool = next((x for x in self.agent_state.tools if x.name == spec["name"]), None) res = await self._execute_tool( target_tool=target_tool, tool_args=spec["args"], - agent_state=agent_state, + agent_state=self.agent_state, agent_step_span=agent_step_span, step_id=step_id, ) @@ -1047,7 +1049,7 @@ class LettaAgentV3(LettaAgentV2): serial_items = [] for idx, spec in enumerate(exec_specs): - target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None) + target_tool = next((x for x in self.agent_state.tools if x.name == spec["name"]), None) if target_tool and target_tool.enable_parallel_execution: parallel_items.append((idx, spec)) else: @@ -1078,7 +1080,7 @@ class LettaAgentV3(LettaAgentV2): # Validate and format function response truncate = spec["name"] not in {"conversation_search", "conversation_search_date", "archival_memory_search"} - return_char_limit = next((t.return_char_limit for t in agent_state.tools if t.name == spec["name"]), None) + return_char_limit = next((t.return_char_limit for t in self.agent_state.tools if t.name == spec["name"]), None) function_response_string = validate_function_response( tool_execution_result.func_return, return_char_limit=return_char_limit, @@ -1090,7 +1092,7 @@ class LettaAgentV3(LettaAgentV2): self.last_function_response = package_function_response( was_success=tool_execution_result.success_flag, response_string=function_response_string, - timezone=agent_state.timezone, + timezone=self.agent_state.timezone, ) # Register successful tool call with solver @@ -1104,7 +1106,7 @@ class LettaAgentV3(LettaAgentV2): sr = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value) else: cont, hb_reason, sr = self._decide_continuation( - agent_state=agent_state, + agent_state=self.agent_state, tool_call_name=spec["name"], tool_rule_violated=spec["violated"], tool_rules_solver=tool_rules_solver, @@ -1119,12 +1121,12 @@ class LettaAgentV3(LettaAgentV2): # Use the parallel message creation function for both single and multiple tools parallel_messages = create_parallel_tool_messages_from_llm_response( - agent_id=agent_state.id, - model=agent_state.llm_config.model, + agent_id=self.agent_state.id, + model=self.agent_state.llm_config.model, tool_call_specs=tool_call_specs, tool_execution_results=tool_execution_results, function_responses=function_responses, - timezone=agent_state.timezone, + timezone=self.agent_state.timezone, run_id=run_id, step_id=step_id, reasoning_content=content, @@ -1147,8 +1149,8 @@ class LettaAgentV3(LettaAgentV2): messages_to_persist, actor=self.actor, run_id=run_id, - project_id=agent_state.project_id, - template_id=agent_state.template_id, + project_id=self.agent_state.project_id, + template_id=self.agent_state.template_id, ) # Update message_ids immediately after persistence to prevent desync @@ -1162,9 +1164,9 @@ class LettaAgentV3(LettaAgentV2): and persisted_messages[0].role == "approval" and persisted_messages[1].role == "tool" ): - agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages[:2]] + self.agent_state.message_ids = self.agent_state.message_ids + [m.id for m in persisted_messages[:2]] await self.agent_manager.update_message_ids_async( - agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=self.actor + agent_id=self.agent_state.id, message_ids=self.agent_state.message_ids, actor=self.actor ) # 5g. Aggregate continuation decisions