From fc950ecddf1e0e598590ece0cd63d9b2e3fd51e4 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 17 Oct 2025 17:33:50 -0700 Subject: [PATCH] feat: Change execution pattern depending on `enable_parallel_execution` (#5550) * Change execution pattern depending on * Increase efficiency --- letta/agents/letta_agent_v2.py | 11 +++-- letta/agents/letta_agent_v3.py | 78 +++++++++++++++++----------------- 2 files changed, 46 insertions(+), 43 deletions(-) diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index eb5d15db..fea764f3 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -40,6 +40,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, UsageStatistics from letta.schemas.step import Step, StepProgression from letta.schemas.step_metrics import StepMetrics +from letta.schemas.tool import Tool from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -954,8 +955,10 @@ class LettaAgentV2(BaseAgentV2): else: # Track tool execution time tool_start_time = get_utc_timestamp_ns() + target_tool = next((x for x in agent_state.tools if x.name == tool_call_name), None) + tool_execution_result = await self._execute_tool( - tool_name=tool_call_name, + target_tool=target_tool, tool_args=tool_args, agent_state=agent_state, agent_step_span=agent_step_span, @@ -1076,20 +1079,20 @@ class LettaAgentV2(BaseAgentV2): @trace_method async def _execute_tool( self, - tool_name: str, + target_tool: Tool, tool_args: JsonDict, agent_state: AgentState, agent_step_span: Span | None = None, step_id: str | None = None, - run_id: str = None, ) -> "ToolExecutionResult": """ Executes a tool and returns the ToolExecutionResult. """ from letta.schemas.tool_execution_result import ToolExecutionResult + tool_name = target_tool.name + # Special memory case - target_tool = next((x for x in agent_state.tools if x.name == tool_name), None) if not target_tool: # TODO: fix this error message return ToolExecutionResult( diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 5d5f7208..81a54e46 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -1,7 +1,7 @@ import asyncio import json import uuid -from typing import AsyncGenerator, Optional +from typing import Any, AsyncGenerator, Dict, Optional from opentelemetry.trace import Span @@ -796,47 +796,47 @@ class LettaAgentV3(LettaAgentV2): ) # 5c. Execute tools (sequentially for single, parallel for multiple) - if len(exec_specs) == 1: - # Single tool - execute directly without asyncio.gather overhead - spec = exec_specs[0] + async def _run_one(spec: Dict[str, Any]): if spec.get("error"): - # Prefill arg validation error - result = ToolExecutionResult(status="error", func_return=spec["error"]) - exec_time = 0 - elif spec["violated"]: + return ToolExecutionResult(status="error", func_return=spec["error"]), 0 + if spec["violated"]: result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver) - exec_time = 0 - else: - t0 = get_utc_timestamp_ns() - result = await self._execute_tool( - tool_name=spec["name"], - tool_args=spec["args"], - agent_state=agent_state, - agent_step_span=agent_step_span, - step_id=step_id, - ) - exec_time = get_utc_timestamp_ns() - t0 - results = [(result, exec_time)] - else: - # Multiple tools - execute in parallel - async def _run_one(spec): - if spec.get("error"): - return ToolExecutionResult(status="error", func_return=spec["error"]), 0 - if spec["violated"]: - result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver) - return result, 0 - t0 = get_utc_timestamp_ns() - res = await self._execute_tool( - tool_name=spec["name"], - tool_args=spec["args"], - agent_state=agent_state, - agent_step_span=agent_step_span, - step_id=step_id, - ) - dt = get_utc_timestamp_ns() - t0 - return res, dt + return result, 0 + t0 = get_utc_timestamp_ns() + target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None) + res = await self._execute_tool( + target_tool=target_tool, + tool_args=spec["args"], + agent_state=agent_state, + agent_step_span=agent_step_span, + step_id=step_id, + ) + dt = get_utc_timestamp_ns() - t0 + return res, dt - results = await asyncio.gather(*[_run_one(s) for s in exec_specs]) + if len(exec_specs) == 1: + results = [await _run_one(exec_specs[0])] + else: + # separate tools by parallel execution capability + parallel_items = [] + serial_items = [] + + for idx, spec in enumerate(exec_specs): + target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None) + if target_tool and target_tool.enable_parallel_execution: + parallel_items.append((idx, spec)) + else: + serial_items.append((idx, spec)) + + # execute all parallel tools concurrently and all serial tools sequentially + results = [None] * len(exec_specs) + + parallel_results = await asyncio.gather(*[_run_one(spec) for _, spec in parallel_items]) if parallel_items else [] + for (idx, _), result in zip(parallel_items, parallel_results): + results[idx] = result + + for idx, spec in serial_items: + results[idx] = await _run_one(spec) # 5d. Update metrics with execution time if step_metrics is not None and results: