feat: consolidate reasoning model checks (#3862)
This commit is contained in:
2
.github/scripts/model-sweep/model_sweep.py
vendored
2
.github/scripts/model-sweep/model_sweep.py
vendored
@@ -96,7 +96,7 @@ all_configs = [
|
|||||||
"openai-gpt-4o-mini.json",
|
"openai-gpt-4o-mini.json",
|
||||||
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
|
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
|
||||||
"claude-3-5-sonnet.json",
|
"claude-3-5-sonnet.json",
|
||||||
"claude-3-7-sonnet.json",
|
"claude-4-sonnet-extended.json",
|
||||||
"claude-3-7-sonnet-extended.json",
|
"claude-3-7-sonnet-extended.json",
|
||||||
"gemini-1.5-pro.json",
|
"gemini-1.5-pro.json",
|
||||||
"gemini-2.5-flash-vertex.json",
|
"gemini-2.5-flash-vertex.json",
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- "openai-gpt-4o-mini.json"
|
- "openai-gpt-4o-mini.json"
|
||||||
- "azure-gpt-4o-mini.json"
|
- "azure-gpt-4o-mini.json"
|
||||||
- "claude-3-5-sonnet.json"
|
- "claude-3-5-sonnet.json"
|
||||||
- "claude-3-7-sonnet.json"
|
- "claude-4-sonnet-extended.json"
|
||||||
- "claude-3-7-sonnet-extended.json"
|
- "claude-3-7-sonnet-extended.json"
|
||||||
- "gemini-pro.json"
|
- "gemini-pro.json"
|
||||||
- "gemini-vertex.json"
|
- "gemini-vertex.json"
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class AnthropicClient(LLMClientBase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Extended Thinking
|
# Extended Thinking
|
||||||
if llm_config.enable_reasoner:
|
if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
|
||||||
data["thinking"] = {
|
data["thinking"] = {
|
||||||
"type": "enabled",
|
"type": "enabled",
|
||||||
"budget_tokens": llm_config.max_reasoning_tokens,
|
"budget_tokens": llm_config.max_reasoning_tokens,
|
||||||
@@ -200,7 +200,7 @@ class AnthropicClient(LLMClientBase):
|
|||||||
# Special case for summarization path
|
# Special case for summarization path
|
||||||
tools_for_request = None
|
tools_for_request = None
|
||||||
tool_choice = None
|
tool_choice = None
|
||||||
elif llm_config.enable_reasoner:
|
elif self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
|
||||||
# NOTE: reasoning models currently do not allow for `any`
|
# NOTE: reasoning models currently do not allow for `any`
|
||||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||||
tools_for_request = [OpenAITool(function=f) for f in tools]
|
tools_for_request = [OpenAITool(function=f) for f in tools]
|
||||||
@@ -296,6 +296,13 @@ class AnthropicClient(LLMClientBase):
|
|||||||
token_count -= 8
|
token_count -= 8
|
||||||
return token_count
|
return token_count
|
||||||
|
|
||||||
|
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||||
|
return (
|
||||||
|
llm_config.model.startswith("claude-3-7-sonnet")
|
||||||
|
or llm_config.model.startswith("claude-sonnet-4")
|
||||||
|
or llm_config.model.startswith("claude-opus-4")
|
||||||
|
)
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
def handle_llm_error(self, e: Exception) -> Exception:
|
def handle_llm_error(self, e: Exception) -> Exception:
|
||||||
if isinstance(e, anthropic.APITimeoutError):
|
if isinstance(e, anthropic.APITimeoutError):
|
||||||
|
|||||||
@@ -504,6 +504,9 @@ class GoogleVertexClient(LLMClientBase):
|
|||||||
return 1
|
return 1
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||||
|
return llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro")
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
def handle_llm_error(self, e: Exception) -> Exception:
|
def handle_llm_error(self, e: Exception) -> Exception:
|
||||||
# Fallback to base implementation
|
# Fallback to base implementation
|
||||||
|
|||||||
@@ -174,6 +174,10 @@ class LLMClientBase:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def handle_llm_error(self, e: Exception) -> Exception:
|
def handle_llm_error(self, e: Exception) -> Exception:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -276,6 +276,9 @@ class OpenAIClient(LLMClientBase):
|
|||||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||||
return response.model_dump()
|
return response.model_dump()
|
||||||
|
|
||||||
|
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||||
|
return is_openai_reasoning_model(llm_config.model)
|
||||||
|
|
||||||
@trace_method
|
@trace_method
|
||||||
def convert_response_to_chat_completion(
|
def convert_response_to_chat_completion(
|
||||||
self,
|
self,
|
||||||
@@ -298,7 +301,7 @@ class OpenAIClient(LLMClientBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If we used a reasoning model, create a content part for the ommitted reasoning
|
# If we used a reasoning model, create a content part for the ommitted reasoning
|
||||||
if is_openai_reasoning_model(llm_config.model):
|
if self.is_reasoning_model(llm_config):
|
||||||
chat_completion_response.choices[0].message.omitted_reasoning_content = True
|
chat_completion_response.choices[0].message.omitted_reasoning_content = True
|
||||||
|
|
||||||
return chat_completion_response
|
return chat_completion_response
|
||||||
|
|||||||
@@ -94,6 +94,9 @@ class LLMConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
model = values.get("model")
|
model = values.get("model")
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
return values
|
||||||
|
|
||||||
# Define models where we want put_inner_thoughts_in_kwargs to be False
|
# Define models where we want put_inner_thoughts_in_kwargs to be False
|
||||||
avoid_put_inner_thoughts_in_kwargs = ["gpt-4"]
|
avoid_put_inner_thoughts_in_kwargs = ["gpt-4"]
|
||||||
|
|
||||||
@@ -107,25 +110,13 @@ class LLMConfig(BaseModel):
|
|||||||
if is_openai_reasoning_model(model):
|
if is_openai_reasoning_model(model):
|
||||||
values["put_inner_thoughts_in_kwargs"] = False
|
values["put_inner_thoughts_in_kwargs"] = False
|
||||||
|
|
||||||
if values.get("enable_reasoner") and values.get("model_endpoint_type") == "anthropic":
|
if values.get("model_endpoint_type") == "anthropic" and (
|
||||||
|
model.startswith("claude-3-7-sonnet") or model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")
|
||||||
|
):
|
||||||
values["put_inner_thoughts_in_kwargs"] = False
|
values["put_inner_thoughts_in_kwargs"] = False
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def issue_warning_for_reasoning_constraints(self) -> "LLMConfig":
|
|
||||||
if self.enable_reasoner:
|
|
||||||
if self.max_reasoning_tokens is None:
|
|
||||||
logger.warning("max_reasoning_tokens must be set when enable_reasoner is True")
|
|
||||||
if self.max_tokens is not None and self.max_reasoning_tokens >= self.max_tokens:
|
|
||||||
logger.warning("max_tokens must be greater than max_reasoning_tokens (thinking budget)")
|
|
||||||
if self.put_inner_thoughts_in_kwargs:
|
|
||||||
logger.debug("Extended thinking is not compatible with put_inner_thoughts_in_kwargs")
|
|
||||||
elif self.max_reasoning_tokens and not self.enable_reasoner:
|
|
||||||
logger.warning("model will not use reasoning unless enable_reasoner is set to True")
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_config(cls, model_name: str):
|
def default_config(cls, model_name: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"model": "claude-3-7-sonnet-20250219",
|
|
||||||
"model_endpoint_type": "anthropic",
|
|
||||||
"model_endpoint": "https://api.anthropic.com/v1",
|
|
||||||
"model_wrapper": null,
|
|
||||||
"context_window": 200000,
|
|
||||||
"put_inner_thoughts_in_kwargs": true
|
|
||||||
}
|
|
||||||
@@ -4,5 +4,7 @@
|
|||||||
"model_endpoint": "https://api.anthropic.com/v1",
|
"model_endpoint": "https://api.anthropic.com/v1",
|
||||||
"model_wrapper": null,
|
"model_wrapper": null,
|
||||||
"context_window": 200000,
|
"context_window": 200000,
|
||||||
"put_inner_thoughts_in_kwargs": true
|
"put_inner_thoughts_in_kwargs": false,
|
||||||
|
"enable_reasoner": true,
|
||||||
|
"max_reasoning_tokens": 1024
|
||||||
}
|
}
|
||||||
@@ -118,9 +118,8 @@ all_configs = [
|
|||||||
"openai-o3.json",
|
"openai-o3.json",
|
||||||
"openai-o4-mini.json",
|
"openai-o4-mini.json",
|
||||||
"azure-gpt-4o-mini.json",
|
"azure-gpt-4o-mini.json",
|
||||||
"claude-4-sonnet.json",
|
"claude-4-sonnet-extended.json",
|
||||||
"claude-3-5-sonnet.json",
|
"claude-3-5-sonnet.json",
|
||||||
"claude-3-7-sonnet.json",
|
|
||||||
"claude-3-7-sonnet-extended.json",
|
"claude-3-7-sonnet-extended.json",
|
||||||
"bedrock-claude-4-sonnet.json",
|
"bedrock-claude-4-sonnet.json",
|
||||||
"gemini-1.5-pro.json",
|
"gemini-1.5-pro.json",
|
||||||
|
|||||||
Reference in New Issue
Block a user