feat: consolidate reasoning model checks (#3862)
This commit is contained in:
2
.github/scripts/model-sweep/model_sweep.py
vendored
2
.github/scripts/model-sweep/model_sweep.py
vendored
@@ -96,7 +96,7 @@ all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
# "azure-gpt-4o-mini.json", # TODO: Re-enable on new agent loop
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet.json",
|
||||
"claude-4-sonnet-extended.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
"gemini-1.5-pro.json",
|
||||
"gemini-2.5-flash-vertex.json",
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
- "openai-gpt-4o-mini.json"
|
||||
- "azure-gpt-4o-mini.json"
|
||||
- "claude-3-5-sonnet.json"
|
||||
- "claude-3-7-sonnet.json"
|
||||
- "claude-4-sonnet-extended.json"
|
||||
- "claude-3-7-sonnet-extended.json"
|
||||
- "gemini-pro.json"
|
||||
- "gemini-vertex.json"
|
||||
|
||||
@@ -182,7 +182,7 @@ class AnthropicClient(LLMClientBase):
|
||||
}
|
||||
|
||||
# Extended Thinking
|
||||
if llm_config.enable_reasoner:
|
||||
if self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
|
||||
data["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": llm_config.max_reasoning_tokens,
|
||||
@@ -200,7 +200,7 @@ class AnthropicClient(LLMClientBase):
|
||||
# Special case for summarization path
|
||||
tools_for_request = None
|
||||
tool_choice = None
|
||||
elif llm_config.enable_reasoner:
|
||||
elif self.is_reasoning_model(llm_config) and llm_config.enable_reasoner:
|
||||
# NOTE: reasoning models currently do not allow for `any`
|
||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||
tools_for_request = [OpenAITool(function=f) for f in tools]
|
||||
@@ -296,6 +296,13 @@ class AnthropicClient(LLMClientBase):
|
||||
token_count -= 8
|
||||
return token_count
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return (
|
||||
llm_config.model.startswith("claude-3-7-sonnet")
|
||||
or llm_config.model.startswith("claude-sonnet-4")
|
||||
or llm_config.model.startswith("claude-opus-4")
|
||||
)
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
if isinstance(e, anthropic.APITimeoutError):
|
||||
|
||||
@@ -504,6 +504,9 @@ class GoogleVertexClient(LLMClientBase):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro")
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
# Fallback to base implementation
|
||||
|
||||
@@ -174,6 +174,10 @@ class LLMClientBase:
|
||||
"""
|
||||
raise NotImplementedError(f"Streaming is not supported for {llm_config.model_endpoint_type}")
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
"""
|
||||
|
||||
@@ -276,6 +276,9 @@ class OpenAIClient(LLMClientBase):
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return is_openai_reasoning_model(llm_config.model)
|
||||
|
||||
@trace_method
|
||||
def convert_response_to_chat_completion(
|
||||
self,
|
||||
@@ -298,7 +301,7 @@ class OpenAIClient(LLMClientBase):
|
||||
)
|
||||
|
||||
# If we used a reasoning model, create a content part for the ommitted reasoning
|
||||
if is_openai_reasoning_model(llm_config.model):
|
||||
if self.is_reasoning_model(llm_config):
|
||||
chat_completion_response.choices[0].message.omitted_reasoning_content = True
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
@@ -94,6 +94,9 @@ class LLMConfig(BaseModel):
|
||||
"""
|
||||
model = values.get("model")
|
||||
|
||||
if model is None:
|
||||
return values
|
||||
|
||||
# Define models where we want put_inner_thoughts_in_kwargs to be False
|
||||
avoid_put_inner_thoughts_in_kwargs = ["gpt-4"]
|
||||
|
||||
@@ -107,25 +110,13 @@ class LLMConfig(BaseModel):
|
||||
if is_openai_reasoning_model(model):
|
||||
values["put_inner_thoughts_in_kwargs"] = False
|
||||
|
||||
if values.get("enable_reasoner") and values.get("model_endpoint_type") == "anthropic":
|
||||
if values.get("model_endpoint_type") == "anthropic" and (
|
||||
model.startswith("claude-3-7-sonnet") or model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")
|
||||
):
|
||||
values["put_inner_thoughts_in_kwargs"] = False
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def issue_warning_for_reasoning_constraints(self) -> "LLMConfig":
|
||||
if self.enable_reasoner:
|
||||
if self.max_reasoning_tokens is None:
|
||||
logger.warning("max_reasoning_tokens must be set when enable_reasoner is True")
|
||||
if self.max_tokens is not None and self.max_reasoning_tokens >= self.max_tokens:
|
||||
logger.warning("max_tokens must be greater than max_reasoning_tokens (thinking budget)")
|
||||
if self.put_inner_thoughts_in_kwargs:
|
||||
logger.debug("Extended thinking is not compatible with put_inner_thoughts_in_kwargs")
|
||||
elif self.max_reasoning_tokens and not self.enable_reasoner:
|
||||
logger.warning("model will not use reasoning unless enable_reasoner is set to True")
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def default_config(cls, model_name: str):
|
||||
"""
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"model_endpoint_type": "anthropic",
|
||||
"model_endpoint": "https://api.anthropic.com/v1",
|
||||
"model_wrapper": null,
|
||||
"context_window": 200000,
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
}
|
||||
@@ -4,5 +4,7 @@
|
||||
"model_endpoint": "https://api.anthropic.com/v1",
|
||||
"model_wrapper": null,
|
||||
"context_window": 200000,
|
||||
"put_inner_thoughts_in_kwargs": true
|
||||
"put_inner_thoughts_in_kwargs": false,
|
||||
"enable_reasoner": true,
|
||||
"max_reasoning_tokens": 1024
|
||||
}
|
||||
@@ -118,9 +118,8 @@ all_configs = [
|
||||
"openai-o3.json",
|
||||
"openai-o4-mini.json",
|
||||
"azure-gpt-4o-mini.json",
|
||||
"claude-4-sonnet.json",
|
||||
"claude-4-sonnet-extended.json",
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
"bedrock-claude-4-sonnet.json",
|
||||
"gemini-1.5-pro.json",
|
||||
|
||||
Reference in New Issue
Block a user