feat: Enable dynamic toggling of tool choice in v3 agent loop for OpenAI [LET-4564] (#5042)
* Add subsequent flag * Finish integrating constrained/unconstrained toggling on v3 agent loop * Update tests to run on v3 * Run lint
This commit is contained in:
committed by
Caren Thomas
parent
c465da27e6
commit
df5c997da0
@@ -183,6 +183,7 @@ class AnthropicClient(LLMClientBase):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
||||
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise
|
||||
|
||||
@@ -70,8 +70,9 @@ class BedrockClient(AnthropicClient):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
# remove disallowed fields
|
||||
if "tool_choice" in data:
|
||||
del data["tool_choice"]["disable_parallel_tool_use"]
|
||||
|
||||
@@ -337,11 +337,12 @@ class DeepseekClient(OpenAIClient):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
# Override put_inner_thoughts_in_kwargs to False for DeepSeek
|
||||
llm_config.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
def add_functions_to_system_message(system_message: ChatMessage):
|
||||
system_message.content += f"<available functions> {''.join(json.dumps(f) for f in tools)} </available functions>"
|
||||
|
||||
@@ -280,6 +280,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
llm_config: LLMConfig,
|
||||
tools: List[dict],
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
|
||||
@@ -29,8 +29,9 @@ class GroqClient(OpenAIClient):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
# Groq validation - these fields are not supported and will cause 400 errors
|
||||
# https://console.groq.com/docs/openai
|
||||
|
||||
@@ -127,6 +127,7 @@ class LLMClientBase:
|
||||
llm_config: LLMConfig,
|
||||
tools: List[dict],
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for this client.
|
||||
|
||||
@@ -206,6 +206,7 @@ class OpenAIClient(LLMClientBase):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for the OpenAI Responses API.
|
||||
@@ -224,14 +225,15 @@ class OpenAIClient(LLMClientBase):
|
||||
logger.warning(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}")
|
||||
model = None
|
||||
|
||||
# Default to auto, unless there's a forced tool call coming from above
|
||||
# Default to auto, unless there's a forced tool call coming from above or requires_subsequent_tool_call is True
|
||||
tool_choice = None
|
||||
if tools: # only set tool_choice if tools exist
|
||||
tool_choice = (
|
||||
"auto"
|
||||
if force_tool_call is None
|
||||
else ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
)
|
||||
if force_tool_call is not None:
|
||||
tool_choice = {"type": "function", "name": force_tool_call}
|
||||
elif requires_subsequent_tool_call:
|
||||
tool_choice = "required"
|
||||
else:
|
||||
tool_choice = "auto"
|
||||
|
||||
# Convert the tools from the ChatCompletions style to the Responses style
|
||||
if tools:
|
||||
@@ -352,6 +354,7 @@ class OpenAIClient(LLMClientBase):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None, # Keep as dict for now as per base class
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Constructs a request object in the expected data format for the OpenAI API.
|
||||
@@ -364,6 +367,7 @@ class OpenAIClient(LLMClientBase):
|
||||
llm_config=llm_config,
|
||||
tools=tools,
|
||||
force_tool_call=force_tool_call,
|
||||
requires_subsequent_tool_call=requires_subsequent_tool_call,
|
||||
)
|
||||
|
||||
if agent_type == AgentType.letta_v1_agent:
|
||||
@@ -407,15 +411,16 @@ class OpenAIClient(LLMClientBase):
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if tools: # only set tool_choice if tools exist
|
||||
if self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
|
||||
if force_tool_call is not None:
|
||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
elif requires_subsequent_tool_call:
|
||||
tool_choice = "required"
|
||||
elif self.requires_auto_tool_choice(llm_config) or agent_type == AgentType.letta_v1_agent:
|
||||
tool_choice = "auto"
|
||||
else:
|
||||
# only set if tools is non-Null
|
||||
tool_choice = "required"
|
||||
|
||||
if force_tool_call is not None:
|
||||
tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=force_tool_call))
|
||||
|
||||
data = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=fill_image_content_in_messages(openai_message_list, messages),
|
||||
|
||||
@@ -29,8 +29,9 @@ class XAIClient(OpenAIClient):
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
# Specific bug for the mini models (as of Apr 14, 2025)
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
||||
|
||||
Reference in New Issue
Block a user