feat: Add errors when tool call violates tool rules (#2766)

This commit is contained in:
Matthew Zhou
2025-06-11 17:12:39 -07:00
committed by GitHub
parent 03f4867cbe
commit d0c2ef89ea
5 changed files with 137 additions and 101 deletions

View File

@@ -33,6 +33,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
@@ -170,12 +171,14 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm(
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
)
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -220,6 +223,7 @@ class LettaAgent(BaseAgent):
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
@@ -323,8 +327,10 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm(
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm(
current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver
)
)
in_context_messages = current_in_context_messages + new_in_context_messages
@@ -363,6 +369,7 @@ class LettaAgent(BaseAgent):
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
response.usage,
@@ -460,15 +467,17 @@ class LettaAgent(BaseAgent):
agent_step_span = tracer.start_span("agent_step", start_time=step_start)
agent_step_span.set_attributes({"step_id": step_id})
request_data, stream, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm_streaming(
first_chunk,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
request_data, stream, current_in_context_messages, new_in_context_messages, valid_tool_names = (
await self._build_and_request_from_llm_streaming(
first_chunk,
agent_step_span,
request_start_timestamp_ns,
current_in_context_messages,
new_in_context_messages,
agent_state,
llm_client,
tool_rules_solver,
)
)
log_event("agent.stream.llm_response.received") # [3^]
@@ -523,6 +532,7 @@ class LettaAgent(BaseAgent):
reasoning_content = interface.get_reasoning_content()
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
valid_tool_names,
agent_state,
tool_rules_solver,
UsageStatistics(
@@ -611,12 +621,12 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
) -> Tuple[Dict, Dict, List[Message], List[Message]]:
) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]]:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
# Create LLM request data
request_data = await self._create_llm_request_data_async(
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
@@ -631,12 +641,7 @@ class LettaAgent(BaseAgent):
dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}),
)
# Attempt LLM request
return (
request_data,
response,
current_in_context_messages,
new_in_context_messages,
)
return (request_data, response, current_in_context_messages, new_in_context_messages, valid_tool_names)
except Exception as e:
if attempt == self.max_summarization_retries:
@@ -664,12 +669,12 @@ class LettaAgent(BaseAgent):
agent_state: AgentState,
llm_client: LLMClientBase,
tool_rules_solver: ToolRulesSolver,
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message]]:
) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str]]:
for attempt in range(self.max_summarization_retries + 1):
try:
log_event("agent.stream_no_tokens.messages.refreshed")
# Create LLM request data
request_data = await self._create_llm_request_data_async(
request_data, valid_tool_names = await self._create_llm_request_data_async(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
@@ -688,6 +693,7 @@ class LettaAgent(BaseAgent):
await llm_client.stream_async(request_data, agent_state.llm_config),
current_in_context_messages,
new_in_context_messages,
valid_tool_names,
)
except Exception as e:
@@ -770,7 +776,7 @@ class LettaAgent(BaseAgent):
in_context_messages: List[Message],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
) -> dict:
) -> Tuple[dict, List[str]]:
self.num_messages, self.num_archival_memories = await asyncio.gather(
(
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
@@ -827,17 +833,21 @@ class LettaAgent(BaseAgent):
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
)
return llm_client.build_request_data(
in_context_messages,
agent_state.llm_config,
allowed_tools,
force_tool_call,
return (
llm_client.build_request_data(
in_context_messages,
agent_state.llm_config,
allowed_tools,
force_tool_call,
),
valid_tool_names,
)
@trace_method
async def _handle_ai_response(
self,
tool_call: ToolCall,
valid_tool_names: List[str],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
usage: UsageStatistics,
@@ -853,8 +863,10 @@ class LettaAgent(BaseAgent):
This might yield additional SSE tokens if we do stalling.
At the end, set self._continue_execution accordingly.
"""
# Check if the called tool is allowed by tool name:
tool_call_name = tool_call.function.name
tool_call_args_str = tool_call.function.arguments
# Temp hack to gracefully handle parallel tool calling attempt, only take first one
if "}{" in tool_call_args_str:
tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}"
@@ -893,14 +905,21 @@ class LettaAgent(BaseAgent):
tool_call_id=tool_call_id,
request_heartbeat=request_heartbeat,
)
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
tool_args=tool_args,
agent_state=agent_state,
agent_step_span=agent_step_span,
step_id=step_id,
)
if tool_call_name not in valid_tool_names:
base_error_message = f"[ToolConstraintError] Cannot call {tool_call_name}, valid tools to call include: {valid_tool_names}."
violated_rule_messages = tool_rules_solver.guess_rule_violation(tool_call_name)
if violated_rule_messages:
bullet_points = "\n".join(f"\t- {msg}" for msg in violated_rule_messages)
base_error_message += f"\n** Hint: Possible rules that were violated:\n{bullet_points}"
tool_execution_result = ToolExecutionResult(status="error", func_return=base_error_message)
else:
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
tool_args=tool_args,
agent_state=agent_state,
agent_step_span=agent_step_span,
step_id=step_id,
)
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id
)

View File

@@ -172,6 +172,40 @@ class ToolRulesSolver(BaseModel):
)
return None
def guess_rule_violation(self, tool_name: str) -> List[str]:
"""
Check if the given tool name or the previous tool in history matches any tool rule,
and return rendered prompt templates for matching rules.
Args:
tool_name: The name of the tool to check for rule violations
Returns:
List of rendered prompt templates from matching tool rules
"""
violated_rules = []
# Get the previous tool from history if it exists
previous_tool = self.tool_call_history[-1] if self.tool_call_history else None
# Check all tool rules for matches
all_rules = (
self.init_tool_rules
+ self.continue_tool_rules
+ self.child_based_tool_rules
+ self.parent_tool_rules
+ self.terminal_tool_rules
)
for rule in all_rules:
# Check if the current tool name or previous tool matches this rule's tool_name
if rule.tool_name == tool_name or (previous_tool and rule.tool_name == previous_tool):
rendered_prompt = rule.render_prompt()
if rendered_prompt:
violated_rules.append(rendered_prompt)
return violated_rules
@staticmethod
def validate_conditional_tool(rule: ConditionalToolRule):
"""

View File

@@ -25,19 +25,24 @@ class BaseToolRule(LettaBase):
def render_prompt(self) -> Optional[str]:
"""Render the prompt template with this rule's attributes."""
if not self.prompt_template:
template_to_use = self.prompt_template or self._get_default_template()
if not template_to_use:
return None
try:
template = Template(self.prompt_template)
template = Template(template_to_use)
return template.render(**self.model_dump())
except Exception as e:
logger.warning(
f"Failed to render prompt template for tool rule '{self.tool_name}' (type: {self.type}). "
f"Template: '{self.prompt_template}'. Error: {e}"
f"Template: '{template_to_use}'. Error: {e}"
)
return None
def _get_default_template(self) -> Optional[str]:
"""Get the default template for this rule type. Override in subclasses."""
return None
class ChildToolRule(BaseToolRule):
"""
@@ -46,11 +51,18 @@ class ChildToolRule(BaseToolRule):
type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools
children: List[str] = Field(..., description="The children tools that can be invoked.")
prompt_template: Optional[str] = Field(
default="<tool_constraint>After using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}</tool_constraint>",
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
)
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
last_tool = tool_call_history[-1] if tool_call_history else None
return set(self.children) if last_tool == self.tool_name else available_tools
def _get_default_template(self) -> Optional[str]:
return "<tool_constraint>After using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}</tool_constraint>"
class ParentToolRule(BaseToolRule):
"""
@@ -59,11 +71,18 @@ class ParentToolRule(BaseToolRule):
type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool
children: List[str] = Field(..., description="The children tools that can be invoked.")
prompt_template: Optional[str] = Field(
default="<tool_constraint>{{ children | join(', ') }} can only be used after {{ tool_name }}</tool_constraint>",
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
)
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
last_tool = tool_call_history[-1] if tool_call_history else None
return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children)
def _get_default_template(self) -> Optional[str]:
return "<tool_constraint>{{ children | join(', ') }} can only be used after {{ tool_name }}</tool_constraint>"
class ConditionalToolRule(BaseToolRule):
"""
@@ -74,6 +93,10 @@ class ConditionalToolRule(BaseToolRule):
default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.")
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")
prompt_template: Optional[str] = Field(
default="<tool_constraint>{{ tool_name }} will determine which tool to use next based on its output</tool_constraint>",
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
)
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
"""Determine valid tools based on function output mapping."""
@@ -119,6 +142,9 @@ class ConditionalToolRule(BaseToolRule):
else: # Assume string
return str(function_output) == str(key)
def _get_default_template(self) -> Optional[str]:
return "<tool_constraint>{{ tool_name }} will determine which tool to use next based on its output</tool_constraint>"
class InitToolRule(BaseToolRule):
"""
@@ -134,6 +160,13 @@ class TerminalToolRule(BaseToolRule):
"""
type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop
prompt_template: Optional[str] = Field(
default="<tool_constraint>{{ tool_name }} ends the conversation when called</tool_constraint>",
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
)
def _get_default_template(self) -> Optional[str]:
return "<tool_constraint>{{ tool_name }} ends the conversation when called</tool_constraint>"
class ContinueToolRule(BaseToolRule):
@@ -142,6 +175,10 @@ class ContinueToolRule(BaseToolRule):
"""
type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop
prompt_template: Optional[str] = Field(
default="<tool_constraint>{{ tool_name }} requires continuing the conversation when called</tool_constraint>",
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
)
class MaxCountPerStepToolRule(BaseToolRule):
@@ -166,6 +203,9 @@ class MaxCountPerStepToolRule(BaseToolRule):
return available_tools
def _get_default_template(self) -> Optional[str]:
return "<tool_constraint>{{ tool_name }}: max {{ max_count_limit }} use(s) per turn</tool_constraint>"
ToolRule = Annotated[
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule],

View File

@@ -520,51 +520,6 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState):
# assert len(responses) == len(messages), "Not all messages were processed"
def test_send_message_async(client: LettaSDKClient, agent: AgentState):
"""
Test that we can send a message asynchronously and retrieve the messages, along with usage statistics
"""
test_message = "This is a test message, respond to the user with a sentence."
run = client.agents.messages.create_async(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content=test_message,
),
],
use_assistant_message=False,
)
assert run.id is not None
assert run.status == "created"
# Wait for the job to complete, cancel it if takes over 10 seconds
start_time = time.time()
while run.status == "created":
time.sleep(1)
run = client.runs.retrieve(run_id=run.id)
print(f"Run status: {run.status}")
if time.time() - start_time > 10:
pytest.fail("Run took too long to complete")
print(f"Run completed in {time.time() - start_time} seconds, run={run}")
assert run.status == "completed"
# Get messages for the job
messages = client.runs.messages.list(run_id=run.id)
assert len(messages) >= 2 # At least assistant response
# Check filters
assistant_messages = client.runs.messages.list(run_id=run.id, role="assistant")
assert len(assistant_messages) > 0
tool_messages = client.runs.messages.list(run_id=run.id, role="tool")
assert len(tool_messages) > 0
# specific_tool_messages = [message for message in client.runs.list_run_messages(run_id=run.id) if isinstance(message, ToolCallMessage)]
# assert specific_tool_messages[0].tool_call.name == "send_message"
# assert len(specific_tool_messages) > 0
def test_agent_creation(client: LettaSDKClient):
"""Test that block IDs are properly attached when creating an agent."""
sleeptime_agent_system = """

View File

@@ -901,18 +901,6 @@ async def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, u
assert not result.stderr
def test_composio_client_simple(server):
apps = server.get_composio_apps()
# Assert there's some amount of apps returned
assert len(apps) > 0
app = apps[0]
actions = server.get_composio_actions_from_app_name(composio_app_name=app.name)
# Assert there's some amount of actions
assert len(actions) > 0
async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = user