feat: parallel tool calling for openai non streaming [LET-4593] (#5773)

* first hack

* clean up

* first implementation working

* revert package-lock

* remove openai test

* error throw

* typo

* Update integration_test_send_message_v2.py

* Update integration_test_send_message_v2.py

* refine test

* Only make changes for openai non streaming

* Add tests

---------

Co-authored-by: Ari Webb <ari@letta.com>
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Ari Webb
2025-10-30 15:52:11 -07:00
committed by Caren Thomas
parent 7cc9471f40
commit 48cc73175b
6 changed files with 84 additions and 24 deletions

View File

@@ -78,6 +78,13 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
use_responses = "input" in request_data and "messages" not in request_data
# No support for Responses API proxy
is_proxy = self.llm_config.provider_name == "lmstudio_openai"
# Use parallel tool calling interface if enabled in config
use_parallel = self.llm_config.parallel_tool_calls and tools and not use_responses and not is_proxy
# TODO: Temp, remove
if use_parallel:
raise RuntimeError("Parallel tool calling not supported for OpenAI streaming")
if use_responses and not is_proxy:
self.interface = SimpleOpenAIResponsesStreamingInterface(
is_openai_proxy=False,

View File

@@ -364,13 +364,15 @@ class LettaAgentV3(LettaAgentV2):
requires_subsequent_tool_call=self._require_tool_call,
)
# TODO: Extend to more providers, and also approval tool rules
# Enable Anthropic parallel tool use when no tool rules are attached
# Enable parallel tool use when no tool rules are attached
try:
no_tool_rules = (
not self.agent_state.tool_rules
or len([t for t in self.agent_state.tool_rules if t.type != "requires_approval"]) == 0
)
# Anthropic/Bedrock parallel tool use
if self.agent_state.llm_config.model_endpoint_type in ["anthropic", "bedrock"]:
no_tool_rules = (
not self.agent_state.tool_rules
or len([t for t in self.agent_state.tool_rules if t.type != "requires_approval"]) == 0
)
if (
isinstance(request_data.get("tool_choice"), dict)
and "disable_parallel_tool_use" in request_data["tool_choice"]
@@ -381,6 +383,16 @@ class LettaAgentV3(LettaAgentV2):
else:
# Explicitly disable when tool rules present or llm_config toggled off
request_data["tool_choice"]["disable_parallel_tool_use"] = True
# OpenAI parallel tool use
elif self.agent_state.llm_config.model_endpoint_type == "openai":
# For OpenAI, we control parallel tool calling via parallel_tool_calls field
# Only allow parallel tool calls when no tool rules and enabled in config
if "parallel_tool_calls" in request_data:
if no_tool_rules and self.agent_state.llm_config.parallel_tool_calls:
request_data["parallel_tool_calls"] = True
else:
request_data["parallel_tool_calls"] = False
except Exception:
# if this fails, we simply don't enable parallel tool use
pass
@@ -435,11 +447,13 @@ class LettaAgentV3(LettaAgentV2):
self._update_global_usage_stats(llm_adapter.usage)
# Handle the AI response with the extracted data (supports multiple tool calls)
# Gather tool calls. Approval paths specify a single tool call.
# Gather tool calls - check for multi-call API first, then fall back to single
if hasattr(llm_adapter, "tool_calls") and llm_adapter.tool_calls:
tool_calls = llm_adapter.tool_calls
elif llm_adapter.tool_call is not None:
tool_calls = [llm_adapter.tool_call]
else:
tool_calls = []
aggregated_persisted: list[Message] = []
persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
@@ -931,15 +945,25 @@ class LettaAgentV3(LettaAgentV2):
)
# 5g. Aggregate continuation decisions
# For multiple tools: continue if ANY says continue, use last non-None stop_reason
# For single tool: use its decision directly
aggregate_continue = any(persisted_continue_flags) if persisted_continue_flags else False
aggregate_continue = aggregate_continue or tool_call_denials or tool_returns # continue if any tool call was denied or returned
aggregate_continue = aggregate_continue or tool_call_denials or tool_returns
# Determine aggregate stop reason
aggregate_stop_reason = None
for sr in persisted_stop_reasons:
if sr is not None:
aggregate_stop_reason = sr
# For parallel tool calls, always continue to allow the agent to process/summarize results
# unless a terminal tool was called or we hit max steps
if len(exec_specs) > 1:
has_terminal = any(sr and sr.stop_reason == StopReasonType.tool_rule.value for sr in persisted_stop_reasons)
is_max_steps = any(sr and sr.stop_reason == StopReasonType.max_steps.value for sr in persisted_stop_reasons)
if not has_terminal and not is_max_steps:
# Force continuation for parallel tool execution
aggregate_continue = True
aggregate_stop_reason = None
return persisted_messages, aggregate_continue, aggregate_stop_reason
@trace_method

View File

@@ -624,8 +624,8 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest):
data = chat_completion_request.model_dump(exclude_none=True)
# add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified."
if chat_completion_request.tools is not None:
data["parallel_tool_calls"] = False
if chat_completion_request.tools is not None and chat_completion_request.parallel_tool_calls is not None:
data["parallel_tool_calls"] = chat_completion_request.parallel_tool_calls
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:

View File

@@ -324,7 +324,7 @@ class OpenAIClient(LLMClientBase):
tool_choice=tool_choice,
max_output_tokens=llm_config.max_tokens,
temperature=llm_config.temperature if supports_temperature_param(model) else None,
parallel_tool_calls=False,
parallel_tool_calls=llm_config.parallel_tool_calls if tools and supports_parallel_tool_calling(model) else False,
)
# Add verbosity control for GPT-5 models
@@ -349,7 +349,7 @@ class OpenAIClient(LLMClientBase):
# Add parallel tool calling
if tools and supports_parallel_tool_calling(model):
data.parallel_tool_calls = False
data.parallel_tool_calls = llm_config.parallel_tool_calls
# always set user id for openai requests
if self.actor:

View File

@@ -1138,8 +1138,14 @@ class Message(BaseMessage):
assert self.tool_calls is not None or text_content is not None, vars(self)
except AssertionError as e:
# relax check if this message only contains reasoning content
if self.content is not None and len(self.content) > 0 and isinstance(self.content[0], ReasoningContent):
return None
if self.content is not None and len(self.content) > 0:
# Check if all non-empty content is reasoning-related
all_reasoning = all(
isinstance(c, (ReasoningContent, SummarizedReasoningContent, OmittedReasoningContent, RedactedReasoningContent))
for c in self.content
)
if all_reasoning:
return None
raise e
# if native content, then put it directly inside the content
@@ -1228,15 +1234,32 @@ class Message(BaseMessage):
use_developer_message: bool = False,
) -> List[dict]:
messages = Message.filter_messages_for_llm_api(messages)
result = [
m.to_openai_dict(
result: List[dict] = []
for m in messages:
# Special case: OpenAI Chat Completions requires a separate tool message per tool_call_id
# If we have multiple explicit tool_returns on a single Message, expand into one dict per return
if m.role == MessageRole.tool and m.tool_returns and len(m.tool_returns) > 0:
for tr in m.tool_returns:
if not tr.tool_call_id:
raise TypeError("ToolReturn came back without a tool_call_id.")
result.append(
{
"content": tr.func_response,
"role": "tool",
"tool_call_id": tr.tool_call_id[:max_tool_id_length] if max_tool_id_length else tr.tool_call_id,
}
)
continue
d = m.to_openai_dict(
max_tool_id_length=max_tool_id_length,
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
use_developer_message=use_developer_message,
)
for m in messages
]
result = [m for m in result if m is not None]
if d is not None:
result.append(d)
return result
def to_openai_responses_dicts(

View File

@@ -533,15 +533,18 @@ async def test_greeting(
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.asyncio(loop_scope="function")
async def test_parallel_tool_call_anthropic(
async def test_parallel_tool_calls(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
send_type: str,
) -> None:
if llm_config.model_endpoint_type != "anthropic":
pytest.skip("Parallel tool calling test only applies to Anthropic models.")
if llm_config.model_endpoint_type != "anthropic" and llm_config.model_endpoint_type != "openai":
pytest.skip("Parallel tool calling test only applies to Anthropic and OpenAI models.")
if llm_config.model_endpoint_type == "openai" and send_type not in {"step", "stream_steps"}:
pytest.skip(f"OpenAI reasoning model {llm_config.model} does not support streaming parallel tool calling for now.")
# change llm_config to support parallel tool calling
llm_config.parallel_tool_calls = True
@@ -587,7 +590,10 @@ async def test_parallel_tool_call_anthropic(
# verify each tool call
for tc in tool_call_msg.tool_calls:
assert tc["name"] == "roll_dice"
assert tc["tool_call_id"].startswith("toolu_")
# Support both Anthropic (toolu_) and OpenAI (call_) tool call ID formats
assert tc["tool_call_id"].startswith("toolu_") or tc["tool_call_id"].startswith("call_"), (
f"Unexpected tool call ID format: {tc['tool_call_id']}"
)
assert "num_sides" in tc["arguments"]
# assert tool returns match the tool calls