fix: dont let message ids diverge in memory vs db (#6537)
This commit is contained in:
@@ -636,7 +636,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
|
||||
tool_calls=tool_calls,
|
||||
valid_tool_names=[tool["name"] for tool in valid_tools],
|
||||
agent_state=self.agent_state,
|
||||
tool_rules_solver=self.tool_rules_solver,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=self.usage.completion_tokens,
|
||||
@@ -782,7 +781,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
async def _handle_ai_response(
|
||||
self,
|
||||
valid_tool_names: list[str],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
usage: UsageStatistics,
|
||||
content: list[TextContent | ReasoningContent | RedactedReasoningContent | OmittedReasoningContent] | None = None,
|
||||
@@ -809,7 +807,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# Case 1a: No tool call, no content (LLM no-op)
|
||||
if content is None or len(content) == 0:
|
||||
# Check if there are required-before-exit tools that haven't been called
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in self.agent_state.tools]))
|
||||
if uncalled:
|
||||
heartbeat_reason = (
|
||||
f"{NON_USER_MSG_PREFIX}ToolRuleViolated: You must call {', '.join(uncalled)} at least once to exit the loop."
|
||||
@@ -817,10 +815,10 @@ class LettaAgentV3(LettaAgentV2):
|
||||
from letta.server.rest_api.utils import create_heartbeat_system_message
|
||||
|
||||
heartbeat_msg = create_heartbeat_system_message(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.agent_state.llm_config.model,
|
||||
function_call_success=True,
|
||||
timezone=agent_state.timezone,
|
||||
timezone=self.agent_state.timezone,
|
||||
heartbeat_reason=heartbeat_reason,
|
||||
run_id=run_id,
|
||||
)
|
||||
@@ -835,21 +833,21 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# Case 1b: No tool call but has content
|
||||
else:
|
||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||
agent_state=agent_state,
|
||||
agent_state=self.agent_state,
|
||||
tool_call_name=None,
|
||||
tool_rule_violated=False,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
is_final_step=is_final_step,
|
||||
)
|
||||
assistant_message = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.agent_state.llm_config.model,
|
||||
function_name=None,
|
||||
function_arguments=None,
|
||||
tool_execution_result=None,
|
||||
tool_call_id=None,
|
||||
function_response=None,
|
||||
timezone=agent_state.timezone,
|
||||
timezone=self.agent_state.timezone,
|
||||
continue_stepping=continue_stepping,
|
||||
heartbeat_reason=heartbeat_reason,
|
||||
reasoning_content=content,
|
||||
@@ -870,7 +868,11 @@ class LettaAgentV3(LettaAgentV2):
|
||||
message.step_id = step_id
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist, actor=self.actor, run_id=run_id, project_id=agent_state.project_id, template_id=agent_state.template_id
|
||||
messages_to_persist,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
|
||||
@@ -880,8 +882,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
allowed_tool_calls = [t for t in tool_calls if not tool_rules_solver.is_requires_approval_tool(t.function.name)]
|
||||
if requested_tool_calls:
|
||||
approval_messages = create_approval_request_message_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.agent_state.llm_config.model,
|
||||
requested_tool_calls=requested_tool_calls,
|
||||
allowed_tool_calls=allowed_tool_calls,
|
||||
reasoning_content=content,
|
||||
@@ -901,8 +903,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
messages_to_persist,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=agent_state.project_id,
|
||||
template_id=agent_state.template_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
return persisted_messages, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
|
||||
|
||||
@@ -955,7 +957,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
denial_returns = create_tool_returns_for_denials(
|
||||
tool_calls=[tool_call_denial],
|
||||
denial_reason=tool_call_denial.reason,
|
||||
timezone=agent_state.timezone,
|
||||
timezone=self.agent_state.timezone,
|
||||
)
|
||||
result_tool_returns.extend(denial_returns)
|
||||
|
||||
@@ -964,7 +966,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# 5a. Validate parallel tool calling constraints
|
||||
if len(tool_calls) > 1:
|
||||
# No parallel tool calls with tool rules
|
||||
if agent_state.tool_rules and len([r for r in agent_state.tool_rules if r.type != "requires_approval"]) > 0:
|
||||
if self.agent_state.tool_rules and len([r for r in self.agent_state.tool_rules if r.type != "requires_approval"]) > 0:
|
||||
raise ValueError(
|
||||
"Parallel tool calling is not allowed when tool rules are present. Disable tool rules to use parallel tool calls."
|
||||
)
|
||||
@@ -985,7 +987,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if not tool_rule_violated:
|
||||
prefill_args = tool_rules_solver.last_prefilled_args_by_tool.get(name)
|
||||
if prefill_args:
|
||||
target_tool = next((t for t in agent_state.tools if t.name == name), None)
|
||||
target_tool = next((t for t in self.agent_state.tools if t.name == name), None)
|
||||
provenance = tool_rules_solver.last_prefilled_args_provenance.get(name)
|
||||
try:
|
||||
args = merge_and_validate_prefilled_args(
|
||||
@@ -1028,11 +1030,11 @@ class LettaAgentV3(LettaAgentV2):
|
||||
result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver)
|
||||
return result, 0
|
||||
t0 = get_utc_timestamp_ns()
|
||||
target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None)
|
||||
target_tool = next((x for x in self.agent_state.tools if x.name == spec["name"]), None)
|
||||
res = await self._execute_tool(
|
||||
target_tool=target_tool,
|
||||
tool_args=spec["args"],
|
||||
agent_state=agent_state,
|
||||
agent_state=self.agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
step_id=step_id,
|
||||
)
|
||||
@@ -1047,7 +1049,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
serial_items = []
|
||||
|
||||
for idx, spec in enumerate(exec_specs):
|
||||
target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None)
|
||||
target_tool = next((x for x in self.agent_state.tools if x.name == spec["name"]), None)
|
||||
if target_tool and target_tool.enable_parallel_execution:
|
||||
parallel_items.append((idx, spec))
|
||||
else:
|
||||
@@ -1078,7 +1080,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
# Validate and format function response
|
||||
truncate = spec["name"] not in {"conversation_search", "conversation_search_date", "archival_memory_search"}
|
||||
return_char_limit = next((t.return_char_limit for t in agent_state.tools if t.name == spec["name"]), None)
|
||||
return_char_limit = next((t.return_char_limit for t in self.agent_state.tools if t.name == spec["name"]), None)
|
||||
function_response_string = validate_function_response(
|
||||
tool_execution_result.func_return,
|
||||
return_char_limit=return_char_limit,
|
||||
@@ -1090,7 +1092,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self.last_function_response = package_function_response(
|
||||
was_success=tool_execution_result.success_flag,
|
||||
response_string=function_response_string,
|
||||
timezone=agent_state.timezone,
|
||||
timezone=self.agent_state.timezone,
|
||||
)
|
||||
|
||||
# Register successful tool call with solver
|
||||
@@ -1104,7 +1106,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
sr = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
|
||||
else:
|
||||
cont, hb_reason, sr = self._decide_continuation(
|
||||
agent_state=agent_state,
|
||||
agent_state=self.agent_state,
|
||||
tool_call_name=spec["name"],
|
||||
tool_rule_violated=spec["violated"],
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
@@ -1119,12 +1121,12 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
# Use the parallel message creation function for both single and multiple tools
|
||||
parallel_messages = create_parallel_tool_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
agent_id=self.agent_state.id,
|
||||
model=self.agent_state.llm_config.model,
|
||||
tool_call_specs=tool_call_specs,
|
||||
tool_execution_results=tool_execution_results,
|
||||
function_responses=function_responses,
|
||||
timezone=agent_state.timezone,
|
||||
timezone=self.agent_state.timezone,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
reasoning_content=content,
|
||||
@@ -1147,8 +1149,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
messages_to_persist,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=agent_state.project_id,
|
||||
template_id=agent_state.template_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
|
||||
# Update message_ids immediately after persistence to prevent desync
|
||||
@@ -1162,9 +1164,9 @@ class LettaAgentV3(LettaAgentV2):
|
||||
and persisted_messages[0].role == "approval"
|
||||
and persisted_messages[1].role == "tool"
|
||||
):
|
||||
agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages[:2]]
|
||||
self.agent_state.message_ids = self.agent_state.message_ids + [m.id for m in persisted_messages[:2]]
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=self.actor
|
||||
agent_id=self.agent_state.id, message_ids=self.agent_state.message_ids, actor=self.actor
|
||||
)
|
||||
|
||||
# 5g. Aggregate continuation decisions
|
||||
|
||||
Reference in New Issue
Block a user