feat: parallel tool calling for openai non streaming [LET-4593] (#5773)
* first hack * clean up * first implementation working * revert package-lock * remove openai test * error throw * typo * Update integration_test_send_message_v2.py * Update integration_test_send_message_v2.py * refine test * Only make changes for openai non streaming * Add tests --------- Co-authored-by: Ari Webb <ari@letta.com> Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -78,6 +78,13 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
use_responses = "input" in request_data and "messages" not in request_data
|
||||
# No support for Responses API proxy
|
||||
is_proxy = self.llm_config.provider_name == "lmstudio_openai"
|
||||
# Use parallel tool calling interface if enabled in config
|
||||
use_parallel = self.llm_config.parallel_tool_calls and tools and not use_responses and not is_proxy
|
||||
|
||||
# TODO: Temp, remove
|
||||
if use_parallel:
|
||||
raise RuntimeError("Parallel tool calling not supported for OpenAI streaming")
|
||||
|
||||
if use_responses and not is_proxy:
|
||||
self.interface = SimpleOpenAIResponsesStreamingInterface(
|
||||
is_openai_proxy=False,
|
||||
|
||||
@@ -364,13 +364,15 @@ class LettaAgentV3(LettaAgentV2):
|
||||
requires_subsequent_tool_call=self._require_tool_call,
|
||||
)
|
||||
# TODO: Extend to more providers, and also approval tool rules
|
||||
# Enable Anthropic parallel tool use when no tool rules are attached
|
||||
# Enable parallel tool use when no tool rules are attached
|
||||
try:
|
||||
no_tool_rules = (
|
||||
not self.agent_state.tool_rules
|
||||
or len([t for t in self.agent_state.tool_rules if t.type != "requires_approval"]) == 0
|
||||
)
|
||||
|
||||
# Anthropic/Bedrock parallel tool use
|
||||
if self.agent_state.llm_config.model_endpoint_type in ["anthropic", "bedrock"]:
|
||||
no_tool_rules = (
|
||||
not self.agent_state.tool_rules
|
||||
or len([t for t in self.agent_state.tool_rules if t.type != "requires_approval"]) == 0
|
||||
)
|
||||
if (
|
||||
isinstance(request_data.get("tool_choice"), dict)
|
||||
and "disable_parallel_tool_use" in request_data["tool_choice"]
|
||||
@@ -381,6 +383,16 @@ class LettaAgentV3(LettaAgentV2):
|
||||
else:
|
||||
# Explicitly disable when tool rules present or llm_config toggled off
|
||||
request_data["tool_choice"]["disable_parallel_tool_use"] = True
|
||||
|
||||
# OpenAI parallel tool use
|
||||
elif self.agent_state.llm_config.model_endpoint_type == "openai":
|
||||
# For OpenAI, we control parallel tool calling via parallel_tool_calls field
|
||||
# Only allow parallel tool calls when no tool rules and enabled in config
|
||||
if "parallel_tool_calls" in request_data:
|
||||
if no_tool_rules and self.agent_state.llm_config.parallel_tool_calls:
|
||||
request_data["parallel_tool_calls"] = True
|
||||
else:
|
||||
request_data["parallel_tool_calls"] = False
|
||||
except Exception:
|
||||
# if this fails, we simply don't enable parallel tool use
|
||||
pass
|
||||
@@ -435,11 +447,13 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self._update_global_usage_stats(llm_adapter.usage)
|
||||
|
||||
# Handle the AI response with the extracted data (supports multiple tool calls)
|
||||
# Gather tool calls. Approval paths specify a single tool call.
|
||||
# Gather tool calls - check for multi-call API first, then fall back to single
|
||||
if hasattr(llm_adapter, "tool_calls") and llm_adapter.tool_calls:
|
||||
tool_calls = llm_adapter.tool_calls
|
||||
elif llm_adapter.tool_call is not None:
|
||||
tool_calls = [llm_adapter.tool_call]
|
||||
else:
|
||||
tool_calls = []
|
||||
|
||||
aggregated_persisted: list[Message] = []
|
||||
persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
|
||||
@@ -931,15 +945,25 @@ class LettaAgentV3(LettaAgentV2):
|
||||
)
|
||||
|
||||
# 5g. Aggregate continuation decisions
|
||||
# For multiple tools: continue if ANY says continue, use last non-None stop_reason
|
||||
# For single tool: use its decision directly
|
||||
aggregate_continue = any(persisted_continue_flags) if persisted_continue_flags else False
|
||||
aggregate_continue = aggregate_continue or tool_call_denials or tool_returns # continue if any tool call was denied or returned
|
||||
aggregate_continue = aggregate_continue or tool_call_denials or tool_returns
|
||||
|
||||
# Determine aggregate stop reason
|
||||
aggregate_stop_reason = None
|
||||
for sr in persisted_stop_reasons:
|
||||
if sr is not None:
|
||||
aggregate_stop_reason = sr
|
||||
|
||||
# For parallel tool calls, always continue to allow the agent to process/summarize results
|
||||
# unless a terminal tool was called or we hit max steps
|
||||
if len(exec_specs) > 1:
|
||||
has_terminal = any(sr and sr.stop_reason == StopReasonType.tool_rule.value for sr in persisted_stop_reasons)
|
||||
is_max_steps = any(sr and sr.stop_reason == StopReasonType.max_steps.value for sr in persisted_stop_reasons)
|
||||
|
||||
if not has_terminal and not is_max_steps:
|
||||
# Force continuation for parallel tool execution
|
||||
aggregate_continue = True
|
||||
aggregate_stop_reason = None
|
||||
return persisted_messages, aggregate_continue, aggregate_stop_reason
|
||||
|
||||
@trace_method
|
||||
|
||||
@@ -624,8 +624,8 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
|
||||
data = chat_completion_request.model_dump(exclude_none=True)
|
||||
|
||||
# add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified."
|
||||
if chat_completion_request.tools is not None:
|
||||
data["parallel_tool_calls"] = False
|
||||
if chat_completion_request.tools is not None and chat_completion_request.parallel_tool_calls is not None:
|
||||
data["parallel_tool_calls"] = chat_completion_request.parallel_tool_calls
|
||||
|
||||
# If functions == None, strip from the payload
|
||||
if "functions" in data and data["functions"] is None:
|
||||
|
||||
@@ -324,7 +324,7 @@ class OpenAIClient(LLMClientBase):
|
||||
tool_choice=tool_choice,
|
||||
max_output_tokens=llm_config.max_tokens,
|
||||
temperature=llm_config.temperature if supports_temperature_param(model) else None,
|
||||
parallel_tool_calls=False,
|
||||
parallel_tool_calls=llm_config.parallel_tool_calls if tools and supports_parallel_tool_calling(model) else False,
|
||||
)
|
||||
|
||||
# Add verbosity control for GPT-5 models
|
||||
@@ -349,7 +349,7 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
# Add parallel tool calling
|
||||
if tools and supports_parallel_tool_calling(model):
|
||||
data.parallel_tool_calls = False
|
||||
data.parallel_tool_calls = llm_config.parallel_tool_calls
|
||||
|
||||
# always set user id for openai requests
|
||||
if self.actor:
|
||||
|
||||
@@ -1138,8 +1138,14 @@ class Message(BaseMessage):
|
||||
assert self.tool_calls is not None or text_content is not None, vars(self)
|
||||
except AssertionError as e:
|
||||
# relax check if this message only contains reasoning content
|
||||
if self.content is not None and len(self.content) > 0 and isinstance(self.content[0], ReasoningContent):
|
||||
return None
|
||||
if self.content is not None and len(self.content) > 0:
|
||||
# Check if all non-empty content is reasoning-related
|
||||
all_reasoning = all(
|
||||
isinstance(c, (ReasoningContent, SummarizedReasoningContent, OmittedReasoningContent, RedactedReasoningContent))
|
||||
for c in self.content
|
||||
)
|
||||
if all_reasoning:
|
||||
return None
|
||||
raise e
|
||||
|
||||
# if native content, then put it directly inside the content
|
||||
@@ -1228,15 +1234,32 @@ class Message(BaseMessage):
|
||||
use_developer_message: bool = False,
|
||||
) -> List[dict]:
|
||||
messages = Message.filter_messages_for_llm_api(messages)
|
||||
result = [
|
||||
m.to_openai_dict(
|
||||
result: List[dict] = []
|
||||
|
||||
for m in messages:
|
||||
# Special case: OpenAI Chat Completions requires a separate tool message per tool_call_id
|
||||
# If we have multiple explicit tool_returns on a single Message, expand into one dict per return
|
||||
if m.role == MessageRole.tool and m.tool_returns and len(m.tool_returns) > 0:
|
||||
for tr in m.tool_returns:
|
||||
if not tr.tool_call_id:
|
||||
raise TypeError("ToolReturn came back without a tool_call_id.")
|
||||
result.append(
|
||||
{
|
||||
"content": tr.func_response,
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id[:max_tool_id_length] if max_tool_id_length else tr.tool_call_id,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
d = m.to_openai_dict(
|
||||
max_tool_id_length=max_tool_id_length,
|
||||
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
|
||||
use_developer_message=use_developer_message,
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
result = [m for m in result if m is not None]
|
||||
if d is not None:
|
||||
result.append(d)
|
||||
|
||||
return result
|
||||
|
||||
def to_openai_responses_dicts(
|
||||
|
||||
@@ -533,15 +533,18 @@ async def test_greeting(
|
||||
)
|
||||
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_parallel_tool_call_anthropic(
|
||||
async def test_parallel_tool_calls(
|
||||
disable_e2b_api_key: Any,
|
||||
client: AsyncLetta,
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
send_type: str,
|
||||
) -> None:
|
||||
if llm_config.model_endpoint_type != "anthropic":
|
||||
pytest.skip("Parallel tool calling test only applies to Anthropic models.")
|
||||
if llm_config.model_endpoint_type != "anthropic" and llm_config.model_endpoint_type != "openai":
|
||||
pytest.skip("Parallel tool calling test only applies to Anthropic and OpenAI models.")
|
||||
|
||||
if llm_config.model_endpoint_type == "openai" and send_type not in {"step", "stream_steps"}:
|
||||
pytest.skip(f"OpenAI reasoning model {llm_config.model} does not support streaming parallel tool calling for now.")
|
||||
|
||||
# change llm_config to support parallel tool calling
|
||||
llm_config.parallel_tool_calls = True
|
||||
@@ -587,7 +590,10 @@ async def test_parallel_tool_call_anthropic(
|
||||
# verify each tool call
|
||||
for tc in tool_call_msg.tool_calls:
|
||||
assert tc["name"] == "roll_dice"
|
||||
assert tc["tool_call_id"].startswith("toolu_")
|
||||
# Support both Anthropic (toolu_) and OpenAI (call_) tool call ID formats
|
||||
assert tc["tool_call_id"].startswith("toolu_") or tc["tool_call_id"].startswith("call_"), (
|
||||
f"Unexpected tool call ID format: {tc['tool_call_id']}"
|
||||
)
|
||||
assert "num_sides" in tc["arguments"]
|
||||
|
||||
# assert tool returns match the tool calls
|
||||
|
||||
Reference in New Issue
Block a user