feat: Change execution pattern depending on enable_parallel_execution (#5550)
* Change execution pattern depending on * Increase efficiency
This commit is contained in:
committed by
Caren Thomas
parent
f8437d47e2
commit
fc950ecddf
@@ -40,6 +40,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, UsageStatistics
|
||||
from letta.schemas.step import Step, StepProgression
|
||||
from letta.schemas.step_metrics import StepMetrics
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
@@ -954,8 +955,10 @@ class LettaAgentV2(BaseAgentV2):
|
||||
else:
|
||||
# Track tool execution time
|
||||
tool_start_time = get_utc_timestamp_ns()
|
||||
target_tool = next((x for x in agent_state.tools if x.name == tool_call_name), None)
|
||||
|
||||
tool_execution_result = await self._execute_tool(
|
||||
tool_name=tool_call_name,
|
||||
target_tool=target_tool,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
@@ -1076,20 +1079,20 @@ class LettaAgentV2(BaseAgentV2):
|
||||
@trace_method
|
||||
async def _execute_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
target_tool: Tool,
|
||||
tool_args: JsonDict,
|
||||
agent_state: AgentState,
|
||||
agent_step_span: Span | None = None,
|
||||
step_id: str | None = None,
|
||||
run_id: str = None,
|
||||
) -> "ToolExecutionResult":
|
||||
"""
|
||||
Executes a tool and returns the ToolExecutionResult.
|
||||
"""
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
tool_name = target_tool.name
|
||||
|
||||
# Special memory case
|
||||
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
||||
if not target_tool:
|
||||
# TODO: fix this error message
|
||||
return ToolExecutionResult(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncGenerator, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
@@ -796,47 +796,47 @@ class LettaAgentV3(LettaAgentV2):
|
||||
)
|
||||
|
||||
# 5c. Execute tools (sequentially for single, parallel for multiple)
|
||||
if len(exec_specs) == 1:
|
||||
# Single tool - execute directly without asyncio.gather overhead
|
||||
spec = exec_specs[0]
|
||||
async def _run_one(spec: Dict[str, Any]):
|
||||
if spec.get("error"):
|
||||
# Prefill arg validation error
|
||||
result = ToolExecutionResult(status="error", func_return=spec["error"])
|
||||
exec_time = 0
|
||||
elif spec["violated"]:
|
||||
return ToolExecutionResult(status="error", func_return=spec["error"]), 0
|
||||
if spec["violated"]:
|
||||
result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver)
|
||||
exec_time = 0
|
||||
else:
|
||||
t0 = get_utc_timestamp_ns()
|
||||
result = await self._execute_tool(
|
||||
tool_name=spec["name"],
|
||||
tool_args=spec["args"],
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
step_id=step_id,
|
||||
)
|
||||
exec_time = get_utc_timestamp_ns() - t0
|
||||
results = [(result, exec_time)]
|
||||
else:
|
||||
# Multiple tools - execute in parallel
|
||||
async def _run_one(spec):
|
||||
if spec.get("error"):
|
||||
return ToolExecutionResult(status="error", func_return=spec["error"]), 0
|
||||
if spec["violated"]:
|
||||
result = _build_rule_violation_result(spec["name"], valid_tool_names, tool_rules_solver)
|
||||
return result, 0
|
||||
t0 = get_utc_timestamp_ns()
|
||||
res = await self._execute_tool(
|
||||
tool_name=spec["name"],
|
||||
tool_args=spec["args"],
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
step_id=step_id,
|
||||
)
|
||||
dt = get_utc_timestamp_ns() - t0
|
||||
return res, dt
|
||||
return result, 0
|
||||
t0 = get_utc_timestamp_ns()
|
||||
target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None)
|
||||
res = await self._execute_tool(
|
||||
target_tool=target_tool,
|
||||
tool_args=spec["args"],
|
||||
agent_state=agent_state,
|
||||
agent_step_span=agent_step_span,
|
||||
step_id=step_id,
|
||||
)
|
||||
dt = get_utc_timestamp_ns() - t0
|
||||
return res, dt
|
||||
|
||||
results = await asyncio.gather(*[_run_one(s) for s in exec_specs])
|
||||
if len(exec_specs) == 1:
|
||||
results = [await _run_one(exec_specs[0])]
|
||||
else:
|
||||
# separate tools by parallel execution capability
|
||||
parallel_items = []
|
||||
serial_items = []
|
||||
|
||||
for idx, spec in enumerate(exec_specs):
|
||||
target_tool = next((x for x in agent_state.tools if x.name == spec["name"]), None)
|
||||
if target_tool and target_tool.enable_parallel_execution:
|
||||
parallel_items.append((idx, spec))
|
||||
else:
|
||||
serial_items.append((idx, spec))
|
||||
|
||||
# execute all parallel tools concurrently and all serial tools sequentially
|
||||
results = [None] * len(exec_specs)
|
||||
|
||||
parallel_results = await asyncio.gather(*[_run_one(spec) for _, spec in parallel_items]) if parallel_items else []
|
||||
for (idx, _), result in zip(parallel_items, parallel_results):
|
||||
results[idx] = result
|
||||
|
||||
for idx, spec in serial_items:
|
||||
results[idx] = await _run_one(spec)
|
||||
|
||||
# 5d. Update metrics with execution time
|
||||
if step_metrics is not None and results:
|
||||
|
||||
Reference in New Issue
Block a user