feat(asyncify): migrate run tools (#2496)

This commit is contained in:
cthomas
2025-05-28 14:25:17 -07:00
committed by GitHub
parent 682b997c65
commit 43cc87b3b7
5 changed files with 58 additions and 33 deletions

View File

@@ -193,7 +193,7 @@ async def upsert_base_tools(
@router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source")
def run_tool_from_source(
async def run_tool_from_source(
server: SyncServer = Depends(get_letta_server),
request: ToolRunFromSource = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -201,10 +201,10 @@ def run_tool_from_source(
"""
Attempt to build a tool from source, then run it on the provided arguments
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
return server.run_tool_from_source(
return await server.run_tool_from_source(
tool_source=request.source_code,
tool_source_type=request.source_type,
tool_args=request.args,

View File

@@ -96,7 +96,7 @@ from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.step_manager import StepManager
from letta.services.telemetry_manager import TelemetryManager
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings, settings, tool_settings
@@ -1865,7 +1865,7 @@ class SyncServer(Server):
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
"""Add a new embedding model"""
def run_tool_from_source(
async def run_tool_from_source(
self,
actor: User,
tool_args: Dict[str, str],
@@ -1898,8 +1898,20 @@ class SyncServer(Server):
# Next, attempt to run the tool with the sandbox
try:
tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
agent_state=agent_state, additional_env_vars=tool_env_vars
tool_execution_manager = ToolExecutionManager(
agent_state=agent_state,
message_manager=self.message_manager,
agent_manager=self.agent_manager,
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=actor,
sandbox_env_vars=tool_env_vars,
)
# TODO: Integrate sandbox result
tool_execution_result = await tool_execution_manager.execute_tool_async(
function_name=tool_name,
function_args=tool_args,
tool=tool,
)
return ToolReturnMessage(
id="null",

View File

@@ -69,8 +69,8 @@ class ToolExecutionManager:
agent_manager: AgentManager,
block_manager: BlockManager,
passage_manager: PassageManager,
agent_state: AgentState,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
):
@@ -98,7 +98,9 @@ class ToolExecutionManager:
passage_manager=self.passage_manager,
actor=self.actor,
)
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
result = await executor.execute(
function_name, function_args, tool, self.actor, self.agent_state, self.sandbox_config, self.sandbox_env_vars
)
# trim result
return_str = str(result.func_return)

View File

@@ -68,9 +68,9 @@ class ToolExecutor(ABC):
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
@@ -84,13 +84,14 @@ class LettaCoreToolExecutor(ToolExecutor):
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
# Map function names to method calls
assert agent_state is not None, "Agent state is required for core tools"
function_map = {
"send_message": self.send_message,
"conversation_search": self.conversation_search,
@@ -537,12 +538,13 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
assert agent_state is not None, "Agent state is required for multi-agent tools"
function_map = {
"send_message_to_agent_and_wait_for_reply": self.send_message_to_agent_and_wait_for_reply,
"send_message_to_agent_async": self.send_message_to_agent_async,
@@ -644,12 +646,13 @@ class ExternalComposioToolExecutor(ToolExecutor):
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
assert agent_state is not None, "Agent state is required for external Composio tools"
action_name = generate_composio_action_from_func_name(tool.name)
# Get entity ID from the agent_state
@@ -684,9 +687,9 @@ class ExternalMCPToolExecutor(ToolExecutor):
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
@@ -718,21 +721,21 @@ class SandboxToolExecutor(ToolExecutor):
self,
function_name: str,
function_args: JsonDict,
agent_state: AgentState,
tool: Tool,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
# Store original memory state
orig_memory_str = agent_state.memory.compile()
orig_memory_str = agent_state.memory.compile() if agent_state else None
try:
# Prepare function arguments
function_args = self._prepare_function_args(function_args, tool, function_name)
agent_state_copy = self._create_agent_state_copy(agent_state)
agent_state_copy = self._create_agent_state_copy(agent_state) if agent_state else None
# Execute in sandbox depending on API key
if tool_settings.e2b_api_key:
@@ -747,7 +750,8 @@ class SandboxToolExecutor(ToolExecutor):
tool_execution_result = await sandbox.run(agent_state=agent_state_copy)
# Verify memory integrity
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
if agent_state:
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
# Update agent memory if needed
if tool_execution_result.agent_state is not None:
@@ -805,9 +809,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
self,
function_name: str,
function_args: dict,
agent_state: AgentState,
tool: Tool,
actor: User,
agent_state: Optional[AgentState] = None,
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:

View File

@@ -771,9 +771,10 @@ def ingest(message: str):
import pytest
def test_tool_run_basic(server, disable_e2b_api_key, user):
@pytest.mark.asyncio
async def test_tool_run_basic(server, disable_e2b_api_key, user):
"""Test running a simple tool from source"""
result = server.run_tool_from_source(
result = await server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
@@ -785,9 +786,10 @@ def test_tool_run_basic(server, disable_e2b_api_key, user):
assert not result.stderr
def test_tool_run_with_env_var(server, disable_e2b_api_key, user):
@pytest.mark.asyncio
async def test_tool_run_with_env_var(server, disable_e2b_api_key, user):
"""Test running a tool that uses an environment variable"""
result = server.run_tool_from_source(
result = await server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR,
tool_source_type="python",
@@ -800,9 +802,10 @@ def test_tool_run_with_env_var(server, disable_e2b_api_key, user):
assert not result.stderr
def test_tool_run_invalid_args(server, disable_e2b_api_key, user):
@pytest.mark.asyncio
async def test_tool_run_invalid_args(server, disable_e2b_api_key, user):
"""Test running a tool with incorrect arguments"""
result = server.run_tool_from_source(
result = await server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
@@ -816,9 +819,10 @@ def test_tool_run_invalid_args(server, disable_e2b_api_key, user):
assert "missing 1 required positional argument" in result.stderr[0]
def test_tool_run_with_distractor(server, disable_e2b_api_key, user):
@pytest.mark.asyncio
async def test_tool_run_with_distractor(server, disable_e2b_api_key, user):
"""Test running a tool with a distractor function in the source"""
result = server.run_tool_from_source(
result = await server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
@@ -831,9 +835,10 @@ def test_tool_run_with_distractor(server, disable_e2b_api_key, user):
assert not result.stderr
def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user):
@pytest.mark.asyncio
async def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user):
"""Test selecting a tool by name when multiple tools exist in the source"""
result = server.run_tool_from_source(
result = await server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
@@ -847,9 +852,10 @@ def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user):
assert not result.stderr
def test_tool_run_util_function(server, disable_e2b_api_key, user):
@pytest.mark.asyncio
async def test_tool_run_util_function(server, disable_e2b_api_key, user):
"""Test selecting a utility function that does not return anything meaningful"""
result = server.run_tool_from_source(
result = await server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
@@ -863,7 +869,8 @@ def test_tool_run_util_function(server, disable_e2b_api_key, user):
assert not result.stderr
def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user):
@pytest.mark.asyncio
async def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user):
"""Test overriding the autogenerated JSON schema with an explicit one"""
explicit_json_schema = {
"name": "ingest",
@@ -881,7 +888,7 @@ def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user):
},
}
result = server.run_tool_from_source(
result = await server.run_tool_from_source(
actor=user,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",