feat: add tool execution result object (#1837)
This commit is contained in:
@@ -3,7 +3,7 @@ import time
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
@@ -49,8 +49,8 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.response_format import ResponseFormatType
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -557,22 +557,23 @@ class Agent(BaseAgent):
|
||||
},
|
||||
)
|
||||
|
||||
function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
|
||||
tool_execution_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
|
||||
function_response = tool_execution_result.func_return
|
||||
|
||||
log_event(
|
||||
"tool_call_ended",
|
||||
attributes={
|
||||
"function_response": function_response,
|
||||
"sandbox_run_result": sandbox_run_result.model_dump() if sandbox_run_result else None,
|
||||
"tool_execution_result": tool_execution_result.model_dump(),
|
||||
},
|
||||
)
|
||||
log_telemetry(
|
||||
self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args
|
||||
)
|
||||
|
||||
if sandbox_run_result and sandbox_run_result.status == "error":
|
||||
if tool_execution_result and tool_execution_result.status == "error":
|
||||
tool_return = ToolReturn(
|
||||
status=sandbox_run_result.status, stdout=sandbox_run_result.stdout, stderr=sandbox_run_result.stderr
|
||||
status=tool_execution_result.status, stdout=tool_execution_result.stdout, stderr=tool_execution_result.stderr
|
||||
)
|
||||
messages = self._handle_function_error_response(
|
||||
function_response,
|
||||
@@ -626,14 +627,10 @@ class Agent(BaseAgent):
|
||||
# Step 4: check if function response is an error
|
||||
if function_response_string.startswith(ERROR_MESSAGE_PREFIX):
|
||||
error_msg = function_response_string
|
||||
tool_return = (
|
||||
ToolReturn(
|
||||
status=sandbox_run_result.status,
|
||||
stdout=sandbox_run_result.stdout,
|
||||
stderr=sandbox_run_result.stderr,
|
||||
)
|
||||
if sandbox_run_result
|
||||
else None
|
||||
tool_return = ToolReturn(
|
||||
status=tool_execution_result.status,
|
||||
stdout=tool_execution_result.stdout,
|
||||
stderr=tool_execution_result.stderr,
|
||||
)
|
||||
messages = self._handle_function_error_response(
|
||||
error_msg,
|
||||
@@ -650,14 +647,10 @@ class Agent(BaseAgent):
|
||||
|
||||
# If no failures happened along the way: ...
|
||||
# Step 5: send the info on the function call and function response to GPT
|
||||
tool_return = (
|
||||
ToolReturn(
|
||||
status=sandbox_run_result.status,
|
||||
stdout=sandbox_run_result.stdout,
|
||||
stderr=sandbox_run_result.stderr,
|
||||
)
|
||||
if sandbox_run_result
|
||||
else None
|
||||
tool_return = ToolReturn(
|
||||
status=tool_execution_result.status,
|
||||
stdout=tool_execution_result.stdout,
|
||||
stderr=tool_execution_result.stderr,
|
||||
)
|
||||
messages.append(
|
||||
Message(
|
||||
@@ -669,7 +662,7 @@ class Agent(BaseAgent):
|
||||
content=[TextContent(text=function_response)],
|
||||
tool_call_id=tool_call_id,
|
||||
# Letta extras
|
||||
tool_returns=[tool_return] if sandbox_run_result else None,
|
||||
tool_returns=[tool_return],
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with function response
|
||||
@@ -1262,9 +1255,7 @@ class Agent(BaseAgent):
|
||||
return context_window_breakdown.context_window_size_current
|
||||
|
||||
# TODO: Refactor into separate class v.s. large if/elses here
|
||||
def execute_tool_and_persist_state(
|
||||
self, function_name: str, function_args: dict, target_letta_tool: Tool
|
||||
) -> tuple[Any, Optional[SandboxRunResult]]:
|
||||
def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool) -> ToolExecutionResult:
|
||||
"""
|
||||
Execute tool modifications and persist the state of the agent.
|
||||
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
|
||||
@@ -1326,8 +1317,10 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
||||
sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
|
||||
return function_response, sandbox_run_result
|
||||
return ToolExecutionResult(
|
||||
status="error" if is_error else "success",
|
||||
func_return=function_response,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Parse the source code to extract function annotations
|
||||
@@ -1344,23 +1337,29 @@ class Agent(BaseAgent):
|
||||
agent_state_copy.tools = []
|
||||
agent_state_copy.tool_rules = []
|
||||
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
|
||||
tool_execution_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
|
||||
agent_state=agent_state_copy
|
||||
)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
if updated_agent_state is not None:
|
||||
self.update_memory_if_changed(updated_agent_state.memory)
|
||||
return function_response, sandbox_run_result
|
||||
if tool_execution_result.agent_state is not None:
|
||||
self.update_memory_if_changed(tool_execution_result.agent_state.memory)
|
||||
return tool_execution_result
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
function_response = get_friendly_error_msg(
|
||||
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
|
||||
)
|
||||
return function_response, SandboxRunResult(status="error")
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=function_response,
|
||||
stderr=[traceback.format_exc()],
|
||||
)
|
||||
|
||||
return function_response, None
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
|
||||
def save_agent(agent: Agent):
|
||||
|
||||
@@ -324,11 +324,11 @@ class LettaAgent(BaseAgent):
|
||||
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
|
||||
# TODO: Integrate sandbox result
|
||||
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
||||
function_response, _ = await tool_execution_manager.execute_tool_async(
|
||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||
function_name=tool_name, function_args=tool_args, tool=target_tool
|
||||
)
|
||||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
|
||||
return function_response, True
|
||||
return tool_execution_result.func_return, True
|
||||
except Exception as e:
|
||||
return f"Failed to call tool. Error: {e}", False
|
||||
|
||||
|
||||
@@ -83,12 +83,12 @@ async def execute_tool_wrapper(params: ToolExecutionParams):
|
||||
sandbox_config=params.sbx_config,
|
||||
sandbox_env_vars=params.sbx_env_vars,
|
||||
)
|
||||
result, _ = await mgr.execute_tool_async(
|
||||
tool_execution_result = await mgr.execute_tool_async(
|
||||
function_name=params.tool_call_name,
|
||||
function_args=params.tool_args,
|
||||
tool=target_tool,
|
||||
)
|
||||
return params.agent_id, (result, True)
|
||||
return params.agent_id, (tool_execution_result.func_return, True)
|
||||
except Exception as e:
|
||||
return params.agent_id, (f"Failed to call tool. Error: {e}", False)
|
||||
|
||||
|
||||
@@ -160,12 +160,12 @@ def execute_external_tool(
|
||||
else:
|
||||
agent_state_copy = None
|
||||
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
tool_execution_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy)
|
||||
function_response, updated_agent_state = tool_execution_result.func_return, tool_execution_result.agent_state
|
||||
# TODO: Bring this back
|
||||
# if allow_agent_state_modifications and updated_agent_state is not None:
|
||||
# self.update_memory_if_changed(updated_agent_state.memory)
|
||||
return function_response, sandbox_run_result
|
||||
return function_response, tool_execution_result
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
|
||||
14
letta/schemas/tool_execution_result.py
Normal file
14
letta/schemas/tool_execution_result.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
||||
func_return: Optional[Any] = Field(None, description="The function return object")
|
||||
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
||||
stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
|
||||
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
||||
sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")
|
||||
@@ -1310,17 +1310,17 @@ class SyncServer(Server):
|
||||
|
||||
# Next, attempt to run the tool with the sandbox
|
||||
try:
|
||||
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
|
||||
tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
|
||||
agent_state=agent_state, additional_env_vars=tool_env_vars
|
||||
)
|
||||
return ToolReturnMessage(
|
||||
id="null",
|
||||
tool_call_id="null",
|
||||
date=get_utc_time(),
|
||||
status=sandbox_run_result.status,
|
||||
tool_return=str(sandbox_run_result.func_return),
|
||||
stdout=sandbox_run_result.stdout,
|
||||
stderr=sandbox_run_result.stderr,
|
||||
status=tool_execution_result.status,
|
||||
tool_return=str(tool_execution_result.func_return),
|
||||
stdout=tool_execution_result.stdout,
|
||||
stderr=tool_execution_result.stderr,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.tool_executor.tool_executor import (
|
||||
ExternalComposioToolExecutor,
|
||||
@@ -58,7 +60,7 @@ class ToolExecutionManager:
|
||||
self.sandbox_config = sandbox_config
|
||||
self.sandbox_env_vars = sandbox_env_vars
|
||||
|
||||
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> ToolExecutionResult:
|
||||
"""
|
||||
Execute a tool and persist any state changes.
|
||||
|
||||
@@ -71,36 +73,17 @@ class ToolExecutionManager:
|
||||
Tuple containing the function response and sandbox run result (if applicable)
|
||||
"""
|
||||
try:
|
||||
# Get the appropriate executor for this tool type
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
|
||||
# Execute the tool
|
||||
return executor.execute(
|
||||
function_name, function_args, self.agent_state, tool, self.actor, self.sandbox_config, self.sandbox_env_vars
|
||||
function_name,
|
||||
function_args,
|
||||
self.agent_state,
|
||||
tool,
|
||||
self.actor,
|
||||
self.sandbox_config,
|
||||
self.sandbox_env_vars,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
|
||||
@trace_method
|
||||
async def execute_tool_async(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""
|
||||
Execute a tool asynchronously and persist any state changes.
|
||||
"""
|
||||
try:
|
||||
# Get the appropriate executor for this tool type
|
||||
# TODO: Extend this async model to composio
|
||||
|
||||
if tool.tool_type == ToolType.CUSTOM:
|
||||
executor = SandboxToolExecutor()
|
||||
result_tuple = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
else:
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
result_tuple = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
return result_tuple
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(
|
||||
@@ -108,4 +91,35 @@ class ToolExecutionManager:
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=error_message,
|
||||
stderr=[traceback.format_exc()],
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def execute_tool_async(self, function_name: str, function_args: dict, tool: Tool) -> ToolExecutionResult:
|
||||
"""
|
||||
Execute a tool asynchronously and persist any state changes.
|
||||
"""
|
||||
try:
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
# TODO: Extend this async model to composio
|
||||
if isinstance(executor, SandboxToolExecutor):
|
||||
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
else:
|
||||
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(
|
||||
function_name=function_name,
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=error_message,
|
||||
stderr=[traceback.format_exc()],
|
||||
)
|
||||
|
||||
@@ -13,8 +13,9 @@ from typing import Any, Dict, Optional
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.helpers.tool_execution_helper import (
|
||||
add_imports_and_pydantic_schemas_for_args,
|
||||
@@ -72,7 +73,11 @@ class ToolExecutionSandbox:
|
||||
self.force_recreate = force_recreate
|
||||
self.force_recreate_venv = force_recreate_venv
|
||||
|
||||
def run(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
|
||||
def run(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment.
|
||||
|
||||
@@ -81,7 +86,7 @@ class ToolExecutionSandbox:
|
||||
additional_env_vars (Optional[Dict]): Environment variables to inject into the sandbox
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[AgentState]]: Tuple containing (tool_result, agent_state)
|
||||
ToolExecutionResult: Object containing tool execution outcome (e.g. status, response)
|
||||
"""
|
||||
if tool_settings.e2b_api_key and not self.privileged_tools:
|
||||
logger.debug(f"Using e2b sandbox to execute {self.tool_name}")
|
||||
@@ -115,7 +120,7 @@ class ToolExecutionSandbox:
|
||||
@trace_method
|
||||
def run_local_dir_sandbox(
|
||||
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
|
||||
local_configs = sbx_config.get_local_config()
|
||||
|
||||
@@ -162,7 +167,12 @@ class ToolExecutionSandbox:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
@trace_method
|
||||
def run_local_dir_sandbox_venv(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
|
||||
def run_local_dir_sandbox_venv(
|
||||
self,
|
||||
sbx_config: SandboxConfig,
|
||||
env: Dict[str, str],
|
||||
temp_file_path: str,
|
||||
) -> ToolExecutionResult:
|
||||
local_configs = sbx_config.get_local_config()
|
||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde
|
||||
venv_path = os.path.join(sandbox_dir, local_configs.venv_name)
|
||||
@@ -205,12 +215,12 @@ class ToolExecutionSandbox:
|
||||
func_result, stdout = self.parse_out_function_results_markers(result.stdout)
|
||||
func_return, agent_state = self.parse_best_effort(func_result)
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=[stdout] if stdout else [],
|
||||
stderr=[result.stderr] if result.stderr else [],
|
||||
status="success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@@ -221,12 +231,12 @@ class ToolExecutionSandbox:
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=func_return,
|
||||
agent_state=None,
|
||||
stdout=[e.stdout] if e.stdout else [],
|
||||
stderr=[e.stderr] if e.stderr else [],
|
||||
status="error",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@@ -238,7 +248,12 @@ class ToolExecutionSandbox:
|
||||
raise e
|
||||
|
||||
@trace_method
|
||||
def run_local_dir_sandbox_directly(self, sbx_config: SandboxConfig, env: Dict[str, str], temp_file_path: str) -> SandboxRunResult:
|
||||
def run_local_dir_sandbox_directly(
|
||||
self,
|
||||
sbx_config: SandboxConfig,
|
||||
env: Dict[str, str],
|
||||
temp_file_path: str,
|
||||
) -> ToolExecutionResult:
|
||||
status = "success"
|
||||
func_return, agent_state, stderr = None, None, None
|
||||
|
||||
@@ -288,12 +303,12 @@ class ToolExecutionSandbox:
|
||||
stdout_output = [captured_stdout.getvalue()] if captured_stdout.getvalue() else []
|
||||
stderr_output = [captured_stderr.getvalue()] if captured_stderr.getvalue() else []
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
status=status,
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=stdout_output,
|
||||
stderr=stderr_output,
|
||||
status=status,
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@@ -307,7 +322,11 @@ class ToolExecutionSandbox:
|
||||
|
||||
# e2b sandbox specific functions
|
||||
|
||||
def run_e2b_sandbox(self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None) -> SandboxRunResult:
|
||||
def run_e2b_sandbox(
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> ToolExecutionResult:
|
||||
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
|
||||
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
|
||||
if not sbx or self.force_recreate:
|
||||
@@ -348,12 +367,12 @@ class ToolExecutionSandbox:
|
||||
else:
|
||||
raise ValueError(f"Tool {self.tool_name} returned execution with None")
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
status="error" if execution.error else "success",
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=execution.logs.stdout,
|
||||
stderr=execution.logs.stderr,
|
||||
status="error" if execution.error else "success",
|
||||
sandbox_config_fingerprint=sbx_config.fingerprint(),
|
||||
)
|
||||
|
||||
@@ -535,7 +554,7 @@ class ToolExecutionSandbox:
|
||||
Generate the code string to call the function.
|
||||
|
||||
Args:
|
||||
inject_agent_state (bool): Whether to inject the axgent's state as an input into the tool
|
||||
inject_agent_state (bool): Whether to inject the agent's state as an input into the tool
|
||||
|
||||
Returns:
|
||||
str: Generated code string for calling the tool
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import math
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
@@ -8,8 +9,9 @@ from letta.functions.helpers import execute_composio_action, generate_composio_a
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
@@ -33,7 +35,7 @@ class ToolExecutor(ABC):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
"""Execute the tool and return the result."""
|
||||
|
||||
|
||||
@@ -49,7 +51,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
# Map function names to method calls
|
||||
function_map = {
|
||||
"send_message": self.send_message,
|
||||
@@ -64,7 +66,10 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
# Execute the appropriate function
|
||||
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
|
||||
function_response = function_map[function_name](agent_state, actor, **function_args_copy)
|
||||
return function_response, None
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
def send_message(self, agent_state: AgentState, actor: User, message: str) -> Optional[str]:
|
||||
"""
|
||||
@@ -186,12 +191,11 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA multi-agent core tools."""
|
||||
|
||||
# TODO: Implement
|
||||
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[
|
||||
# Any, Optional[SandboxRunResult]]:
|
||||
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> ToolExecutionResult:
|
||||
# callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
||||
# function_args["self"] = agent # need to attach self to arg since it's dynamically linked
|
||||
# function_response = callable_func(**function_args)
|
||||
# return function_response, None
|
||||
# return ToolExecutionResult(func_return=function_response)
|
||||
|
||||
|
||||
class LettaMemoryToolExecutor(ToolExecutor):
|
||||
@@ -206,7 +210,7 @@ class LettaMemoryToolExecutor(ToolExecutor):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
# Map function names to method calls
|
||||
function_map = {
|
||||
"core_memory_append": self.core_memory_append,
|
||||
@@ -223,7 +227,10 @@ class LettaMemoryToolExecutor(ToolExecutor):
|
||||
# Update memory if changed
|
||||
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
return function_response, None
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
def core_memory_append(self, agent_state: "AgentState", label: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
@@ -273,7 +280,7 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
action_name = generate_composio_action_from_func_name(tool.name)
|
||||
|
||||
# Get entity ID from the agent_state
|
||||
@@ -287,7 +294,10 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
|
||||
)
|
||||
|
||||
return function_response, None
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
def _get_entity_id(self, agent_state: AgentState) -> Optional[str]:
|
||||
"""Extract the entity ID from environment variables."""
|
||||
@@ -302,8 +312,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
|
||||
# TODO: Implement
|
||||
#
|
||||
# def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> Tuple[
|
||||
# Any, Optional[SandboxRunResult]]:
|
||||
# def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> ToolExecutionResult:
|
||||
# # Get the server name from the tool tag
|
||||
# server_name = self._extract_server_name(tool)
|
||||
#
|
||||
@@ -316,8 +325,10 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
# # Execute the tool
|
||||
# function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
||||
#
|
||||
# sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
|
||||
# return function_response, sandbox_run_result
|
||||
# return ToolExecutionResult(
|
||||
# status="error" if is_error else "success",
|
||||
# func_return=function_response,
|
||||
# )
|
||||
#
|
||||
# def _extract_server_name(self, tool: Tool) -> str:
|
||||
# """Extract server name from tool tags."""
|
||||
@@ -360,7 +371,7 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
) -> ToolExecutionResult:
|
||||
|
||||
# Store original memory state
|
||||
orig_memory_str = agent_state.memory.compile()
|
||||
@@ -381,21 +392,19 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
function_name, function_args, actor, tool_object=tool, sandbox_config=sandbox_config, sandbox_env_vars=sandbox_env_vars
|
||||
)
|
||||
|
||||
sandbox_run_result = await sandbox.run(agent_state=agent_state_copy)
|
||||
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
tool_execution_result = await sandbox.run(agent_state=agent_state_copy)
|
||||
|
||||
# Verify memory integrity
|
||||
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
|
||||
# Update agent memory if needed
|
||||
if updated_agent_state is not None:
|
||||
AgentManager().update_memory_if_changed(agent_state.id, updated_agent_state.memory, actor)
|
||||
if tool_execution_result.agent_state is not None:
|
||||
AgentManager().update_memory_if_changed(agent_state.id, tool_execution_result.agent_state.memory, actor)
|
||||
|
||||
return function_response, sandbox_run_result
|
||||
return tool_execution_result
|
||||
|
||||
except Exception as e:
|
||||
return self._handle_execution_error(e, function_name)
|
||||
return self._handle_execution_error(e, function_name, traceback.format_exc())
|
||||
|
||||
def _prepare_function_args(self, function_args: dict, tool: Tool, function_name: str) -> dict:
|
||||
"""Prepare function arguments with proper type coercion."""
|
||||
@@ -417,9 +426,18 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
agent_state_copy.tool_rules = []
|
||||
return agent_state_copy
|
||||
|
||||
def _handle_execution_error(self, exception: Exception, function_name: str) -> Tuple[str, SandboxRunResult]:
|
||||
def _handle_execution_error(
|
||||
self,
|
||||
exception: Exception,
|
||||
function_name: str,
|
||||
stderr: str,
|
||||
) -> ToolExecutionResult:
|
||||
"""Handle tool execution errors."""
|
||||
error_message = get_friendly_error_msg(
|
||||
function_name=function_name, exception_name=type(exception).__name__, exception_message=str(exception)
|
||||
)
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return=error_message,
|
||||
stderr=[stderr],
|
||||
)
|
||||
|
||||
@@ -7,8 +7,9 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult
|
||||
from letta.schemas.sandbox_config import SandboxConfig
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
@@ -64,7 +65,7 @@ class AsyncToolSandboxBase(ABC):
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment asynchronously.
|
||||
Must be implemented by subclasses.
|
||||
|
||||
@@ -2,8 +2,9 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@@ -30,7 +31,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment asynchronously,
|
||||
*always* using a subprocess for execution.
|
||||
@@ -45,7 +46,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
|
||||
async def run_e2b_sandbox(
|
||||
self, agent_state: Optional[AgentState] = None, additional_env_vars: Optional[Dict] = None
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
if self.provided_sandbox_config:
|
||||
sbx_config = self.provided_sandbox_config
|
||||
else:
|
||||
@@ -94,7 +95,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
else:
|
||||
raise ValueError(f"Tool {self.tool_name} returned execution with None")
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=execution.logs.stdout,
|
||||
|
||||
@@ -5,8 +5,9 @@ import tempfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.services.helpers.tool_execution_helper import (
|
||||
create_venv_for_local_sandbox,
|
||||
find_python_executable,
|
||||
@@ -39,7 +40,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
self,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
additional_env_vars: Optional[Dict] = None,
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Run the tool in a sandbox environment asynchronously,
|
||||
*always* using a subprocess for execution.
|
||||
@@ -53,7 +54,11 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
return result
|
||||
|
||||
@trace_method
|
||||
async def run_local_dir_sandbox(self, agent_state: Optional[AgentState], additional_env_vars: Optional[Dict]) -> SandboxRunResult:
|
||||
async def run_local_dir_sandbox(
|
||||
self,
|
||||
agent_state: Optional[AgentState],
|
||||
additional_env_vars: Optional[Dict],
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Unified asynchronougit pus method to run the tool in a local sandbox environment,
|
||||
always via subprocess for multi-core parallelism.
|
||||
@@ -156,7 +161,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
@trace_method
|
||||
async def _execute_tool_subprocess(
|
||||
self, sbx_config, python_executable: str, temp_file_path: str, env: Dict[str, str], cwd: str
|
||||
) -> SandboxRunResult:
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
Execute user code in a subprocess, always capturing stdout and stderr.
|
||||
We parse special markers to extract the pickled result string.
|
||||
@@ -189,7 +194,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
func_result, stdout_text = self.parse_out_function_results_markers(stdout)
|
||||
func_return, agent_state = self.parse_best_effort(func_result)
|
||||
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=agent_state,
|
||||
stdout=[stdout_text] if stdout_text else [],
|
||||
@@ -209,7 +214,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
exception_name=type(e).__name__,
|
||||
exception_message=str(e),
|
||||
)
|
||||
return SandboxRunResult(
|
||||
return ToolExecutionResult(
|
||||
func_return=func_return,
|
||||
agent_state=None,
|
||||
stdout=[],
|
||||
|
||||
497
poetry.lock
generated
497
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -67,9 +67,9 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_get_emojis
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
function_response, sandbox_run_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
tool_execution_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
function_name=composio_get_emojis.name, function_args={}, tool=composio_get_emojis
|
||||
)
|
||||
|
||||
# Small check, it should return something at least
|
||||
assert len(function_response.keys()) > 10
|
||||
assert len(tool_execution_result.func_return.keys()) > 10
|
||||
|
||||
Reference in New Issue
Block a user