From c3e425c012fbd4961fd03566652135d1e05a7a03 Mon Sep 17 00:00:00 2001 From: mlong93 <35275280+mlong93@users.noreply.github.com> Date: Sat, 1 Feb 2025 16:41:37 -0800 Subject: [PATCH] fix: failed tool calls will not be called in the subsequent step (#868) Co-authored-by: Mindy Long Co-authored-by: Shubham Naik --- letta/agent.py | 146 +++++++++++---------- tests/integration_test_agent_tool_graph.py | 61 +++++++++ tests/test_client.py | 7 +- 3 files changed, 144 insertions(+), 70 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 8f1760f8..14815a50 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -38,6 +38,7 @@ from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.sandbox_config import SandboxRunResult from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics @@ -198,7 +199,9 @@ class Agent(BaseAgent): return True return False - def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool): + def execute_tool_and_persist_state( + self, function_name: str, function_args: dict, target_letta_tool: Tool + ) -> tuple[str, Optional[SandboxRunResult]]: """ Execute tool modifications and persist the state of the agent. Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data @@ -242,6 +245,7 @@ class Agent(BaseAgent): assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" if updated_agent_state is not None: self.update_memory_if_changed(updated_agent_state.memory) + return function_response, sandbox_run_result except Exception as e: # Need to catch error here, or else trunction wont happen # TODO: modify to function execution error @@ -249,7 +253,44 @@ class Agent(BaseAgent): function_name=function_name, exception_name=type(e).__name__, exception_message=str(e) ) - return function_response + return function_response, None + + def _handle_function_error_response( + self, + error_msg: str, + tool_call_id: str, + function_name: str, + function_response: str, + messages: List[Message], + include_function_failed_message: bool = False, + ) -> List[Message]: + """ + Handle error from function call response + """ + # Update tool rules + self.last_function_response = function_response + self.tool_rules_solver.update_tool_usage(function_name) + + # Extend conversation with function response + function_response = package_function_response(False, error_msg) + new_message = Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.created_by_id, + model=self.model, + openai_message_dict={ + "role": "tool", + "name": function_name, + "content": function_response, + "tool_call_id": tool_call_id, + }, + ) + messages.append(new_message) + self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message) + if include_function_failed_message: + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message) + + # Return updated messages + return messages def _get_ai_reply( self, @@ -261,6 +302,7 @@ class Agent(BaseAgent): backoff_factor: float = 0.5, # delay multiplier for exponential backoff max_delay: float = 10.0, # max delay between retries step_count: Optional[int] = None, + last_function_failed: bool = False, ) -> ChatCompletionResponse: """Get response from LLM API with robust retry mechanism.""" @@ -273,6 +315,12 @@ class Agent(BaseAgent): else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names] ) + # Don't allow a tool to be called if it failed last time + if last_function_failed and self.tool_rules_solver.last_tool_name: + allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.last_tool_name] + if not allowed_functions: + return None + # For the first message, force the initial tool if one is specified force_tool_call = None if ( @@ -285,6 +333,7 @@ class Agent(BaseAgent): # Force a tool call if exactly one tool is specified elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1: force_tool_call = allowed_tool_names[0] + for attempt in range(1, empty_response_retry_limit + 1): try: response = create( @@ -409,21 +458,7 @@ class Agent(BaseAgent): if not target_letta_tool: error_msg = f"No function named {function_name}" - function_response = package_function_response(False, error_msg) - messages.append( - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={ - "role": "tool", - "name": function_name, - "content": function_response, - "tool_call_id": tool_call_id, - }, - ) - ) # extend conversation with function response - self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) + messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages) return messages, False, True # force a heartbeat to allow agent to handle error # Failure case 2: function name is OK, but function args are bad JSON @@ -432,21 +467,7 @@ class Agent(BaseAgent): function_args = parse_json(raw_function_args) except Exception: error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}" - function_response = package_function_response(False, error_msg) - messages.append( - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={ - "role": "tool", - "name": function_name, - "content": function_response, - "tool_call_id": tool_call_id, - }, - ) - ) # extend conversation with function response - self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) + messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages) return messages, False, True # force a heartbeat to allow agent to handle error # Check if inner thoughts is in the function call arguments (possible apparently if you are using Azure) @@ -479,7 +500,12 @@ class Agent(BaseAgent): self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1]) try: # handle tool execution (sandbox) and state updates - function_response = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) + function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) + + if sandbox_run_result and sandbox_run_result.status == "error": + error_msg = f"Error calling function {function_name} with args {function_args}: {sandbox_run_result.stderr}" + messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages) + return messages, False, True # force a heartbeat to allow agent to handle error # handle trunction if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: @@ -505,45 +531,16 @@ class Agent(BaseAgent): error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) error_msg_user = f"{error_msg}\n{traceback.format_exc()}" self.logger.error(error_msg_user) - function_response = package_function_response(False, error_msg) - self.last_function_response = function_response - # TODO: truncate error message somehow - messages.append( - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={ - "role": "tool", - "name": function_name, - "content": function_response, - "tool_call_id": tool_call_id, - }, - ) - ) # extend conversation with function response - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) - self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) + messages = self._handle_function_error_response( + error_msg, tool_call_id, function_name, function_response, messages, include_function_failed_message=True + ) return messages, False, True # force a heartbeat to allow agent to handle error # Step 4: check if function response is an error if function_response_string.startswith(ERROR_MESSAGE_PREFIX): - function_response = package_function_response(False, function_response_string) - # TODO: truncate error message somehow - messages.append( - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={ - "role": "tool", - "name": function_name, - "content": function_response, - "tool_call_id": tool_call_id, - }, - ) - ) # extend conversation with function response - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) - self.interface.function_message(f"Error: {function_response_string}", msg_obj=messages[-1]) + messages = self._handle_function_error_response( + error_msg, tool_call_id, function_name, function_response, messages, include_function_failed_message=True + ) return messages, False, True # force a heartbeat to allow agent to handle error # If no failures happened along the way: ... @@ -607,9 +604,11 @@ class Agent(BaseAgent): counter = 0 total_usage = UsageStatistics() step_count = 0 + function_failed = False while True: kwargs["first_message"] = False kwargs["step_count"] = step_count + kwargs["last_function_failed"] = function_failed step_response = self.inner_step( messages=next_input_message, **kwargs, @@ -689,6 +688,7 @@ class Agent(BaseAgent): step_count: Optional[int] = None, metadata: Optional[dict] = None, summarize_attempt_count: int = 0, + last_function_failed: bool = False, ) -> AgentStepResponse: """Runs a single step in the agent loop (generates at most one LLM call)""" @@ -723,7 +723,17 @@ class Agent(BaseAgent): first_message=first_message, stream=stream, step_count=step_count, + last_function_failed=last_function_failed, ) + if not response: + # EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early + return AgentStepResponse( + messages=input_message_sequence, + heartbeat_request=False, + function_failed=False, # NOTE: this is different from other function fails. We force to return early + in_context_memory_warning=False, + usage=UsageStatistics(), + ) # Step 3: check if LLM wanted to call a function # (if yes) Step 4: call the function diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 3a24e29b..025f751b 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -659,3 +659,64 @@ def test_simple_tool_rule(mock_e2b_api_key_none): assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin" cleanup(client, agent_uuid=agent_state.id) + + +def test_init_tool_rule_always_fails_one_tool(): + """ + Test an init tool rule that always fails when called. The agent has only one tool available. + + Once that tool fails and the agent removes that tool, the agent should have 0 tools available. + + This means that the agent should return from `step` early. + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tools + bad_tool = client.create_or_update_tool(auto_error) + + # Create tool rule: InitToolRule + tool_rule = InitToolRule( + tool_name=bad_tool.name, + ) + + # Set up agent with the tool rule + claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=False) + + # Start conversation + response = client.user_message(agent_id=agent_state.id, message="blah blah blah") + + # Verify the tool calls + tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] + assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls + assert_invoked_function_call(response.messages, bad_tool.name) + + +def test_init_tool_rule_always_fails_multiple_tools(): + """ + Test an init tool rule that always fails when called. The agent has only 1+ tools available. + Once that tool fails and the agent removes that tool, the agent should have other tools available. + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tools + bad_tool = client.create_or_update_tool(auto_error) + + # Create tool rule: InitToolRule + tool_rule = InitToolRule( + tool_name=bad_tool.name, + ) + + # Set up agent with the tool rule + claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=True) + + # Start conversation + response = client.user_message(agent_id=agent_state.id, message="blah blah blah") + + # Verify the tool calls + tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] + assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls + assert_invoked_function_call(response.messages, bad_tool.name) diff --git a/tests/test_client.py b/tests/test_client.py index 721a293f..c9cfae4a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -480,11 +480,14 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]): assert response_message.status == "error" if isinstance(client, RESTClient): - assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero" + assert ( + response_message.tool_return.startswith("Error calling function always_error") + and "ZeroDivisionError" in response_message.tool_return + ) else: response_json = json.loads(response_message.tool_return) assert response_json["status"] == "Failed" - assert response_json["message"] == "Error executing function always_error: ZeroDivisionError: division by zero" + assert "Error calling function always_error" in response_json["message"] and "ZeroDivisionError" in response_json["message"] client.delete_agent(agent_id=agent.id)