fix: refactor sandbox run logic to add status field (#2248)

Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
cthomas
2024-12-17 15:44:41 -08:00
committed by GitHub
parent bb06ab0fcb
commit 6fb2968006
5 changed files with 74 additions and 78 deletions

View File

@@ -64,6 +64,7 @@ from letta.system import (
)
from letta.utils import (
count_tokens,
get_friendly_error_msg,
get_local_time,
get_tool_call_id,
get_utc_time,
@@ -456,12 +457,9 @@ class Agent(BaseAgent):
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
error_msg = f"Error executing tool {function_name}: {e}"
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
raise ValueError(error_msg)
function_response = get_friendly_error_msg(
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
)
return function_response

View File

@@ -1,7 +1,7 @@
import hashlib
import json
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, model_validator
@@ -21,6 +21,7 @@ class SandboxRunResult(BaseModel):
agent_state: Optional[AgentState] = Field(None, description="The agent state")
stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation")
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox")

View File

@@ -74,7 +74,7 @@ from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import get_utc_time, json_dumps, json_loads
from letta.utils import get_friendly_error_msg, get_utc_time, json_dumps, json_loads
logger = get_logger(__name__)
@@ -1395,55 +1395,30 @@ class SyncServer(Server):
# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state)
function_response = str(sandbox_run_result.func_return)
stdout = [s for s in sandbox_run_result.stdout if s.strip()]
stderr = [s for s in sandbox_run_result.stderr if s.strip()]
# expected error
if stderr:
error_msg = self.get_error_msg_for_func_return(tool.name, stderr[-1])
return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="error",
function_return=error_msg,
stdout=stdout,
stderr=stderr,
)
return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="success",
function_return=function_response,
stdout=stdout,
stderr=stderr,
status=sandbox_run_result.status,
function_return=str(sandbox_run_result.func_return),
stdout=sandbox_run_result.stdout,
stderr=sandbox_run_result.stderr,
)
# unexpected error TODO(@cthomas): consolidate error handling
except Exception as e:
error_msg = self.get_error_msg_for_func_return(tool.name, e)
func_return = get_friendly_error_msg(
function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e)
)
return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="error",
function_return=error_msg,
stdout=[""],
function_return=func_return,
stdout=[],
stderr=[traceback.format_exc()],
)
def get_error_msg_for_func_return(self, tool_name, exception_message):
# same as agent.py
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
error_msg = f"Error executing tool {tool_name}: {exception_message}"
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
return error_msg
# Composio wrappers
def get_composio_client(self, api_key: Optional[str] = None):
if api_key:

View File

@@ -10,7 +10,7 @@ import tempfile
import traceback
import uuid
import venv
from typing import Any, Dict, Optional, TextIO
from typing import Any, Dict, Optional
from letta.log import get_logger
from letta.schemas.agent import AgentState
@@ -20,6 +20,7 @@ from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import tool_settings
from letta.utils import get_friendly_error_msg
logger = get_logger(__name__)
@@ -79,11 +80,11 @@ class ToolExecutionSandbox:
logger.debug(f"Using local sandbox to execute {self.tool_name}")
result = self.run_local_dir_sandbox(agent_state=agent_state)
# Log out any stdout from the tool run
logger.debug(f"Executed tool '{self.tool_name}', logging stdout from tool run: \n")
for log_line in result.stdout:
# Log out any stdout/stderr from the tool run
logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n")
for log_line in (result.stdout or []) + (result.stderr or []):
logger.debug(f"{log_line}")
logger.debug(f"Ending stdout log from tool run.")
logger.debug(f"Ending output log from tool run.")
# Return result
return result
@@ -126,30 +127,24 @@ class ToolExecutionSandbox:
temp_file.flush()
temp_file_path = temp_file.name
# Save the old stdout
old_stdout = sys.stdout
old_stderr = sys.stderr
try:
if local_configs.use_venv:
return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path)
else:
return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path, old_stdout, old_stderr)
return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path)
except Exception as e:
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}")
raise e
finally:
# Clean up the temp file and restore stdout
sys.stdout = old_stdout
sys.stderr = old_stderr
# Clean up the temp file
os.remove(temp_file_path)
def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
local_configs = sbx_config.get_local_config()
venv_path = os.path.join(local_configs.sandbox_dir, local_configs.venv_name)
# Safety checks for the venv
# Verify that the venv path exists and is a directory
# Safety checks for the venv: verify that the venv path exists and is a directory
if not os.path.isdir(venv_path):
logger.warning(f"Virtual environment directory does not exist at: {venv_path}, creating one now...")
self.create_venv_for_local_sandbox(sandbox_dir_path=local_configs.sandbox_dir, venv_path=venv_path, env=env)
@@ -180,27 +175,42 @@ class ToolExecutionSandbox:
return SandboxRunResult(
func_return=func_return,
agent_state=agent_state,
stdout=[stdout],
stderr=[result.stderr],
stdout=[stdout] if stdout else [],
stderr=[result.stderr] if result.stderr else [],
status="success",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
except subprocess.CalledProcessError as e:
logger.error(f"Executing tool {self.tool_name} has process error: {e}")
func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e),
)
return SandboxRunResult(
func_return=func_return,
agent_state=None,
stdout=[e.stdout] if e.stdout else [],
stderr=[e.stderr] if e.stderr else [],
status="error",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
except subprocess.TimeoutExpired:
raise TimeoutError(f"Executing tool {self.tool_name} has timed out.")
except subprocess.CalledProcessError as e:
logger.error(f"Executing tool {self.tool_name} has process error: {e}")
raise e
except Exception as e:
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
raise e
def run_local_dir_sandbox_runpy(
self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str, old_stdout: TextIO, old_stderr: TextIO
self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str
) -> SandboxRunResult:
func_return, agent_state, error_msg = None, None, None
status = "success"
agent_state, stderr = None, None
# Redirect stdout and stderr to capture script output
old_stdout = sys.stdout
old_stderr = sys.stderr
captured_stdout, captured_stderr = io.StringIO(), io.StringIO()
sys.stdout = captured_stdout
sys.stderr = captured_stderr
@@ -215,21 +225,24 @@ class ToolExecutionSandbox:
func_return, agent_state = self.parse_best_effort(func_result)
except Exception as e:
func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e)
)
traceback.print_exc(file=sys.stderr)
error_msg = f"{type(e).__name__}: {str(e)}"
status = "error"
# Restore stdout and stderr and collect captured output
sys.stdout = old_stdout
sys.stderr = old_stderr
stdout_output = [captured_stdout.getvalue()]
stderr_output = [captured_stderr.getvalue()]
stderr_output.append(error_msg if error_msg else '')
stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else []
stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else []
return SandboxRunResult(
func_return=func_return,
agent_state=agent_state,
stdout=stdout_output,
stderr=stderr_output,
status=status,
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
@@ -280,20 +293,23 @@ class ToolExecutionSandbox:
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
code = self.generate_execution_script(agent_state=agent_state)
execution = sbx.run_code(code, envs=env_vars)
func_return, agent_state = None, None
if execution.error is not None:
logger.error(f"Executing tool {self.tool_name} failed with {execution.error}")
execution.logs.stderr.append(execution.error.traceback)
execution.logs.stderr.append(f"{execution.error.name}: {execution.error.value}")
elif len(execution.results) == 0:
raise ValueError(f"Tool {self.tool_name} returned execution with None")
else:
if execution.results:
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
elif execution.error:
logger.error(f"Executing tool {self.tool_name} failed with {execution.error}")
func_return = get_friendly_error_msg(
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
)
execution.logs.stderr.append(execution.error.traceback)
else:
raise ValueError(f"Tool {self.tool_name} returned execution with None")
return SandboxRunResult(
func_return=func_return,
agent_state=agent_state,
stdout=execution.logs.stdout,
stderr=execution.logs.stderr,
status="error" if execution.error else "success",
sandbox_config_fingerprint=sbx_config.fingerprint(),
)
@@ -481,5 +497,3 @@ class ToolExecutionSandbox:
func_call_str = self.tool.name + "(" + params + ")"
return func_call_str
#

View File

@@ -1118,3 +1118,11 @@ def sanitize_filename(filename: str) -> str:
# Return the sanitized filename
return sanitized_filename
def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str):
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
error_msg = f"Error executing function {function_name}: {exception_name}: {exception_message}"
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
return error_msg