From d0c2ef89eaebade4684f7d782f7c54efc57ed39f Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 11 Jun 2025 17:12:39 -0700 Subject: [PATCH] feat: Add errors when tool call violates tool rules (#2766) --- letta/agents/letta_agent.py | 101 ++++++++++++++++++------------ letta/helpers/tool_rule_solver.py | 34 ++++++++++ letta/schemas/tool_rule.py | 46 +++++++++++++- tests/test_sdk_client.py | 45 ------------- tests/test_server.py | 12 ---- 5 files changed, 137 insertions(+), 101 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 7d5149c3..6d40704f 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -33,6 +33,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.schemas.provider_trace import ProviderTraceCreate +from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.server.rest_api.utils import create_letta_messages_from_llm_response @@ -170,12 +171,14 @@ class LettaAgent(BaseAgent): agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span.set_attributes({"step_id": step_id}) - request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm( - current_in_context_messages, - new_in_context_messages, - agent_state, - llm_client, - tool_rules_solver, + request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( + await self._build_and_request_from_llm( + current_in_context_messages, + new_in_context_messages, + agent_state, + llm_client, + tool_rules_solver, + ) ) in_context_messages = current_in_context_messages + new_in_context_messages @@ -220,6 +223,7 @@ class LettaAgent(BaseAgent): persisted_messages, should_continue = await self._handle_ai_response( tool_call, + valid_tool_names, agent_state, tool_rules_solver, response.usage, @@ -323,8 +327,10 @@ class LettaAgent(BaseAgent): agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span.set_attributes({"step_id": step_id}) - request_data, response_data, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm( - current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver + request_data, response_data, current_in_context_messages, new_in_context_messages, valid_tool_names = ( + await self._build_and_request_from_llm( + current_in_context_messages, new_in_context_messages, agent_state, llm_client, tool_rules_solver + ) ) in_context_messages = current_in_context_messages + new_in_context_messages @@ -363,6 +369,7 @@ class LettaAgent(BaseAgent): persisted_messages, should_continue = await self._handle_ai_response( tool_call, + valid_tool_names, agent_state, tool_rules_solver, response.usage, @@ -460,15 +467,17 @@ class LettaAgent(BaseAgent): agent_step_span = tracer.start_span("agent_step", start_time=step_start) agent_step_span.set_attributes({"step_id": step_id}) - request_data, stream, current_in_context_messages, new_in_context_messages = await self._build_and_request_from_llm_streaming( - first_chunk, - agent_step_span, - request_start_timestamp_ns, - current_in_context_messages, - new_in_context_messages, - agent_state, - llm_client, - tool_rules_solver, + request_data, stream, current_in_context_messages, new_in_context_messages, valid_tool_names = ( + await self._build_and_request_from_llm_streaming( + first_chunk, + agent_step_span, + request_start_timestamp_ns, + current_in_context_messages, + new_in_context_messages, + agent_state, + llm_client, + tool_rules_solver, + ) ) log_event("agent.stream.llm_response.received") # [3^] @@ -523,6 +532,7 @@ class LettaAgent(BaseAgent): reasoning_content = interface.get_reasoning_content() persisted_messages, should_continue = await self._handle_ai_response( tool_call, + valid_tool_names, agent_state, tool_rules_solver, UsageStatistics( @@ -611,12 +621,12 @@ class LettaAgent(BaseAgent): agent_state: AgentState, llm_client: LLMClientBase, tool_rules_solver: ToolRulesSolver, - ) -> Tuple[Dict, Dict, List[Message], List[Message]]: + ) -> Tuple[Dict, Dict, List[Message], List[Message], List[str]]: for attempt in range(self.max_summarization_retries + 1): try: log_event("agent.stream_no_tokens.messages.refreshed") # Create LLM request data - request_data = await self._create_llm_request_data_async( + request_data, valid_tool_names = await self._create_llm_request_data_async( llm_client=llm_client, in_context_messages=current_in_context_messages + new_in_context_messages, agent_state=agent_state, @@ -631,12 +641,7 @@ class LettaAgent(BaseAgent): dict(get_ctx_attributes(), **{"model.name": agent_state.llm_config.model}), ) # Attempt LLM request - return ( - request_data, - response, - current_in_context_messages, - new_in_context_messages, - ) + return (request_data, response, current_in_context_messages, new_in_context_messages, valid_tool_names) except Exception as e: if attempt == self.max_summarization_retries: @@ -664,12 +669,12 @@ class LettaAgent(BaseAgent): agent_state: AgentState, llm_client: LLMClientBase, tool_rules_solver: ToolRulesSolver, - ) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message]]: + ) -> Tuple[Dict, AsyncStream[ChatCompletionChunk], List[Message], List[Message], List[str]]: for attempt in range(self.max_summarization_retries + 1): try: log_event("agent.stream_no_tokens.messages.refreshed") # Create LLM request data - request_data = await self._create_llm_request_data_async( + request_data, valid_tool_names = await self._create_llm_request_data_async( llm_client=llm_client, in_context_messages=current_in_context_messages + new_in_context_messages, agent_state=agent_state, @@ -688,6 +693,7 @@ class LettaAgent(BaseAgent): await llm_client.stream_async(request_data, agent_state.llm_config), current_in_context_messages, new_in_context_messages, + valid_tool_names, ) except Exception as e: @@ -770,7 +776,7 @@ class LettaAgent(BaseAgent): in_context_messages: List[Message], agent_state: AgentState, tool_rules_solver: ToolRulesSolver, - ) -> dict: + ) -> Tuple[dict, List[str]]: self.num_messages, self.num_archival_memories = await asyncio.gather( ( self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) @@ -827,17 +833,21 @@ class LettaAgent(BaseAgent): tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True ) - return llm_client.build_request_data( - in_context_messages, - agent_state.llm_config, - allowed_tools, - force_tool_call, + return ( + llm_client.build_request_data( + in_context_messages, + agent_state.llm_config, + allowed_tools, + force_tool_call, + ), + valid_tool_names, ) @trace_method async def _handle_ai_response( self, tool_call: ToolCall, + valid_tool_names: List[str], agent_state: AgentState, tool_rules_solver: ToolRulesSolver, usage: UsageStatistics, @@ -853,8 +863,10 @@ class LettaAgent(BaseAgent): This might yield additional SSE tokens if we do stalling. At the end, set self._continue_execution accordingly. """ + # Check if the called tool is allowed by tool name: tool_call_name = tool_call.function.name tool_call_args_str = tool_call.function.arguments + # Temp hack to gracefully handle parallel tool calling attempt, only take first one if "}{" in tool_call_args_str: tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}" @@ -893,14 +905,21 @@ class LettaAgent(BaseAgent): tool_call_id=tool_call_id, request_heartbeat=request_heartbeat, ) - - tool_execution_result = await self._execute_tool( - tool_name=tool_call_name, - tool_args=tool_args, - agent_state=agent_state, - agent_step_span=agent_step_span, - step_id=step_id, - ) + if tool_call_name not in valid_tool_names: + base_error_message = f"[ToolConstraintError] Cannot call {tool_call_name}, valid tools to call include: {valid_tool_names}." + violated_rule_messages = tool_rules_solver.guess_rule_violation(tool_call_name) + if violated_rule_messages: + bullet_points = "\n".join(f"\t- {msg}" for msg in violated_rule_messages) + base_error_message += f"\n** Hint: Possible rules that were violated:\n{bullet_points}" + tool_execution_result = ToolExecutionResult(status="error", func_return=base_error_message) + else: + tool_execution_result = await self._execute_tool( + tool_name=tool_call_name, + tool_args=tool_args, + agent_state=agent_state, + agent_step_span=agent_step_span, + step_id=step_id, + ) log_telemetry( self.logger, "_handle_ai_response execute tool finish", tool_execution_result=tool_execution_result, tool_call_id=tool_call_id ) diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index caecd271..acd35c8a 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -172,6 +172,40 @@ class ToolRulesSolver(BaseModel): ) return None + def guess_rule_violation(self, tool_name: str) -> List[str]: + """ + Check if the given tool name or the previous tool in history matches any tool rule, + and return rendered prompt templates for matching rules. + + Args: + tool_name: The name of the tool to check for rule violations + + Returns: + List of rendered prompt templates from matching tool rules + """ + violated_rules = [] + + # Get the previous tool from history if it exists + previous_tool = self.tool_call_history[-1] if self.tool_call_history else None + + # Check all tool rules for matches + all_rules = ( + self.init_tool_rules + + self.continue_tool_rules + + self.child_based_tool_rules + + self.parent_tool_rules + + self.terminal_tool_rules + ) + + for rule in all_rules: + # Check if the current tool name or previous tool matches this rule's tool_name + if rule.tool_name == tool_name or (previous_tool and rule.tool_name == previous_tool): + rendered_prompt = rule.render_prompt() + if rendered_prompt: + violated_rules.append(rendered_prompt) + + return violated_rules + @staticmethod def validate_conditional_tool(rule: ConditionalToolRule): """ diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index 7efa07c0..94dc4978 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -25,19 +25,24 @@ class BaseToolRule(LettaBase): def render_prompt(self) -> Optional[str]: """Render the prompt template with this rule's attributes.""" - if not self.prompt_template: + template_to_use = self.prompt_template or self._get_default_template() + if not template_to_use: return None try: - template = Template(self.prompt_template) + template = Template(template_to_use) return template.render(**self.model_dump()) except Exception as e: logger.warning( f"Failed to render prompt template for tool rule '{self.tool_name}' (type: {self.type}). " - f"Template: '{self.prompt_template}'. Error: {e}" + f"Template: '{template_to_use}'. Error: {e}" ) return None + def _get_default_template(self) -> Optional[str]: + """Get the default template for this rule type. Override in subclasses.""" + return None + class ChildToolRule(BaseToolRule): """ @@ -46,11 +51,18 @@ class ChildToolRule(BaseToolRule): type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools children: List[str] = Field(..., description="The children tools that can be invoked.") + prompt_template: Optional[str] = Field( + default="After using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}", + description="Optional Jinja2 template for generating agent prompt about this tool rule.", + ) def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: last_tool = tool_call_history[-1] if tool_call_history else None return set(self.children) if last_tool == self.tool_name else available_tools + def _get_default_template(self) -> Optional[str]: + return "After using {{ tool_name }}, you can only use these tools: {{ children | join(', ') }}" + class ParentToolRule(BaseToolRule): """ @@ -59,11 +71,18 @@ class ParentToolRule(BaseToolRule): type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool children: List[str] = Field(..., description="The children tools that can be invoked.") + prompt_template: Optional[str] = Field( + default="{{ children | join(', ') }} can only be used after {{ tool_name }}", + description="Optional Jinja2 template for generating agent prompt about this tool rule.", + ) def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: last_tool = tool_call_history[-1] if tool_call_history else None return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children) + def _get_default_template(self) -> Optional[str]: + return "{{ children | join(', ') }} can only be used after {{ tool_name }}" + class ConditionalToolRule(BaseToolRule): """ @@ -74,6 +93,10 @@ class ConditionalToolRule(BaseToolRule): default_child: Optional[str] = Field(None, description="The default child tool to be called. If None, any tool can be called.") child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping") require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case") + prompt_template: Optional[str] = Field( + default="{{ tool_name }} will determine which tool to use next based on its output", + description="Optional Jinja2 template for generating agent prompt about this tool rule.", + ) def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: """Determine valid tools based on function output mapping.""" @@ -119,6 +142,9 @@ class ConditionalToolRule(BaseToolRule): else: # Assume string return str(function_output) == str(key) + def _get_default_template(self) -> Optional[str]: + return "{{ tool_name }} will determine which tool to use next based on its output" + class InitToolRule(BaseToolRule): """ @@ -134,6 +160,13 @@ class TerminalToolRule(BaseToolRule): """ type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop + prompt_template: Optional[str] = Field( + default="{{ tool_name }} ends the conversation when called", + description="Optional Jinja2 template for generating agent prompt about this tool rule.", + ) + + def _get_default_template(self) -> Optional[str]: + return "{{ tool_name }} ends the conversation when called" class ContinueToolRule(BaseToolRule): @@ -142,6 +175,10 @@ class ContinueToolRule(BaseToolRule): """ type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop + prompt_template: Optional[str] = Field( + default="{{ tool_name }} requires continuing the conversation when called", + description="Optional Jinja2 template for generating agent prompt about this tool rule.", + ) class MaxCountPerStepToolRule(BaseToolRule): @@ -166,6 +203,9 @@ class MaxCountPerStepToolRule(BaseToolRule): return available_tools + def _get_default_template(self) -> Optional[str]: + return "{{ tool_name }}: max {{ max_count_limit }} use(s) per turn" + ToolRule = Annotated[ Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule, ParentToolRule], diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 943e779b..916f1863 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -520,51 +520,6 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState): # assert len(responses) == len(messages), "Not all messages were processed" -def test_send_message_async(client: LettaSDKClient, agent: AgentState): - """ - Test that we can send a message asynchronously and retrieve the messages, along with usage statistics - """ - test_message = "This is a test message, respond to the user with a sentence." - run = client.agents.messages.create_async( - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content=test_message, - ), - ], - use_assistant_message=False, - ) - assert run.id is not None - assert run.status == "created" - - # Wait for the job to complete, cancel it if takes over 10 seconds - start_time = time.time() - while run.status == "created": - time.sleep(1) - run = client.runs.retrieve(run_id=run.id) - print(f"Run status: {run.status}") - if time.time() - start_time > 10: - pytest.fail("Run took too long to complete") - - print(f"Run completed in {time.time() - start_time} seconds, run={run}") - assert run.status == "completed" - - # Get messages for the job - messages = client.runs.messages.list(run_id=run.id) - assert len(messages) >= 2 # At least assistant response - - # Check filters - assistant_messages = client.runs.messages.list(run_id=run.id, role="assistant") - assert len(assistant_messages) > 0 - tool_messages = client.runs.messages.list(run_id=run.id, role="tool") - assert len(tool_messages) > 0 - - # specific_tool_messages = [message for message in client.runs.list_run_messages(run_id=run.id) if isinstance(message, ToolCallMessage)] - # assert specific_tool_messages[0].tool_call.name == "send_message" - # assert len(specific_tool_messages) > 0 - - def test_agent_creation(client: LettaSDKClient): """Test that block IDs are properly attached when creating an agent.""" sleeptime_agent_system = """ diff --git a/tests/test_server.py b/tests/test_server.py index 4f6fdc25..6a13f672 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -901,18 +901,6 @@ async def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, u assert not result.stderr -def test_composio_client_simple(server): - apps = server.get_composio_apps() - # Assert there's some amount of apps returned - assert len(apps) > 0 - - app = apps[0] - actions = server.get_composio_actions_from_app_name(composio_app_name=app.name) - - # Assert there's some amount of actions - assert len(actions) > 0 - - async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tools, base_memory_tools): """Test that the memory rebuild is generating the correct number of role=system messages""" actor = user