feat: Change execution pattern depending on enable_parallel_execution (#5550)

* Change execution pattern depending on

* Increase efficiency
This commit is contained in:
Matthew Zhou
2025-10-17 17:33:50 -07:00
committed by Caren Thomas
parent f8437d47e2
commit fc950ecddf
2 changed files with 46 additions and 43 deletions

View File

@@ -40,6 +40,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, UsageStatistics
from letta.schemas.step import Step, StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
@@ -954,8 +955,10 @@ class LettaAgentV2(BaseAgentV2):
else:
# Track tool execution time
tool_start_time = get_utc_timestamp_ns()
target_tool = next((x for x in agent_state.tools if x.name == tool_call_name), None)
tool_execution_result = await self._execute_tool(
tool_name=tool_call_name,
target_tool=target_tool,
tool_args=tool_args,
agent_state=agent_state,
agent_step_span=agent_step_span,
@@ -1076,20 +1079,20 @@ class LettaAgentV2(BaseAgentV2):
@trace_method
async def _execute_tool(
self,
tool_name: str,
target_tool: Tool,
tool_args: JsonDict,
agent_state: AgentState,
agent_step_span: Span | None = None,
step_id: str | None = None,
run_id: str = None,
) -> "ToolExecutionResult":
"""
Executes a tool and returns the ToolExecutionResult.
"""
from letta.schemas.tool_execution_result import ToolExecutionResult
tool_name = target_tool.name
# Special memory case
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
if not target_tool:
# TODO: fix this error message
return ToolExecutionResult(

View File

@@ -1,7 +1,7 @@
import asyncio
import json
import uuid
from typing import AsyncGenerator, Optional
from typing import Any, AsyncGenerator, Dict, Optional
from opentelemetry.trace import Span
@@ -796,47 +796,47 @@ class LettaAgentV3(LettaAgentV2):
)
# 5c. Execute tools (sequentially for single, parallel for multiple)
if len(exec_specs) == 1:
# Single tool - execute directly without asyncio.gather overhead
spec = exec_specs[0]
async def _run_one(spec: Dict[str, Any]):
if spec.get("error"):
# Prefill arg validation error
result = ToolExecutionResult(status="error", func_return=spec["error"])
exec_time = 0
elif spec["violated"]:
return ToolExecutionResult(status="error", func_return=spec["error"]), 0
if spec["violated"]:
result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver)
exec_time = 0
else:
t0 = get_utc_timestamp_ns()
result = await self._execute_tool(
tool_name=spec["name"],
tool_args=spec["args"],
agent_state=agent_state,
agent_step_span=agent_step_span,
step_id=step_id,
)
exec_time = get_utc_timestamp_ns() - t0
results = [(result, exec_time)]
else:
# Multiple tools - execute in parallel
async def _run_one(spec):
if spec.get("error"):
return ToolExecutionResult(status="error", func_return=spec["error"]), 0
if spec["violated"]:
result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver)
return result, 0
t0 = get_utc_timestamp_ns()
res = await self._execute_tool(
tool_name=spec["name"],
tool_args=spec["args"],
agent_state=agent_state,
agent_step_span=agent_step_span,
step_id=step_id,
)
dt = get_utc_timestamp_ns() - t0
return res, dt
return result, 0
t0 = get_utc_timestamp_ns()
target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None)
res = await self._execute_tool(
target_tool=target_tool,
tool_args=spec["args"],
agent_state=agent_state,
agent_step_span=agent_step_span,
step_id=step_id,
)
dt = get_utc_timestamp_ns() - t0
return res, dt
results = await asyncio.gather(*[_run_one(s) for s in exec_specs])
if len(exec_specs) == 1:
results = [await _run_one(exec_specs[0])]
else:
# separate tools by parallel execution capability
parallel_items = []
serial_items = []
for idx, spec in enumerate(exec_specs):
target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None)
if target_tool and target_tool.enable_parallel_execution:
parallel_items.append((idx, spec))
else:
serial_items.append((idx, spec))
# execute all parallel tools concurrently and all serial tools sequentially
results = [None] * len(exec_specs)
parallel_results = await asyncio.gather(*[_run_one(spec) for _, spec in parallel_items]) if parallel_items else []
for (idx, _), result in zip(parallel_items, parallel_results):
results[idx] = result
for idx, spec in serial_items:
results[idx] = await _run_one(spec)
# 5d. Update metrics with execution time
if step_metrics is not None and results: