fix: propagate error on tool failure (#2281)
Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import List, Optional, Tuple, Union
|
||||
from letta.constants import (
|
||||
BASE_TOOLS,
|
||||
CLI_WARNING_PREFIX,
|
||||
ERROR_MESSAGE_PREFIX,
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
LLM_MAX_TOKENS,
|
||||
@@ -498,7 +499,7 @@ class Agent(BaseAgent):
|
||||
function_args.pop("self", None)
|
||||
# error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}"
|
||||
# Less detailed - don't provide full args, idea is that it should be in recent context so no need (just adds noise)
|
||||
error_msg = f"Error calling function {function_name}: {str(e)}"
|
||||
error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
|
||||
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
|
||||
printd(error_msg_user)
|
||||
function_response = package_function_response(False, error_msg)
|
||||
@@ -521,8 +522,29 @@ class Agent(BaseAgent):
|
||||
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# Step 4: check if function response is an error
|
||||
if function_response_string.startswith(ERROR_MESSAGE_PREFIX):
|
||||
function_response = package_function_response(False, function_response_string)
|
||||
# TODO: truncate error message somehow
|
||||
messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
) # extend conversation with function response
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
|
||||
self.interface.function_message(f"Error: {function_response_string}", msg_obj=messages[-1])
|
||||
return messages, False, True # force a heartbeat to allow agent to handle error
|
||||
|
||||
# If no failures happened along the way: ...
|
||||
# Step 4: send the info on the function call and function response to GPT
|
||||
# Step 5: send the info on the function call and function response to GPT
|
||||
messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
|
||||
@@ -69,6 +69,8 @@ INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG = STARTUP_QUOTES[2]
|
||||
|
||||
CLI_WARNING_PREFIX = "Warning: "
|
||||
|
||||
ERROR_MESSAGE_PREFIX = "Error"
|
||||
|
||||
NON_USER_MSG_PREFIX = "[This is an automated system message hidden from the user] "
|
||||
|
||||
# Constants to do with summarization / conversation length window
|
||||
|
||||
@@ -238,7 +238,7 @@ class QueuingInterface(AgentInterface):
|
||||
new_message = {"function_return": msg, "status": "success"}
|
||||
|
||||
elif msg.startswith("Error: "):
|
||||
msg = msg.replace("Error: ", "")
|
||||
msg = msg.replace("Error: ", "", 1)
|
||||
new_message = {"function_return": msg, "status": "error"}
|
||||
|
||||
else:
|
||||
@@ -951,7 +951,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
)
|
||||
|
||||
elif msg.startswith("Error: "):
|
||||
msg = msg.replace("Error: ", "")
|
||||
msg = msg.replace("Error: ", "", 1)
|
||||
# new_message = {"function_return": msg, "status": "error"}
|
||||
assert msg_obj.tool_call_id is not None
|
||||
new_message = ToolReturnMessage(
|
||||
|
||||
@@ -28,6 +28,7 @@ from letta.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
ERROR_MESSAGE_PREFIX,
|
||||
LETTA_DIR,
|
||||
MAX_FILENAME_LENGTH,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
@@ -1122,7 +1123,7 @@ def sanitize_filename(filename: str) -> str:
|
||||
def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str):
|
||||
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
||||
|
||||
error_msg = f"Error executing function {function_name}: {exception_name}: {exception_message}"
|
||||
error_msg = f"{ERROR_MESSAGE_PREFIX} executing function {function_name}: {exception_name}: {exception_message}"
|
||||
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
||||
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
||||
return error_msg
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -382,6 +383,39 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
|
||||
|
||||
def test_function_always_error(client: Union[LocalClient, RESTClient]):
|
||||
"""Test to see if function that errors works correctly"""
|
||||
|
||||
def always_error():
|
||||
"""
|
||||
Always throw an error.
|
||||
"""
|
||||
return 5/0
|
||||
|
||||
tool = client.create_or_update_tool(func=always_error)
|
||||
agent = client.create_agent(tool_ids=[tool.id])
|
||||
# get function response
|
||||
response = client.send_message(agent_id=agent.id, message="call the always_error function", role="user")
|
||||
print(response.messages)
|
||||
|
||||
response_message = None
|
||||
for message in response.messages:
|
||||
if isinstance(message, FunctionReturn):
|
||||
response_message = message
|
||||
break
|
||||
|
||||
assert response_message, "FunctionReturn message not found in response"
|
||||
assert response_message.status == "error"
|
||||
if isinstance(client, RESTClient):
|
||||
assert response_message.function_return == "Error executing function always_error: ZeroDivisionError: division by zero"
|
||||
else:
|
||||
response_json = json.loads(response_message.function_return)
|
||||
assert response_json['status'] == "Failed"
|
||||
assert response_json['message'] == "Error executing function always_error: ZeroDivisionError: division by zero"
|
||||
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user