From 6fb2968006500527ba25d2fb0b78c2c1c09e760b Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 17 Dec 2024 15:44:41 -0800 Subject: [PATCH] fix: refactor sandbox run logic to add status field (#2248) Co-authored-by: Caren Thomas --- letta/agent.py | 10 ++- letta/schemas/sandbox_config.py | 3 +- letta/server/server.py | 45 +++---------- letta/services/tool_execution_sandbox.py | 86 ++++++++++++++---------- letta/utils.py | 8 +++ 5 files changed, 74 insertions(+), 78 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index a7448ac4..689f9eb0 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -64,6 +64,7 @@ from letta.system import ( ) from letta.utils import ( count_tokens, + get_friendly_error_msg, get_local_time, get_tool_call_id, get_utc_time, @@ -456,12 +457,9 @@ class Agent(BaseAgent): except Exception as e: # Need to catch error here, or else trunction wont happen # TODO: modify to function execution error - from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT - - error_msg = f"Error executing tool {function_name}: {e}" - if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT: - error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] - raise ValueError(error_msg) + function_response = get_friendly_error_msg( + function_name=function_name, exception_name=type(e).__name__, exception_message=str(e) + ) return function_response diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index 9b118cf6..f86233fa 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -1,7 +1,7 @@ import hashlib import json from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, model_validator @@ -21,6 +21,7 @@ class SandboxRunResult(BaseModel): agent_state: Optional[AgentState] = Field(None, description="The agent state") stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation") stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation") + status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object") sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox") diff --git a/letta/server/server.py b/letta/server/server.py index 24d70ef3..fa9ca8ed 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -74,7 +74,7 @@ from letta.services.source_manager import SourceManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager -from letta.utils import get_utc_time, json_dumps, json_loads +from letta.utils import get_friendly_error_msg, get_utc_time, json_dumps, json_loads logger = get_logger(__name__) @@ -1395,55 +1395,30 @@ class SyncServer(Server): # Next, attempt to run the tool with the sandbox try: sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state) - function_response = str(sandbox_run_result.func_return) - stdout = [s for s in sandbox_run_result.stdout if s.strip()] - stderr = [s for s in sandbox_run_result.stderr if s.strip()] - - # expected error - if stderr: - error_msg = self.get_error_msg_for_func_return(tool.name, stderr[-1]) - return FunctionReturn( - id="null", - function_call_id="null", - date=get_utc_time(), - status="error", - function_return=error_msg, - stdout=stdout, - stderr=stderr, - ) - return FunctionReturn( id="null", function_call_id="null", date=get_utc_time(), - status="success", - function_return=function_response, - stdout=stdout, - stderr=stderr, + status=sandbox_run_result.status, + function_return=str(sandbox_run_result.func_return), + stdout=sandbox_run_result.stdout, + stderr=sandbox_run_result.stderr, ) - # unexpected error TODO(@cthomas): consolidate error handling except Exception as e: - error_msg = self.get_error_msg_for_func_return(tool.name, e) + func_return = get_friendly_error_msg( + function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e) + ) return FunctionReturn( id="null", function_call_id="null", date=get_utc_time(), status="error", - function_return=error_msg, - stdout=[""], + function_return=func_return, + stdout=[], stderr=[traceback.format_exc()], ) - def get_error_msg_for_func_return(self, tool_name, exception_message): - # same as agent.py - from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT - - error_msg = f"Error executing tool {tool_name}: {exception_message}" - if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT: - error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] - return error_msg - # Composio wrappers def get_composio_client(self, api_key: Optional[str] = None): if api_key: diff --git a/letta/services/tool_execution_sandbox.py b/letta/services/tool_execution_sandbox.py index 3aac64b5..b6004c3c 100644 --- a/letta/services/tool_execution_sandbox.py +++ b/letta/services/tool_execution_sandbox.py @@ -10,7 +10,7 @@ import tempfile import traceback import uuid import venv -from typing import Any, Dict, Optional, TextIO +from typing import Any, Dict, Optional from letta.log import get_logger from letta.schemas.agent import AgentState @@ -20,6 +20,7 @@ from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import tool_settings +from letta.utils import get_friendly_error_msg logger = get_logger(__name__) @@ -79,11 +80,11 @@ class ToolExecutionSandbox: logger.debug(f"Using local sandbox to execute {self.tool_name}") result = self.run_local_dir_sandbox(agent_state=agent_state) - # Log out any stdout from the tool run - logger.debug(f"Executed tool '{self.tool_name}', logging stdout from tool run: \n") - for log_line in result.stdout: + # Log out any stdout/stderr from the tool run + logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n") + for log_line in (result.stdout or []) + (result.stderr or []): logger.debug(f"{log_line}") - logger.debug(f"Ending stdout log from tool run.") + logger.debug(f"Ending output log from tool run.") # Return result return result @@ -126,30 +127,24 @@ class ToolExecutionSandbox: temp_file.flush() temp_file_path = temp_file.name - # Save the old stdout - old_stdout = sys.stdout - old_stderr = sys.stderr try: if local_configs.use_venv: return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path) else: - return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path, old_stdout, old_stderr) + return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path) except Exception as e: logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}") logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}") raise e finally: - # Clean up the temp file and restore stdout - sys.stdout = old_stdout - sys.stderr = old_stderr + # Clean up the temp file os.remove(temp_file_path) def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult: local_configs = sbx_config.get_local_config() venv_path = os.path.join(local_configs.sandbox_dir, local_configs.venv_name) - # Safety checks for the venv - # Verify that the venv path exists and is a directory + # Safety checks for the venv: verify that the venv path exists and is a directory if not os.path.isdir(venv_path): logger.warning(f"Virtual environment directory does not exist at: {venv_path}, creating one now...") self.create_venv_for_local_sandbox(sandbox_dir_path=local_configs.sandbox_dir, venv_path=venv_path, env=env) @@ -180,27 +175,42 @@ class ToolExecutionSandbox: return SandboxRunResult( func_return=func_return, agent_state=agent_state, - stdout=[stdout], - stderr=[result.stderr], + stdout=[stdout] if stdout else [], + stderr=[result.stderr] if result.stderr else [], + status="success", + sandbox_config_fingerprint=sbx_config.fingerprint(), + ) + + except subprocess.CalledProcessError as e: + logger.error(f"Executing tool {self.tool_name} has process error: {e}") + func_return = get_friendly_error_msg( + function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e), + ) + return SandboxRunResult( + func_return=func_return, + agent_state=None, + stdout=[e.stdout] if e.stdout else [], + stderr=[e.stderr] if e.stderr else [], + status="error", sandbox_config_fingerprint=sbx_config.fingerprint(), ) except subprocess.TimeoutExpired: raise TimeoutError(f"Executing tool {self.tool_name} has timed out.") - except subprocess.CalledProcessError as e: - logger.error(f"Executing tool {self.tool_name} has process error: {e}") - raise e + except Exception as e: logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}") raise e - def run_local_dir_sandbox_runpy( - self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str, old_stdout: TextIO, old_stderr: TextIO + self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str ) -> SandboxRunResult: - func_return, agent_state, error_msg = None, None, None + status = "success" + agent_state, stderr = None, None # Redirect stdout and stderr to capture script output + old_stdout = sys.stdout + old_stderr = sys.stderr captured_stdout, captured_stderr = io.StringIO(), io.StringIO() sys.stdout = captured_stdout sys.stderr = captured_stderr @@ -215,21 +225,24 @@ class ToolExecutionSandbox: func_return, agent_state = self.parse_best_effort(func_result) except Exception as e: + func_return = get_friendly_error_msg( + function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e) + ) traceback.print_exc(file=sys.stderr) - error_msg = f"{type(e).__name__}: {str(e)}" + status = "error" # Restore stdout and stderr and collect captured output sys.stdout = old_stdout sys.stderr = old_stderr - stdout_output = [captured_stdout.getvalue()] - stderr_output = [captured_stderr.getvalue()] - stderr_output.append(error_msg if error_msg else '') + stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else [] + stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else [] return SandboxRunResult( func_return=func_return, agent_state=agent_state, stdout=stdout_output, stderr=stderr_output, + status=status, sandbox_config_fingerprint=sbx_config.fingerprint(), ) @@ -280,20 +293,23 @@ class ToolExecutionSandbox: env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) code = self.generate_execution_script(agent_state=agent_state) execution = sbx.run_code(code, envs=env_vars) - func_return, agent_state = None, None - if execution.error is not None: - logger.error(f"Executing tool {self.tool_name} failed with {execution.error}") - execution.logs.stderr.append(execution.error.traceback) - execution.logs.stderr.append(f"{execution.error.name}: {execution.error.value}") - elif len(execution.results) == 0: - raise ValueError(f"Tool {self.tool_name} returned execution with None") - else: + if execution.results: func_return, agent_state = self.parse_best_effort(execution.results[0].text) + elif execution.error: + logger.error(f"Executing tool {self.tool_name} failed with {execution.error}") + func_return = get_friendly_error_msg( + function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value + ) + execution.logs.stderr.append(execution.error.traceback) + else: + raise ValueError(f"Tool {self.tool_name} returned execution with None") + return SandboxRunResult( func_return=func_return, agent_state=agent_state, stdout=execution.logs.stdout, stderr=execution.logs.stderr, + status="error" if execution.error else "success", sandbox_config_fingerprint=sbx_config.fingerprint(), ) @@ -481,5 +497,3 @@ class ToolExecutionSandbox: func_call_str = self.tool.name + "(" + params + ")" return func_call_str - - # diff --git a/letta/utils.py b/letta/utils.py index 7184a0e3..f8b3778b 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -1118,3 +1118,11 @@ def sanitize_filename(filename: str) -> str: # Return the sanitized filename return sanitized_filename + +def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str): + from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT + + error_msg = f"Error executing function {function_name}: {exception_name}: {exception_message}" + if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT: + error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] + return error_msg