feat: Add get valid tools helper pure function [LET-4439] (#4771)

Add get valid tools helper function
This commit is contained in:
Matthew Zhou
2025-09-18 12:16:53 -07:00
committed by Caren Thomas
parent 2c89b24021
commit a3925e6a7b
2 changed files with 18 additions and 15 deletions

View File

@@ -8,7 +8,9 @@ from letta.errors import PendingApprovalError
from letta.helpers import ToolRulesSolver
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate, MessageCreateBase
@@ -266,3 +268,17 @@ def _build_rule_violation_result(tool_name: str, valid: list[str], solver: ToolR
hint_txt = ("\n** Hint: Possible rules that were violated:\n" + "\n".join(f"\t- {h}" for h in hint_lines)) if hint_lines else ""
msg = f"[ToolConstraintError] Cannot call {tool_name}, valid tools include: {valid}.{hint_txt}"
return ToolExecutionResult(status="error", func_return=msg)
def _load_last_function_response(in_context_messages: list[Message]):
"""Load the last function response from message history"""
for msg in reversed(in_context_messages):
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
text_content = msg.content[0].text
try:
response_json = json.loads(text_content)
if response_json.get("message"):
return response_json["message"]
except (json.JSONDecodeError, KeyError):
raise ValueError(f"Invalid JSON format in message: {text_content}")
return None

View File

@@ -13,6 +13,7 @@ from letta.agents.base_agent_v2 import BaseAgentV2
from letta.agents.ephemeral_summary_agent import EphemeralSummaryAgent
from letta.agents.helpers import (
_build_rule_violation_result,
_load_last_function_response,
_pop_heartbeat,
_prepare_in_context_messages_no_persist_async,
_safe_load_tool_call_str,
@@ -374,7 +375,7 @@ class LettaAgentV2(BaseAgentV2):
None,
)
try:
self.last_function_response = self._load_last_function_response(messages)
self.last_function_response = _load_last_function_response(messages)
valid_tools = await self._get_valid_tools()
approval_request, approval_response = await self._maybe_get_approval_messages(messages)
if approval_request and approval_response:
@@ -740,20 +741,6 @@ class LettaAgentV2(BaseAgentV2):
)
return allowed_tools
@trace_method
def _load_last_function_response(self, in_context_messages: list[Message]):
"""Load the last function response from message history"""
for msg in reversed(in_context_messages):
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
text_content = msg.content[0].text
try:
response_json = json.loads(text_content)
if response_json.get("message"):
return response_json["message"]
except (json.JSONDecodeError, KeyError):
raise ValueError(f"Invalid JSON format in message: {text_content}")
return None
@trace_method
def _request_checkpoint_start(self, request_start_timestamp_ns: int | None) -> Span | None:
if request_start_timestamp_ns is not None: