fix: refactor sandbox run logic to add status field (#2248)
Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
@@ -64,6 +64,7 @@ from letta.system import (
|
||||
)
|
||||
from letta.utils import (
|
||||
count_tokens,
|
||||
get_friendly_error_msg,
|
||||
get_local_time,
|
||||
get_tool_call_id,
|
||||
get_utc_time,
|
||||
@@ -456,12 +457,9 @@ class Agent(BaseAgent):
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
||||
|
||||
error_msg = f"Error executing tool {function_name}: {e}"
|
||||
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
||||
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
||||
raise ValueError(error_msg)
|
||||
function_response = get_friendly_error_msg(
|
||||
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
|
||||
)
|
||||
|
||||
return function_response
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -21,6 +21,7 @@ class SandboxRunResult(BaseModel):
|
||||
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
||||
stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the function invocation")
|
||||
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
||||
sandbox_config_fingerprint: str = Field(None, description="The fingerprint of the config for the sandbox")
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.utils import get_utc_time, json_dumps, json_loads
|
||||
from letta.utils import get_friendly_error_msg, get_utc_time, json_dumps, json_loads
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1395,55 +1395,30 @@ class SyncServer(Server):
|
||||
# Next, attempt to run the tool with the sandbox
|
||||
try:
|
||||
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state)
|
||||
function_response = str(sandbox_run_result.func_return)
|
||||
stdout = [s for s in sandbox_run_result.stdout if s.strip()]
|
||||
stderr = [s for s in sandbox_run_result.stderr if s.strip()]
|
||||
|
||||
# expected error
|
||||
if stderr:
|
||||
error_msg = self.get_error_msg_for_func_return(tool.name, stderr[-1])
|
||||
return FunctionReturn(
|
||||
id="null",
|
||||
function_call_id="null",
|
||||
date=get_utc_time(),
|
||||
status="error",
|
||||
function_return=error_msg,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
)
|
||||
|
||||
return FunctionReturn(
|
||||
id="null",
|
||||
function_call_id="null",
|
||||
date=get_utc_time(),
|
||||
status="success",
|
||||
function_return=function_response,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
status=sandbox_run_result.status,
|
||||
function_return=str(sandbox_run_result.func_return),
|
||||
stdout=sandbox_run_result.stdout,
|
||||
stderr=sandbox_run_result.stderr,
|
||||
)
|
||||
|
||||
# unexpected error TODO(@cthomas): consolidate error handling
|
||||
except Exception as e:
|
||||
error_msg = self.get_error_msg_for_func_return(tool.name, e)
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=tool.name, exception_name=type(e).__name__, exception_message=str(e)
|
||||
)
|
||||
return FunctionReturn(
|
||||
id="null",
|
||||
function_call_id="null",
|
||||
date=get_utc_time(),
|
||||
status="error",
|
||||
function_return=error_msg,
|
||||
stdout=[""],
|
||||
function_return=func_return,
|
||||
stdout=[],
|
||||
stderr=[traceback.format_exc()],
|
||||
)
|
||||
|
||||
def get_error_msg_for_func_return(self, tool_name, exception_message):
|
||||
# same as agent.py
|
||||
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
||||
|
||||
error_msg = f"Error executing tool {tool_name}: {exception_message}"
|
||||
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
||||
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
||||
return error_msg
|
||||
|
||||
# Composio wrappers
|
||||
def get_composio_client(self, api_key: Optional[str] = None):
|
||||
if api_key:
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
import traceback
|
||||
import uuid
|
||||
import venv
|
||||
from typing import Any, Dict, Optional, TextIO
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -20,6 +20,7 @@ from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import tool_settings
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -79,11 +80,11 @@ class ToolExecutionSandbox:
|
||||
logger.debug(f"Using local sandbox to execute {self.tool_name}")
|
||||
result = self.run_local_dir_sandbox(agent_state=agent_state)
|
||||
|
||||
# Log out any stdout from the tool run
|
||||
logger.debug(f"Executed tool '{self.tool_name}', logging stdout from tool run: \n")
|
||||
for log_line in result.stdout:
|
||||
# Log out any stdout/stderr from the tool run
|
||||
logger.debug(f"Executed tool '{self.tool_name}', logging output from tool run: \n")
|
||||
for log_line in (result.stdout or []) + (result.stderr or []):
|
||||
logger.debug(f"{log_line}")
|
||||
logger.debug(f"Ending stdout log from tool run.")
|
||||
logger.debug(f"Ending output log from tool run.")
|
||||
|
||||
# Return result
|
||||
return result
|
||||
@@ -126,30 +127,24 @@ class ToolExecutionSandbox:
|
||||
temp_file.flush()
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
# Save the old stdout
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
try:
|
||||
if local_configs.use_venv:
|
||||
return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path)
|
||||
else:
|
||||
return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path, old_stdout, old_stderr)
|
||||
return self.run_local_dir_sandbox_runpy(sbx_config, env_vars, temp_file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
|
||||
logger.error(f"Logging out tool {self.tool_name} auto-generated code for debugging: \n\n{code}")
|
||||
raise e
|
||||
finally:
|
||||
# Clean up the temp file and restore stdout
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
# Clean up the temp file
|
||||
os.remove(temp_file_path)
|
||||
|
||||
def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
|
||||
local_configs = sbx_config.get_local_config()
|
||||
venv_path = os.path.join(local_configs.sandbox_dir, local_configs.venv_name)
|
||||
|
||||
# Safety checks for the venv
|
||||
# Verify that the venv path exists and is a directory
|
||||
# Safety checks for the venv: verify that the venv path exists and is a directory
|
||||
if not os.path.isdir(venv_path):
|
||||
logger.warning(f"Virtual environment directory does not exist at: {venv_path}, creating one now...")
|
||||
self.create_venv_for_local_sandbox(sandbox_dir_path=local_configs.sandbox_dir, venv_path=venv_path, env=env)
|
||||
@@ -180,27 +175,42 @@ class ToolExecutionSandbox:
|
||||
return SandboxRunResult(
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=[stdout],
|
||||
stderr=[result.stderr],
|
||||
stdout=[stdout] if stdout else [],
|
||||
stderr=[result.stderr] if result.stderr else [],
|
||||
status="success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Executing tool {self.tool_name} has process error: {e}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e),
|
||||
)
|
||||
return SandboxRunResult(
|
||||
func_return=func_return,
|
||||
agent_state=None,
|
||||
stdout=[e.stdout] if e.stdout else [],
|
||||
stderr=[e.stderr] if e.stderr else [],
|
||||
status="error",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
raise TimeoutError(f"Executing tool {self.tool_name} has timed out.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Executing tool {self.tool_name} has process error: {e}")
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Executing tool {self.tool_name} has an unexpected error: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def run_local_dir_sandbox_runpy(
|
||||
self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str, old_stdout: TextIO, old_stderr: TextIO
|
||||
self, sbx_config: SandboxConfig, env_vars: Dict[str, str], temp_file_path: str
|
||||
) -> SandboxRunResult:
|
||||
func_return, agent_state, error_msg = None, None, None
|
||||
status = "success"
|
||||
agent_state, stderr = None, None
|
||||
|
||||
# Redirect stdout and stderr to capture script output
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
captured_stdout, captured_stderr = io.StringIO(), io.StringIO()
|
||||
sys.stdout = captured_stdout
|
||||
sys.stderr = captured_stderr
|
||||
@@ -215,21 +225,24 @@ class ToolExecutionSandbox:
|
||||
func_return, agent_state = self.parse_best_effort(func_result)
|
||||
|
||||
except Exception as e:
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name, exception_name=type(e).__name__, exception_message=str(e)
|
||||
)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
status = "error"
|
||||
|
||||
# Restore stdout and stderr and collect captured output
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
stdout_output = [captured_stdout.getvalue()]
|
||||
stderr_output = [captured_stderr.getvalue()]
|
||||
stderr_output.append(error_msg if error_msg else '')
|
||||
stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else []
|
||||
stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else []
|
||||
|
||||
return SandboxRunResult(
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=stdout_output,
|
||||
stderr=stderr_output,
|
||||
status=status,
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@@ -280,20 +293,23 @@ class ToolExecutionSandbox:
|
||||
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
|
||||
code = self.generate_execution_script(agent_state=agent_state)
|
||||
execution = sbx.run_code(code, envs=env_vars)
|
||||
func_return, agent_state = None, None
|
||||
if execution.error is not None:
|
||||
logger.error(f"Executing tool {self.tool_name} failed with {execution.error}")
|
||||
execution.logs.stderr.append(execution.error.traceback)
|
||||
execution.logs.stderr.append(f"{execution.error.name}: {execution.error.value}")
|
||||
elif len(execution.results) == 0:
|
||||
raise ValueError(f"Tool {self.tool_name} returned execution with None")
|
||||
else:
|
||||
if execution.results:
|
||||
func_return, agent_state = self.parse_best_effort(execution.results[0].text)
|
||||
elif execution.error:
|
||||
logger.error(f"Executing tool {self.tool_name} failed with {execution.error}")
|
||||
func_return = get_friendly_error_msg(
|
||||
function_name=self.tool_name, exception_name=execution.error.name, exception_message=execution.error.value
|
||||
)
|
||||
execution.logs.stderr.append(execution.error.traceback)
|
||||
else:
|
||||
raise ValueError(f"Tool {self.tool_name} returned execution with None")
|
||||
|
||||
return SandboxRunResult(
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=execution.logs.stdout,
|
||||
stderr=execution.logs.stderr,
|
||||
status="error" if execution.error else "success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@@ -481,5 +497,3 @@ class ToolExecutionSandbox:
|
||||
|
||||
func_call_str = self.tool.name + "(" + params + ")"
|
||||
return func_call_str
|
||||
|
||||
#
|
||||
|
||||
@@ -1118,3 +1118,11 @@ def sanitize_filename(filename: str) -> str:
|
||||
|
||||
# Return the sanitized filename
|
||||
return sanitized_filename
|
||||
|
||||
def get_friendly_error_msg(function_name: str, exception_name: str, exception_message: str):
|
||||
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT
|
||||
|
||||
error_msg = f"Error executing function {function_name}: {exception_name}: {exception_message}"
|
||||
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
||||
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
||||
return error_msg
|
||||
|
||||
Reference in New Issue
Block a user