feat(asyncify): migrate run tools (#2496)
This commit is contained in:
@@ -193,7 +193,7 @@ async def upsert_base_tools(
|
||||
|
||||
|
||||
@router.post("/run", response_model=ToolReturnMessage, operation_id="run_tool_from_source")
|
||||
def run_tool_from_source(
|
||||
async def run_tool_from_source(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: ToolRunFromSource = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
@@ -201,10 +201,10 @@ def run_tool_from_source(
|
||||
"""
|
||||
Attempt to build a tool from source, then run it on the provided arguments
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.run_tool_from_source(
|
||||
return await server.run_tool_from_source(
|
||||
tool_source=request.source_code,
|
||||
tool_source_type=request.source_type,
|
||||
tool_args=request.args,
|
||||
|
||||
@@ -96,7 +96,7 @@ from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
@@ -1865,7 +1865,7 @@ class SyncServer(Server):
|
||||
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
|
||||
"""Add a new embedding model"""
|
||||
|
||||
def run_tool_from_source(
|
||||
async def run_tool_from_source(
|
||||
self,
|
||||
actor: User,
|
||||
tool_args: Dict[str, str],
|
||||
@@ -1898,8 +1898,20 @@ class SyncServer(Server):
|
||||
|
||||
# Next, attempt to run the tool with the sandbox
|
||||
try:
|
||||
tool_execution_result = ToolExecutionSandbox(tool.name, tool_args, actor, tool_object=tool).run(
|
||||
agent_state=agent_state, additional_env_vars=tool_env_vars
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=actor,
|
||||
sandbox_env_vars=tool_env_vars,
|
||||
)
|
||||
# TODO: Integrate sandbox result
|
||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||
function_name=tool_name,
|
||||
function_args=tool_args,
|
||||
tool=tool,
|
||||
)
|
||||
return ToolReturnMessage(
|
||||
id="null",
|
||||
|
||||
@@ -69,8 +69,8 @@ class ToolExecutionManager:
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
@@ -98,7 +98,9 @@ class ToolExecutionManager:
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
result = await executor.execute(
|
||||
function_name, function_args, tool, self.actor, self.agent_state, self.sandbox_config, self.sandbox_env_vars
|
||||
)
|
||||
|
||||
# trim result
|
||||
return_str = str(result.func_return)
|
||||
|
||||
@@ -68,9 +68,9 @@ class ToolExecutor(ABC):
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
@@ -84,13 +84,14 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
# Map function names to method calls
|
||||
assert agent_state is not None, "Agent state is required for core tools"
|
||||
function_map = {
|
||||
"send_message": self.send_message,
|
||||
"conversation_search": self.conversation_search,
|
||||
@@ -537,12 +538,13 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
assert agent_state is not None, "Agent state is required for multi-agent tools"
|
||||
function_map = {
|
||||
"send_message_to_agent_and_wait_for_reply": self.send_message_to_agent_and_wait_for_reply,
|
||||
"send_message_to_agent_async": self.send_message_to_agent_async,
|
||||
@@ -644,12 +646,13 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
assert agent_state is not None, "Agent state is required for external Composio tools"
|
||||
action_name = generate_composio_action_from_func_name(tool.name)
|
||||
|
||||
# Get entity ID from the agent_state
|
||||
@@ -684,9 +687,9 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
@@ -718,21 +721,21 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: JsonDict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
|
||||
# Store original memory state
|
||||
orig_memory_str = agent_state.memory.compile()
|
||||
orig_memory_str = agent_state.memory.compile() if agent_state else None
|
||||
|
||||
try:
|
||||
# Prepare function arguments
|
||||
function_args = self._prepare_function_args(function_args, tool, function_name)
|
||||
|
||||
agent_state_copy = self._create_agent_state_copy(agent_state)
|
||||
agent_state_copy = self._create_agent_state_copy(agent_state) if agent_state else None
|
||||
|
||||
# Execute in sandbox depending on API key
|
||||
if tool_settings.e2b_api_key:
|
||||
@@ -747,7 +750,8 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
tool_execution_result = await sandbox.run(agent_state=agent_state_copy)
|
||||
|
||||
# Verify memory integrity
|
||||
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
if agent_state:
|
||||
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
|
||||
# Update agent memory if needed
|
||||
if tool_execution_result.agent_state is not None:
|
||||
@@ -805,9 +809,9 @@ class LettaBuiltinToolExecutor(ToolExecutor):
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
agent_state: Optional[AgentState] = None,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
|
||||
@@ -771,9 +771,10 @@ def ingest(message: str):
|
||||
import pytest
|
||||
|
||||
|
||||
def test_tool_run_basic(server, disable_e2b_api_key, user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run_basic(server, disable_e2b_api_key, user):
|
||||
"""Test running a simple tool from source"""
|
||||
result = server.run_tool_from_source(
|
||||
result = await server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
@@ -785,9 +786,10 @@ def test_tool_run_basic(server, disable_e2b_api_key, user):
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_tool_run_with_env_var(server, disable_e2b_api_key, user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run_with_env_var(server, disable_e2b_api_key, user):
|
||||
"""Test running a tool that uses an environment variable"""
|
||||
result = server.run_tool_from_source(
|
||||
result = await server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_ENV_VAR,
|
||||
tool_source_type="python",
|
||||
@@ -800,9 +802,10 @@ def test_tool_run_with_env_var(server, disable_e2b_api_key, user):
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_tool_run_invalid_args(server, disable_e2b_api_key, user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run_invalid_args(server, disable_e2b_api_key, user):
|
||||
"""Test running a tool with incorrect arguments"""
|
||||
result = server.run_tool_from_source(
|
||||
result = await server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
@@ -816,9 +819,10 @@ def test_tool_run_invalid_args(server, disable_e2b_api_key, user):
|
||||
assert "missing 1 required positional argument" in result.stderr[0]
|
||||
|
||||
|
||||
def test_tool_run_with_distractor(server, disable_e2b_api_key, user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run_with_distractor(server, disable_e2b_api_key, user):
|
||||
"""Test running a tool with a distractor function in the source"""
|
||||
result = server.run_tool_from_source(
|
||||
result = await server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
tool_source_type="python",
|
||||
@@ -831,9 +835,10 @@ def test_tool_run_with_distractor(server, disable_e2b_api_key, user):
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user):
|
||||
"""Test selecting a tool by name when multiple tools exist in the source"""
|
||||
result = server.run_tool_from_source(
|
||||
result = await server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
tool_source_type="python",
|
||||
@@ -847,9 +852,10 @@ def test_tool_run_explicit_tool_name(server, disable_e2b_api_key, user):
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_tool_run_util_function(server, disable_e2b_api_key, user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run_util_function(server, disable_e2b_api_key, user):
|
||||
"""Test selecting a utility function that does not return anything meaningful"""
|
||||
result = server.run_tool_from_source(
|
||||
result = await server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
|
||||
tool_source_type="python",
|
||||
@@ -863,7 +869,8 @@ def test_tool_run_util_function(server, disable_e2b_api_key, user):
|
||||
assert not result.stderr
|
||||
|
||||
|
||||
def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user):
|
||||
"""Test overriding the autogenerated JSON schema with an explicit one"""
|
||||
explicit_json_schema = {
|
||||
"name": "ingest",
|
||||
@@ -881,7 +888,7 @@ def test_tool_run_with_explicit_json_schema(server, disable_e2b_api_key, user):
|
||||
},
|
||||
}
|
||||
|
||||
result = server.run_tool_from_source(
|
||||
result = await server.run_tool_from_source(
|
||||
actor=user,
|
||||
tool_source=EXAMPLE_TOOL_SOURCE,
|
||||
tool_source_type="python",
|
||||
|
||||
Reference in New Issue
Block a user