diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 7763d495..0a865539 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -8,7 +8,9 @@ from letta.errors import PendingApprovalError from letta.helpers import ToolRulesSolver from letta.log import get_logger from letta.schemas.agent import AgentState +from letta.schemas.enums import MessageRole from letta.schemas.letta_message import MessageType +from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message, MessageCreate, MessageCreateBase @@ -266,3 +268,17 @@ def _build_rule_violation_result(tool_name: str, valid: list[str], solver: ToolR hint_txt = ("\n** Hint: Possible rules that were violated:\n" + "\n".join(f"\t- {h}" for h in hint_lines)) if hint_lines else "" msg = f"[ToolConstraintError] Cannot call {tool_name}, valid tools include: {valid}.{hint_txt}" return ToolExecutionResult(status="error", func_return=msg) + + +def _load_last_function_response(in_context_messages: list[Message]): + """Load the last function response from message history""" + for msg in reversed(in_context_messages): + if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent): + text_content = msg.content[0].text + try: + response_json = json.loads(text_content) + if response_json.get("message"): + return response_json["message"] + except (json.JSONDecodeError, KeyError): + raise ValueError(f"Invalid JSON format in message: {text_content}") + return None diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 5093059e..0bd8fedc 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -13,6 +13,7 @@ from letta.agents.base_agent_v2 import BaseAgentV2 from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent from letta.agents.helpers import ( _build_rule_violation_result, + _load_last_function_response, _pop_heartbeat, _prepare_in_context_messages_no_persist_async, _safe_load_tool_call_str, @@ -374,7 +375,7 @@ class LettaAgentV2(BaseAgentV2): None, ) try: - self.last_function_response = self._load_last_function_response(messages) + self.last_function_response = _load_last_function_response(messages) valid_tools = await self._get_valid_tools() approval_request, approval_response = await self._maybe_get_approval_messages(messages) if approval_request and approval_response: @@ -740,20 +741,6 @@ class LettaAgentV2(BaseAgentV2): ) return allowed_tools - @trace_method - def _load_last_function_response(self, in_context_messages: list[Message]): - """Load the last function response from message history""" - for msg in reversed(in_context_messages): - if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent): - text_content = msg.content[0].text - try: - response_json = json.loads(text_content) - if response_json.get("message"): - return response_json["message"] - except (json.JSONDecodeError, KeyError): - raise ValueError(f"Invalid JSON format in message: {text_content}") - return None - @trace_method def _request_checkpoint_start(self, request_start_timestamp_ns: int | None) -> Span | None: if request_start_timestamp_ns is not None: